Source code for agentscope_runtime.engine.app.celery_mixin
# -*- coding: utf-8 -*-
import inspect
import asyncio
import logging
from typing import Callable, Optional, List
from celery import Celery
logger = logging.getLogger(__name__)
[docs]
class CeleryMixin:
"""
Celery task processing mixin that provides core Celery functionality.
Can be reused by BaseApp and FastAPIAppFactory.
"""
[docs]
def __init__(
self,
broker_url: Optional[str] = None,
backend_url: Optional[str] = None,
):
if broker_url and backend_url:
self.celery_app = Celery(
"agentscope_runtime",
broker=broker_url,
backend=backend_url,
)
else:
self.celery_app = None
self._registered_queues: set[str] = set()
[docs]
def get_registered_queues(self) -> set[str]:
return self._registered_queues
[docs]
def register_celery_task(self, func: Callable, queue: str = "celery"):
"""Register a Celery task for the given function."""
if self.celery_app is None:
raise RuntimeError("Celery is not configured.")
self._registered_queues.add(queue)
def _coerce_result(x):
# Normalize Pydantic models first
if hasattr(x, "model_dump"): # pydantic v2
x = x.model_dump()
elif hasattr(x, "dict"): # pydantic v1
x = x.dict()
# Preserve simple primitives as-is
if isinstance(x, (str, int, float, bool)) or x is None:
return x
# Recursively coerce dictionaries
if isinstance(x, dict):
return {k: _coerce_result(v) for k, v in x.items()}
# Recursively coerce lists
if isinstance(x, list):
return [_coerce_result(item) for item in x]
# Fallback: string representation for anything else
return str(x)
async def _collect_async_gen(agen):
items = []
async for x in agen:
items.append(_coerce_result(x))
return items
def _collect_gen(gen):
return [_coerce_result(x) for x in gen]
@self.celery_app.task(queue=queue)
def wrapper(*args, **kwargs):
# 1) async generator function
if inspect.isasyncgenfunction(func):
result = func(*args, **kwargs)
# 2) async function
elif inspect.iscoroutinefunction(func):
result = asyncio.run(func(*args, **kwargs))
else:
result = func(*args, **kwargs)
# 3) async generator
if inspect.isasyncgen(result):
return asyncio.run(_collect_async_gen(result))
# 4) sync generator
if inspect.isgenerator(result):
return _collect_gen(result)
# 5) normal return
return _coerce_result(result)
return wrapper
[docs]
def run_task_processor(
self,
loglevel: str = "INFO",
concurrency: Optional[int] = None,
queues: Optional[List[str]] = None,
):
"""Run Celery worker in this process."""
if self.celery_app is None:
raise RuntimeError("Celery is not configured.")
cmd = ["worker", f"--loglevel={loglevel}"]
if concurrency:
cmd.append(f"--concurrency={concurrency}")
if queues:
cmd += ["-Q", ",".join(queues)]
self.celery_app.worker_main(cmd)
[docs]
def get_task_status(self, task_id: str):
"""Get task status from Celery result backend."""
if self.celery_app is None:
return {"error": "Celery not configured"}
result = self.celery_app.AsyncResult(task_id)
if result.state == "PENDING":
return {"status": "pending", "result": None}
elif result.state == "SUCCESS":
return {"status": "finished", "result": result.result}
elif result.state == "FAILURE":
return {"status": "error", "result": str(result.info)}
else:
return {"status": result.state, "result": None}
[docs]
def submit_task(self, func: Callable, *args, **kwargs):
"""Submit task directly to Celery queue."""
if not hasattr(func, "celery_task"):
raise RuntimeError(
f"Function {func.__name__} is not registered as a task",
)
return func.celery_task.delay(*args, **kwargs)