Source code for agentscope_runtime.engine.deployers.utils.service_utils.fastapi_factory

# -*- coding: utf-8 -*-
# pylint:disable=too-many-branches, unused-argument, too-many-return-statements
# pylint:disable=protected-access

import asyncio
import functools
import inspect
import json
import logging
from contextlib import asynccontextmanager
from dataclasses import asdict, is_dataclass
from typing import Optional, Callable, Type, Any, List, Dict

from a2a.types import A2ARequest
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel

from agentscope_runtime.engine.schemas.agent_schemas import AgentRequest
from agentscope_runtime.engine.schemas.response_api import ResponseAPI
from ..deployment_modes import DeploymentMode
from ...adapter.a2a.a2a_protocol_adapter import A2AFastAPIDefaultAdapter
from ...adapter.protocol_adapter import ProtocolAdapter
from ...adapter.responses.response_api_protocol_adapter import (
    ResponseAPIDefaultAdapter,
)

logger = logging.getLogger(__name__)


class _WrappedFastAPI(FastAPI):
    """FastAPI subclass that can dynamically augment OpenAPI schemas."""

    _REF_TEMPLATE = "#/components/schemas/{model}"

    def openapi(self) -> dict[str, Any]:
        """Generate OpenAPI schema with protocol-specific components."""
        openapi_schema = super().openapi()
        protocol_adapters = (
            getattr(self.state, "protocol_adapters", None) or []
        )

        if protocol_adapters:
            if any(
                isinstance(adapter, A2AFastAPIDefaultAdapter)
                for adapter in protocol_adapters
            ):
                self._inject_schema(
                    openapi_schema,
                    "A2ARequest",
                    A2ARequest.model_json_schema(
                        ref_template=self._REF_TEMPLATE,
                    ),
                )
            if any(
                isinstance(adapter, ResponseAPIDefaultAdapter)
                for adapter in protocol_adapters
            ):
                self._inject_schema(
                    openapi_schema,
                    "ResponseAPI",
                    ResponseAPI.model_json_schema(
                        ref_template=self._REF_TEMPLATE,
                    ),
                )

        self._inject_schema(
            openapi_schema,
            "AgentRequest",
            AgentRequest.model_json_schema(
                ref_template=self._REF_TEMPLATE,
            ),
        )

        return openapi_schema

    @staticmethod
    def _inject_schema(
        openapi_schema: dict[str, Any],
        schema_name: str,
        schema_definition: dict[str, Any],
    ) -> None:
        """Insert schema definition (and nested defs) into OpenAPI."""
        components = openapi_schema.setdefault("components", {})
        component_schemas = components.setdefault("schemas", {})

        defs = schema_definition.pop("$defs", {})
        for def_name, def_schema in defs.items():
            component_schemas.setdefault(def_name, def_schema)

        component_schemas[schema_name] = schema_definition


