# -*- coding: utf-8 -*-
import inspect
import asyncio
import logging
import time
import concurrent.futures
from typing import Callable, Optional, List, Any, Dict
logger = logging.getLogger(__name__)
[docs]
class TaskEngineMixin:
[docs]
def init_task_engine(
self,
broker_url: Optional[str] = None,
backend_url: Optional[str] = None,
):
self.celery_app = None
self.active_tasks: Dict[str, Dict[str, Any]] = {}
self._registered_queues: set[str] = set()
self.task_locks: Dict[str, asyncio.Lock] = {}
self._tasks_lock: Optional[asyncio.Lock] = None
if broker_url and backend_url:
try:
from celery import Celery
self.celery_app = Celery(
"agentscope_runtime",
broker=broker_url,
backend=backend_url,
)
logger.info("Celery task engine initialized.")
except ImportError:
logger.warning(
"Celery not installed, "
"using fallback in-memory processing.",
)
self.celery_app = None
except Exception as e:
logger.error(f"Celery initialization error: {e}")
self.celery_app = None
else:
logger.info("Celery not configured. Fallback to in-memory mode.")
def _coerce_result(self, x):
# Normalize Pydantic models first
if hasattr(x, "model_dump"): # pydantic v2
x = x.model_dump()
elif hasattr(x, "dict"): # pydantic v1
x = x.dict()
# Preserve simple primitives as-is
if isinstance(x, (str, int, float, bool)) or x is None:
return x
# Recursively coerce dictionaries
if isinstance(x, dict):
return {k: self._coerce_result(v) for k, v in x.items()}
# Recursively coerce lists
if isinstance(x, list):
return [self._coerce_result(item) for item in x]
# Fallback: string representation for anything else
return str(x)
[docs]
def register_celery_task(self, func: Callable, queue: str = "celery"):
if self.celery_app is None:
raise RuntimeError("Celery is not configured.")
mod_name = func.__module__
if mod_name == "__main__":
import os
import sys
mod_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
unique_name = f"tasks.{mod_name}.{func.__name__}"
logger.info(
f"Registered Celery task '{unique_name}' on queue '{queue}'",
)
self._registered_queues.add(queue)
async def _collect_async_gen(agen):
items = []
async for x in agen:
items.append(self._coerce_result(x))
return items
def _collect_gen(gen):
return [self._coerce_result(x) for x in gen]
@self.celery_app.task(name=unique_name, queue=queue)
def wrapper(*args, **kwargs):
# 1) async generator function
if inspect.isasyncgenfunction(func):
result = func(*args, **kwargs)
# 2) async function
elif inspect.iscoroutinefunction(func):
result = asyncio.run(func(*args, **kwargs))
else:
result = func(*args, **kwargs)
# 3) async generator
if inspect.isasyncgen(result):
return asyncio.run(_collect_async_gen(result))
# 4) sync generator
if inspect.isgenerator(result):
return _collect_gen(result)
# 5) normal return
return self._coerce_result(result)
return wrapper
[docs]
def submit_celery_task(self, func: Callable, *args, **kwargs):
if not hasattr(func, "celery_task"):
raise RuntimeError(f"Function {func.__name__} is not registered.")
return func.celery_task.delay(*args, **kwargs)
[docs]
def start_embedded_celery_worker(self):
"""Initialize Celery worker in a background daemon thread."""
import threading
def start_celery_worker():
try:
logger.info(
"Initializing Celery worker in a "
"background daemon thread...",
)
queues = (
list(self._registered_queues)
if self._registered_queues
else ["celery"]
)
self._run_celery_task_processor(
loglevel="INFO",
concurrency=1,
queues=queues,
)
except Exception as e:
logger.error(f"Embedded Celery worker failed: {e}")
threading.Thread(target=start_celery_worker, daemon=True).start()
def _run_celery_task_processor(
self,
loglevel: str = "INFO",
concurrency: Optional[int] = None,
queues: Optional[List[str]] = None,
):
"""Run Celery worker in this process."""
if self.celery_app is None:
raise RuntimeError("Celery is not configured.")
cmd = ["worker", f"--loglevel={loglevel}"]
if concurrency:
cmd.append(f"--concurrency={concurrency}")
if queues:
cmd += ["-Q", ",".join(queues)]
self.celery_app.worker_main(cmd)
async def _get_task_lock(self, task_id: str) -> asyncio.Lock:
"""
Get or create a lock for a specific task.
Args:
task_id: Task identifier
Returns:
asyncio.Lock for the specified task
"""
if self._tasks_lock is None:
self._tasks_lock = asyncio.Lock()
async with self._tasks_lock:
if task_id not in self.task_locks:
self.task_locks[task_id] = asyncio.Lock()
return self.task_locks[task_id]
[docs]
async def execute_background_task(
self,
task_id: str,
func: Callable,
request: dict,
queue: str,
):
# pylint:disable=unused-argument
try:
self.active_tasks[task_id].update(
{
"status": "running",
"started_at": time.time(),
},
)
if inspect.isasyncgenfunction(func):
result = []
async for item in func(request):
result.append(self._coerce_result(item))
elif inspect.iscoroutinefunction(func):
result = await func(request)
result = self._coerce_result(result)
elif inspect.isgeneratorfunction(func):
def collect_gen():
return [self._coerce_result(x) for x in func(request)]
with concurrent.futures.ThreadPoolExecutor() as executor:
result = await asyncio.get_event_loop().run_in_executor(
executor,
collect_gen,
)
else:
with concurrent.futures.ThreadPoolExecutor() as executor:
result = await asyncio.get_event_loop().run_in_executor(
executor,
func,
request,
)
result = self._coerce_result(result)
self.active_tasks[task_id].update(
{
"status": "completed",
"result": result,
"completed_at": time.time(),
},
)
except Exception as e:
self.active_tasks[task_id].update(
{
"status": "failed",
"error": str(e),
"failed_at": time.time(),
},
)
[docs]
async def execute_stream_query_task(
self,
task_id: str,
stream_func: Callable,
request: dict,
queue: str,
timeout: Optional[float] = None,
):
"""
Execute stream query as background task and collect ONLY the final
response.
Design: Only store the last event (final response), ignore
intermediate events to reduce memory usage.
Args:
task_id: Unique task identifier
stream_func: Streaming function (runner.stream_query)
request: Request dict
queue: Queue name
timeout: Task execution timeout in seconds
Returns:
Final response event as dict
Raises:
TimeoutError: If task exceeds specified timeout
RuntimeError: If stream yields no events
"""
# pylint:disable=unused-argument
task_lock = await self._get_task_lock(task_id)
try:
async with task_lock:
self.active_tasks[task_id].update(
{
"status": "running",
"started_at": time.time(),
},
)
final_response = None
start_time = time.time()
event_count = 0
async def stream_with_collection():
nonlocal final_response, event_count
async for event in stream_func(request):
event_count += 1
if hasattr(event, "model_dump"):
final_response = event.model_dump()
elif hasattr(event, "dict"):
final_response = event.dict()
else:
final_response = {"data": str(event)}
if timeout is not None:
await asyncio.wait_for(
stream_with_collection(),
timeout=timeout,
)
else:
await stream_with_collection()
if event_count == 0 or final_response is None:
raise RuntimeError(
f"Stream function yielded no events for task {task_id}",
)
elapsed = time.time() - start_time
async with task_lock:
self.active_tasks[task_id].update(
{
"status": "completed",
"result": final_response,
"completed_at": time.time(),
"elapsed_time": elapsed,
},
)
return final_response
except asyncio.TimeoutError:
async with task_lock:
self.active_tasks[task_id].update(
{
"status": "failed",
"error": f"Task exceeded timeout of {timeout}s",
"error_type": "TimeoutError",
"failed_at": time.time(),
},
)
raise
except Exception as e:
async with task_lock:
self.active_tasks[task_id].update(
{
"status": "failed",
"error": str(e),
"error_type": type(e).__name__,
"failed_at": time.time(),
},
)
raise
[docs]
def get_task_status(self, task_id: str):
# pylint:disable=too-many-return-statements
if self.celery_app:
result = self.celery_app.AsyncResult(task_id)
if result.state == "PENDING":
return {"status": "pending", "result": None}
elif result.state == "SUCCESS":
return {"status": "finished", "result": result.result}
elif result.state == "FAILURE":
return {"status": "error", "result": str(result.info)}
else:
return {"status": result.state, "result": None}
else:
if (
not hasattr(self, "active_tasks")
or task_id not in self.active_tasks
):
return {"error": f"Task {task_id} not found"}
task_info = self.active_tasks[task_id]
task_status = task_info.get("status", "unknown")
if task_status == "submitted":
return {"status": "pending", "result": None}
if task_status == "running":
return {
"status": "running",
"started_at": task_info.get("started_at"),
"result": None,
}
elif task_status == "completed":
return {
"status": "finished",
"result": task_info.get("result"),
}
elif task_status == "failed":
return {
"status": "error",
"result": task_info.get("error", "Unknown error"),
}
else:
return {"status": task_status, "result": None}