Source code for agentscope_runtime.engine.deployers.utils.service_utils.routing.custom_endpoint_mixin

# -*- coding: utf-8 -*-
import functools
import inspect
import logging
import json
from typing import Callable, List, Any
from dataclasses import asdict, is_dataclass

from pydantic import BaseModel
from fastapi.responses import StreamingResponse

logger = logging.getLogger(__name__)


[docs] class CustomEndpointMixin:
[docs] def register_single_custom_endpoint( self, path: str, handler: Callable, methods: List[str], ): """Register a single custom endpoint with proper async/sync handling.""" tags = ["custom"] # Regular endpoint handling with automatic parameter parsing # Check in the correct order: async gen > sync gen > async & # sync if inspect.isasyncgenfunction( handler, ) or inspect.isgeneratorfunction(handler): wrapped_handler = ( CustomEndpointMixin._create_streaming_parameter_wrapper( handler, ) ) self.add_api_route( path, wrapped_handler, methods=methods, tags=tags, response_model=None, ) else: # Non-streaming endpoint -> wrapper that preserves # handler signature wrapped_handler = CustomEndpointMixin._create_handler_wrapper( handler, ) self.add_api_route( path, wrapped_handler, methods=methods, response_model=None, tags=tags, )
@staticmethod def _create_handler_wrapper(handler: Callable): """Create a wrapper for a handler that preserves function signature. This wrapper maintains the handler's signature to enable FastAPI's automatic parameter parsing and dependency injection. For async handlers, it returns an async wrapper; for sync handlers, it returns a sync wrapper. Args: handler: The handler function to wrap Returns: A wrapped handler that preserves the original function signature """ is_awaitable = inspect.iscoroutinefunction(handler) if is_awaitable: @functools.wraps(handler) async def wrapped_handler(*args, **kwargs): return await handler(*args, **kwargs) wrapped_handler.__signature__ = inspect.signature(handler) return wrapped_handler else: @functools.wraps(handler) def wrapped_handler(*args, **kwargs): return handler(*args, **kwargs) wrapped_handler.__signature__ = inspect.signature(handler) return wrapped_handler @staticmethod def _to_sse_event(item: Any) -> str: """Normalize streaming items into JSON-serializable structures.""" def _serialize(value: Any, depth: int = 0): # pylint:disable=too-many-return-statements if depth > 20: return f"<too-deep-level-{depth}-{str(value)}>" if isinstance(value, (list, tuple, set)): return [_serialize(i, depth=depth + 1) for i in value] elif isinstance(value, dict): return { k: _serialize(v, depth=depth + 1) for k, v in value.items() } elif isinstance(value, (str, int, float, bool, type(None))): return value elif isinstance(value, BaseModel): return value.model_dump() elif is_dataclass(value): return asdict(value) for attr in ("to_map", "to_dict"): method = getattr(value, attr, None) if callable(method): return method() return str(value) serialized = _serialize(item, depth=0) return f"data: {json.dumps(serialized, ensure_ascii=False)}\n\n" @staticmethod def _create_streaming_parameter_wrapper( handler: Callable, ): """Create a wrapper for streaming handlers that handles parameter parsing.""" is_async_gen = inspect.isasyncgenfunction(handler) # NOTE: # ----- # FastAPI >= 0.123.5 uses Dependant.is_coroutine_callable, which in # turn unwraps callables via inspect.unwrap() and then inspects the # unwrapped target to decide whether it is a coroutine function / # generator / async generator. # # If we decorate an async-generator handler with # functools.wraps(handler), FastAPI will unwrap back to the original # async-generator function and *misclassify* the endpoint as # non-coroutine. It will then call our async wrapper *without awaiting # it*, and later try to JSON-encode the resulting coroutine object, # causing errors like: # TypeError("'coroutine' object is not iterable") # # To avoid that, we deliberately do NOT use functools.wraps() here. # Instead, we manually copy the key metadata (name, qualname, doc, # module, and signature) from the original handler, but we do NOT set # __wrapped__. This ensures: # * FastAPI sees the wrapper itself as the callable (an async def), # so Dependant.is_coroutine_callable is True, and it is properly # awaited. # * FastAPI still sees the correct signature for parameter parsing. if is_async_gen: async def wrapped_handler(*args, **kwargs): async def generate(): try: async for chunk in handler(*args, **kwargs): yield CustomEndpointMixin._to_sse_event( chunk, ) except Exception as e: logger.error( f"Error in streaming handler: {e}", exc_info=True, ) err_event = { "error": str(e), "error_type": e.__class__.__name__, "message": "Error in streaming generator", } yield CustomEndpointMixin._to_sse_event(err_event) return StreamingResponse( generate(), media_type="text/event-stream", ) else: def wrapped_handler(*args, **kwargs): def generate(): try: for chunk in handler(*args, **kwargs): yield CustomEndpointMixin._to_sse_event(chunk) except Exception as e: logger.error( f"Error in streaming handler: {e}", exc_info=True, ) err_event = { "error": str(e), "error_type": e.__class__.__name__, "message": "Error in streaming generator", } yield CustomEndpointMixin._to_sse_event(err_event) return StreamingResponse( generate(), media_type="text/event-stream", ) # Manually propagate essential metadata without creating a __wrapped__ # chain that would confuse FastAPI's unwrap logic. wrapped_handler.__name__ = getattr( handler, "__name__", wrapped_handler.__name__, ) wrapped_handler.__qualname__ = getattr( handler, "__qualname__", wrapped_handler.__qualname__, ) wrapped_handler.__doc__ = getattr( handler, "__doc__", wrapped_handler.__doc__, ) wrapped_handler.__module__ = getattr( handler, "__module__", wrapped_handler.__module__, ) wrapped_handler.__signature__ = inspect.signature(handler) # Make sure FastAPI doesn't see any stale __wrapped__ pointing back to # the original async-generator; if present, remove it. return wrapped_handler