[docs] async def error_stream(e): yield ( f"data: " f"{json.dumps({'error': f'Request parsing error: {str(e)}'})}\n\n" )
[docs] class FastAPIAppFactory: """Factory for creating FastAPI applications with unified architecture."""
[docs] @staticmethod def create_app( func: Optional[Callable] = None, runner: Optional[Any] = None, endpoint_path: str = "/process", request_model: Optional[Type] = None, response_type: str = "sse", stream: bool = True, before_start: Optional[Callable] = None, after_finish: Optional[Callable] = None, mode: DeploymentMode = DeploymentMode.DAEMON_THREAD, protocol_adapters: Optional[list[ProtocolAdapter]] = None, custom_endpoints: Optional[ List[Dict] ] = None, # New parameter for custom endpoints # Celery parameters broker_url: Optional[str] = None, backend_url: Optional[str] = None, enable_embedded_worker: bool = False, app_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> FastAPI: """Create a FastAPI application with unified architecture. Args: func: Custom processing function runner: Runner instance (for DAEMON_THREAD mode) endpoint_path: API endpoint path for the processing function request_model: Pydantic model for request validation response_type: Response type - "json", "sse", or "text" stream: Enable streaming responses before_start: Callback function called before server starts after_finish: Callback function called after server finishes mode: Deployment mode protocol_adapters: Protocol adapters custom_endpoints: List of custom endpoint configurations broker_url: Celery broker URL backend_url: Celery backend URL enable_embedded_worker: Whether to run embedded Celery worker app_kwargs: Additional keyword arguments for the FastAPI app **kwargs: Additional keyword arguments Returns: FastAPI application instance """ # Initialize Celery mixin if broker and backend URLs are provided celery_mixin = None if broker_url and backend_url: try: from ....app.celery_mixin import CeleryMixin celery_mixin = CeleryMixin( broker_url=broker_url, backend_url=backend_url, ) except ImportError: # CeleryMixin not available, will use fallback task processing celery_mixin = None # Create lifespan manager @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager.""" # Startup try: await FastAPIAppFactory._handle_startup( app, mode, runner, before_start, **kwargs, ) yield finally: # Shutdown await FastAPIAppFactory._handle_shutdown( app, after_finish, **kwargs, ) # Create FastAPI app app = _WrappedFastAPI(lifespan=lifespan, **(app_kwargs or {})) # Store configuration in app state app.state.deployment_mode = mode app.state.stream_enabled = stream app.state.custom_func = func app.state.runner = runner app.state.response_type = response_type app.state.endpoint_path = endpoint_path app.state.protocol_adapters = protocol_adapters # Store for later use app.state.custom_endpoints = ( custom_endpoints or [] ) # Store custom endpoints # Store Celery configuration app.state.celery_mixin = celery_mixin app.state.broker_url = broker_url app.state.backend_url = backend_url app.state.enable_embedded_worker = enable_embedded_worker # Add middleware FastAPIAppFactory._add_middleware(app, mode) # Add routes FastAPIAppFactory._add_routes( app, endpoint_path, request_model, stream, mode, ) # Note: protocol_adapters will be added in _handle_startup # after runner is available return app
@staticmethod async def _handle_startup( app: FastAPI, mode: DeploymentMode, external_runner: Optional[Any], before_start: Optional[Callable], **kwargs, ): """Handle application startup.""" try: # aexit any possible running instances before set up # runner await app.state.runner.__aexit__(None, None, None) await app.state.runner.__aenter__() except Exception as e: logger.error( f"Warning: Error during runner setup: {e}", ) # Call custom startup callback if before_start: if asyncio.iscoroutinefunction(before_start): await before_start(app, **kwargs) else: before_start(app, **kwargs) # Add protocol adapter endpoints after runner is available if ( hasattr(app.state, "protocol_adapters") and app.state.protocol_adapters ): # Determine the effective function to use if hasattr(app.state, "custom_func") and app.state.custom_func: effective_func = app.state.custom_func elif hasattr(app.state, "runner") and app.state.runner: # Use stream_query if streaming is enabled, otherwise query if ( hasattr(app.state, "stream_enabled") and app.state.stream_enabled ): effective_func = app.state.runner.stream_query else: effective_func = app.state.runner.query else: effective_func = None if effective_func: for protocol_adapter in app.state.protocol_adapters: protocol_adapter.add_endpoint(app=app, func=effective_func) # Add custom endpoints after runner is available if ( hasattr(app.state, "custom_endpoints") and app.state.custom_endpoints ): FastAPIAppFactory._add_custom_endpoints(app) # Start embedded Celery worker if enabled if ( hasattr(app.state, "enable_embedded_worker") and app.state.enable_embedded_worker and hasattr(app.state, "celery_mixin") and app.state.celery_mixin ): # Start Celery worker in background thread import threading def start_celery_worker(): try: celery_mixin = app.state.celery_mixin # Get registered queues or use default queues = ( list(celery_mixin.get_registered_queues()) if celery_mixin.get_registered_queues() else ["celery"] ) celery_mixin.run_task_processor( loglevel="INFO", concurrency=1, queues=queues, ) except Exception as e: logger.error(f"Failed to start Celery worker: {e}") worker_thread = threading.Thread( target=start_celery_worker, daemon=True, ) worker_thread.start() @staticmethod async def _handle_shutdown( app: FastAPI, after_finish: Optional[Callable], **kwargs, ): """Handle application shutdown.""" # Call custom shutdown callback if after_finish: if asyncio.iscoroutinefunction(after_finish): await after_finish(app, **kwargs) else: after_finish(app, **kwargs) # Cleanup internal runner runner = app.state.runner if runner: try: # Clean up runner await runner.__aexit__(None, None, None) except Exception as e: logger.error(f"Warning: Error during runner cleanup: {e}") @staticmethod async def _create_internal_runner(): """Create internal runner with configured services.""" from agentscope_runtime.engine import Runner # Create runner (agent will be set later) runner = Runner( # agent=None, # Will be set by the specific deployment # context_manager=context_manager, ) # Initialize runner await runner.__aenter__() return runner @staticmethod def _add_middleware(app: FastAPI, mode: DeploymentMode): """Add middleware based on deployment mode.""" # Common middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mode-specific middleware if mode == DeploymentMode.DETACHED_PROCESS: # Add process management middleware @app.middleware("http") async def process_middleware(request: Request, call_next): # Add process-specific headers response = await call_next(request) response.headers["X-Process-Mode"] = "detached" return response elif mode == DeploymentMode.STANDALONE: # Add configuration middleware @app.middleware("http") async def config_middleware(request: Request, call_next): # Add configuration headers response = await call_next(request) response.headers["X-Deployment-Mode"] = "standalone" return response @staticmethod def _add_routes( app: FastAPI, endpoint_path: str, request_model: Optional[Type], stream_enabled: bool, mode: DeploymentMode, ): """Add routes to the FastAPI application.""" # Health check endpoint @app.get("/health") async def health_check(): """Health check endpoint.""" status = {"status": "healthy", "mode": mode.value} # Add service health checks if hasattr(app.state, "runner") and app.state.runner: status["runner"] = "ready" else: status["runner"] = "not_ready" return status # Agent API endpoint @app.post( endpoint_path, openapi_extra={ "requestBody": { "content": { "application/json": { "schema": { "$ref": "#/components/schemas/AgentRequest", }, }, }, "required": True, "description": "Agent API Request Format." "See https://runtime.agentscope.io/en/protocol.html for " "more details.", }, }, tags=["agent-api"], ) async def agent_api(request: dict): """ Agent API endpoint, see <https://runtime.agentscope.io/en/protocol.html> for more details. """ return StreamingResponse( FastAPIAppFactory._create_stream_generator( app, request=request, ), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", }, ) # # Standard endpoint # @app.post(endpoint_path) # async def process_endpoint(request: dict): # """Main processing endpoint.""" # return await FastAPIAppFactory._handle_request( # app, # request, # stream_enabled, # ) # Root endpoint @app.get("/") async def root(): """Root endpoint.""" return { "service": "AgentScope Runtime", "mode": mode.value, "endpoints": { "process": endpoint_path, "stream": ( f"{endpoint_path}/stream" if stream_enabled else None ), "health": "/health", }, } # Mode-specific endpoints FastAPIAppFactory._add_process_control_endpoints(app) @staticmethod def _add_process_control_endpoints(app: FastAPI): """Add process control endpoints for detached mode.""" @app.post("/shutdown") async def shutdown_process_simple(): """Gracefully shutdown the process (simple endpoint).""" # Import here to avoid circular imports import os import signal # Schedule shutdown after response async def delayed_shutdown(): await asyncio.sleep(0.5) os.kill(os.getpid(), signal.SIGTERM) asyncio.create_task(delayed_shutdown()) return {"status": "shutting down"} @app.post("/admin/shutdown") async def shutdown_process(): """Gracefully shutdown the process.""" # Import here to avoid circular imports import os import signal # Schedule shutdown after response async def delayed_shutdown(): await asyncio.sleep(1) os.kill(os.getpid(), signal.SIGTERM) asyncio.create_task(delayed_shutdown()) return {"message": "Shutdown initiated"} @app.get("/admin/status") async def get_process_status(): """Get process status information.""" import os import psutil process = psutil.Process(os.getpid()) return { "pid": os.getpid(), "status": process.status(), "memory_usage": process.memory_info().rss, "cpu_percent": process.cpu_percent(), "uptime": process.create_time(), } @staticmethod async def _handle_request( app: FastAPI, request: dict, stream_enabled: bool, ): """Handle a standard request.""" try: # Get runner instance runner = FastAPIAppFactory._get_runner_instance(app) if not runner: return JSONResponse( status_code=503, content={ "error": "Service not ready", "message": "Runner not initialized", }, ) # Handle custom function vs runner if app.state.custom_func: # Use custom function result = await FastAPIAppFactory._call_custom_function( app.state.custom_func, request, ) return {"response": result} else: # Use runner if stream_enabled: # Collect streaming response result = await FastAPIAppFactory._collect_stream_response( runner, request, ) return {"response": result} else: # Direct query result = await runner.query(request) return {"response": result} except Exception as e: return JSONResponse( status_code=500, content={"error": "Internal server error", "message": str(e)}, ) @staticmethod async def _create_stream_generator(app: FastAPI, request: dict): """Create streaming response generator.""" try: runner = FastAPIAppFactory._get_runner_instance(app) if not runner: yield ( f"data: {json.dumps({'error': 'Runner not initialized'})}" f"\n\n" ) return if app.state.custom_func: # Handle custom function (convert to stream) result = await FastAPIAppFactory._call_custom_function( app.state.custom_func, request, ) yield f"data: {json.dumps({'text': str(result)})}\n\n" else: # Use runner streaming async for chunk in runner.stream_query(request): if hasattr(chunk, "model_dump_json"): yield f"data: {chunk.model_dump_json()}\n\n" elif hasattr(chunk, "json"): yield f"data: {chunk.json()}\n\n" else: yield f"data: {json.dumps({'text': str(chunk)})}\n\n" except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" @staticmethod async def _collect_stream_response(runner, request: dict) -> str: """Collect streaming response into a single string.""" response_parts = [] async for chunk in runner.stream_query(request): if hasattr(chunk, "text"): response_parts.append(chunk.text) else: response_parts.append(str(chunk)) return "".join(response_parts) @staticmethod async def _call_custom_function(func: Callable, request: dict): """Call custom function with proper parameters.""" if asyncio.iscoroutinefunction(func): return await func( user_id="default", request=request, request_id="generated", ) else: return func( user_id="default", request=request, request_id="generated", ) @staticmethod def _get_runner_instance(app: FastAPI): """Get runner instance from app state.""" if hasattr(app.state, "runner"): return app.state.runner return None @staticmethod def _create_handler_wrapper(handler: Callable): """Create a wrapper for a handler that preserves function signature. This wrapper maintains the handler's signature to enable FastAPI's automatic parameter parsing and dependency injection. For async handlers, it returns an async wrapper; for sync handlers, it returns a sync wrapper. Args: handler: The handler function to wrap Returns: A wrapped handler that preserves the original function signature """ is_awaitable = inspect.iscoroutinefunction(handler) if is_awaitable: @functools.wraps(handler) async def wrapped_handler(*args, **kwargs): return await handler(*args, **kwargs) wrapped_handler.__signature__ = inspect.signature(handler) return wrapped_handler else: @functools.wraps(handler) def wrapped_handler(*args, **kwargs): return handler(*args, **kwargs) wrapped_handler.__signature__ = inspect.signature(handler) return wrapped_handler @staticmethod def _to_sse_event(item: Any) -> str: """Normalize streaming items into JSON-serializable structures.""" def _serialize(value: Any, depth: int = 0): if depth > 20: return f"<too-deep-level-{depth}-{str(value)}>" if isinstance(value, (list, tuple, set)): return [_serialize(i, depth=depth + 1) for i in value] elif isinstance(value, dict): return { k: _serialize(v, depth=depth + 1) for k, v in value.items() } elif isinstance(value, (str, int, float, bool, type(None))): return value elif isinstance(value, BaseModel): return value.model_dump() elif is_dataclass(value): return asdict(value) for attr in ("to_map", "to_dict"): method = getattr(value, attr, None) if callable(method): return method() return str(value) serialized = _serialize(item, depth=0) return f"data: {json.dumps(serialized, ensure_ascii=False)}\n\n" @staticmethod def _create_streaming_parameter_wrapper( handler: Callable, ): """Create a wrapper for streaming handlers that handles parameter parsing.""" is_async_gen = inspect.isasyncgenfunction(handler) # NOTE: # ----- # FastAPI >= 0.123.5 uses Dependant.is_coroutine_callable, which in # turn unwraps callables via inspect.unwrap() and then inspects the # unwrapped target to decide whether it is a coroutine function / # generator / async generator. # # If we decorate an async-generator handler with # functools.wraps(handler), FastAPI will unwrap back to the original # async-generator function and *misclassify* the endpoint as # non-coroutine. It will then call our async wrapper *without awaiting # it*, and later try to JSON-encode the resulting coroutine object, # causing errors like: # TypeError("'coroutine' object is not iterable") # # To avoid that, we deliberately do NOT use functools.wraps() here. # Instead, we manually copy the key metadata (name, qualname, doc, # module, and signature) from the original handler, but we do NOT set # __wrapped__. This ensures: # * FastAPI sees the wrapper itself as the callable (an async def), # so Dependant.is_coroutine_callable is True, and it is properly # awaited. # * FastAPI still sees the correct signature for parameter parsing. if is_async_gen: async def wrapped_handler(*args, **kwargs): async def generate(): try: async for chunk in handler(*args, **kwargs): yield FastAPIAppFactory._to_sse_event( chunk, ) except Exception as e: logger.error( f"Error in streaming handler: {e}", exc_info=True, ) err_event = { "error": str(e), "error_type": e.__class__.__name__, "message": "Error in streaming generator", } yield FastAPIAppFactory._to_sse_event(err_event) return StreamingResponse( generate(), media_type="text/event-stream", ) else: def wrapped_handler(*args, **kwargs): def generate(): try: for chunk in handler(*args, **kwargs): yield FastAPIAppFactory._to_sse_event(chunk) except Exception as e: logger.error( f"Error in streaming handler: {e}", exc_info=True, ) err_event = { "error": str(e), "error_type": e.__class__.__name__, "message": "Error in streaming generator", } yield FastAPIAppFactory._to_sse_event(err_event) return StreamingResponse( generate(), media_type="text/event-stream", ) # Manually propagate essential metadata without creating a __wrapped__ # chain that would confuse FastAPI's unwrap logic. wrapped_handler.__name__ = getattr( handler, "__name__", wrapped_handler.__name__, ) wrapped_handler.__qualname__ = getattr( handler, "__qualname__", wrapped_handler.__qualname__, ) wrapped_handler.__doc__ = getattr( handler, "__doc__", wrapped_handler.__doc__, ) wrapped_handler.__module__ = getattr( handler, "__module__", wrapped_handler.__module__, ) wrapped_handler.__signature__ = inspect.signature(handler) # Make sure FastAPI doesn't see any stale __wrapped__ pointing back to # the original async-generator; if present, remove it. return wrapped_handler @staticmethod def _add_custom_endpoints(app: FastAPI): """Add all custom endpoints to the FastAPI application.""" if ( not hasattr(app.state, "custom_endpoints") or not app.state.custom_endpoints ): return for endpoint in app.state.custom_endpoints: FastAPIAppFactory._register_single_custom_endpoint( app, endpoint["path"], endpoint["handler"], endpoint["methods"], endpoint, # Pass the full endpoint config ) @staticmethod def _register_single_custom_endpoint( app: FastAPI, path: str, handler: Callable, methods: List[str], endpoint_config: Dict = None, ): """Register a single custom endpoint with proper async/sync handling.""" tags = ["custom"] for method in methods: # Check if this is a task endpoint if endpoint_config and endpoint_config.get("task_type"): # Create task endpoint with proper execution logic task_handler = FastAPIAppFactory._create_task_handler( app, handler, endpoint_config.get("queue", "default"), ) app.add_api_route( path, task_handler, methods=[method], tags=tags, ) # Add task status endpoint - align with BaseApp pattern status_path = f"{path}/{{task_id}}" status_handler = FastAPIAppFactory._create_task_status_handler( app, ) app.add_api_route( status_path, status_handler, methods=["GET"], tags=tags, ) else: # Regular endpoint handling with automatic parameter parsing # Check in the correct order: async gen > sync gen > async & # sync if inspect.isasyncgenfunction( handler, ) or inspect.isgeneratorfunction(handler): wrapped_handler = ( FastAPIAppFactory._create_streaming_parameter_wrapper( handler, ) ) app.add_api_route( path, wrapped_handler, methods=[method], tags=tags, response_model=None, ) else: # Non-streaming endpoint -> wrapper that preserves # handler signature wrapped_handler = ( FastAPIAppFactory._create_handler_wrapper(handler) ) app.add_api_route( path, wrapped_handler, methods=[method], response_model=None, tags=tags, ) @staticmethod def _create_task_handler(app: FastAPI, task_func: Callable, queue: str): """Create a task handler that executes functions asynchronously.""" async def task_endpoint(request: dict): try: import uuid # Generate task ID task_id = str(uuid.uuid4()) # Check if Celery is available if ( hasattr(app.state, "celery_mixin") and app.state.celery_mixin ): # Use Celery for task processing celery_mixin = app.state.celery_mixin # Register the task function if not already registered if not hasattr(task_func, "celery_task"): celery_task = celery_mixin.register_celery_task( task_func, queue, ) task_func.celery_task = celery_task # Submit task to Celery result = celery_mixin.submit_task(task_func, request) return { "task_id": result.id, "status": "submitted", "queue": queue, "message": f"Task {result.id} submitted to Celery " f"queue {queue}", } else: # Fallback to in-memory task processing import time # Initialize task storage if not exists if not hasattr(app.state, "active_tasks"): app.state.active_tasks = {} # Create task info for tracking task_info = { "task_id": task_id, "status": "submitted", "queue": queue, "submitted_at": time.time(), "request": request, } app.state.active_tasks[task_id] = task_info # Execute task asynchronously in background asyncio.create_task( FastAPIAppFactory._execute_background_task( app, task_id, task_func, request, queue, ), ) return { "task_id": task_id, "status": "submitted", "queue": queue, "message": f"Task {task_id} submitted to queue " f"{queue}", } except Exception as e: return { "error": str(e), "type": "task", "queue": queue, "status": "failed", } return task_endpoint @staticmethod async def _execute_background_task( app: FastAPI, task_id: str, func: Callable, request: dict, queue: str, ): """Execute task in background and update status.""" try: import time import concurrent.futures # Update status to running if ( hasattr(app.state, "active_tasks") and task_id in app.state.active_tasks ): app.state.active_tasks[task_id].update( { "status": "running", "started_at": time.time(), }, ) # Execute the actual task function if asyncio.iscoroutinefunction(func): result = await func(request) else: # Run sync function in thread pool to avoid blocking with concurrent.futures.ThreadPoolExecutor() as executor: result = await asyncio.get_event_loop().run_in_executor( executor, func, request, ) # Update status to completed if ( hasattr(app.state, "active_tasks") and task_id in app.state.active_tasks ): app.state.active_tasks[task_id].update( { "status": "completed", "result": result, "completed_at": time.time(), }, ) except Exception as e: # Update status to failed if ( hasattr(app.state, "active_tasks") and task_id in app.state.active_tasks ): app.state.active_tasks[task_id].update( { "status": "failed", "error": str(e), "failed_at": time.time(), }, ) @staticmethod def _create_task_status_handler(app: FastAPI): """Create a handler for checking task status.""" async def task_status_handler(task_id: str): if not task_id: return {"error": "task_id required"} # Check if Celery is available if hasattr(app.state, "celery_mixin") and app.state.celery_mixin: # Use Celery for task status checking celery_mixin = app.state.celery_mixin return celery_mixin.get_task_status(task_id) else: # Fallback to in-memory task status checking if ( not hasattr(app.state, "active_tasks") or task_id not in app.state.active_tasks ): return {"error": f"Task {task_id} not found"} task_info = app.state.active_tasks[task_id] task_status = task_info.get("status", "unknown") # Align with BaseApp.get_task logic - map internal status to # external status format if task_status in ["submitted", "running"]: return {"status": "pending", "result": None} elif task_status == "completed": return { "status": "finished", "result": task_info.get("result"), } elif task_status == "failed": return { "status": "error", "result": task_info.get("error", "Unknown error"), } else: return {"status": task_status, "result": None} return task_status_handler