# -*- coding: utf-8 -*-
# type: ignore
import contextvars
import inspect
import json
import os
import re
import time
import threading
import uuid
from collections.abc import Callable
from copy import deepcopy
from enum import Enum
from functools import wraps
from typing import (
Any,
AsyncGenerator,
Dict,
Iterable,
Optional,
TypeVar,
Union,
)
from pydantic import BaseModel
from opentelemetry.propagate import extract
from opentelemetry.context import attach
from opentelemetry.trace import StatusCode, NoOpTracerProvider
from opentelemetry import trace as ot_trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPSpanGrpcExporter,
)
from opentelemetry.sdk.resources import SERVICE_NAME, SERVICE_VERSION, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
ConsoleSpanExporter,
)
from .asyncio_util import aenumerate
from .message_util import (
merge_incremental_chunk,
get_finish_reason,
)
from .base import Tracer, TracerHandler, EventContext
from .tracing_metric import TraceType
from .local_logging_handler import LocalLogHandler
from .tracing_util import TracingUtil
T_co = TypeVar("T_co", covariant=True)
def _str_to_bool(value: str) -> bool:
"""Convert string to boolean value.
Args:
value (str): String value to convert.
Returns:
bool: Boolean representation of the string.
"""
return value.lower() in ("true", "1", "yes", "on")
[docs]
class MineType(str, Enum):
"""MIME type enumeration for content types."""
TEXT = "text/plain"
JSON = "application/json"
_parent_span_context: contextvars.ContextVar = contextvars.ContextVar(
"_parent_span_context",
default=None,
)
_current_request_id: contextvars.ContextVar[str] = contextvars.ContextVar(
"_current_request_id",
default="",
)
_current_trace_header: contextvars.ContextVar[dict] = contextvars.ContextVar(
"_current_trace_header",
default={},
)
[docs]
def trace( # pylint: disable=too-many-statements
trace_type: Union[TraceType, str, None] = None,
*,
trace_name: Optional[str] = None,
is_root_span: Optional[bool] = None,
get_finish_reason_func: Optional[
Callable[[Any], Optional[str]]
] = get_finish_reason,
merge_output_func: Optional[
Callable[[Any], Union[BaseModel, dict, str, None]]
] = merge_incremental_chunk,
) -> Any:
"""Decorator for tracing function execution.
Args:
trace_type (Union[TraceType, str]): The type of trace event.
trace_name (Optional[str]): The name of trace event.
is_root_span (Optional[bool]): Specify current span as root span
get_finish_reason_func(Optional[Callable]): The function to judge
if stopped
merge_output_func(Optional[Callable]): The function to merge outputs
Returns:
Any: The decorated function with tracing capabilities.
"""
def wrapper(func: Any) -> Any:
"""Wrapper function that applies tracing to the target function.
Args:
func (Any): The function to be traced.
Returns:
Any: The wrapped function with appropriate tracing logic.
"""
@wraps(func)
async def async_exec(*args: Any, **kwargs: Any) -> Any:
"""Execute async function with tracing.
Args:
*args (Any): Positional arguments for the function.
**kwargs (Any): Keyword arguments for the function.
Returns:
Any: The result of the traced function.
"""
_init_trace_context()
start_payload = _get_start_payload(args, kwargs, func)
trace_context = kwargs.get("trace_context") if kwargs else None
if trace_context:
attach(trace_context)
parent_ctx = (
trace_context
if trace_context
else _parent_span_context.get(None)
)
(
final_trace_type,
final_trace_name,
final_is_root_span,
) = _validate_trace_options(
trace_type,
trace_name,
is_root_span,
func.__name__,
parent_ctx,
)
# Auto generate request_id for root span if needed
_set_request_id(parent_ctx)
common_attrs = TracingUtil.get_common_attributes() or {}
span_attributes = {
"gen_ai.span.kind": final_trace_type,
"gen_ai.user.query_root_flag": 1 if final_is_root_span else 0,
"input.mine_type": MineType.JSON,
"input.value": json.dumps(
start_payload,
ensure_ascii=False,
),
**common_attrs,
}
with _get_ot_tracer().start_as_current_span(
final_trace_name,
context=parent_ctx,
attributes=span_attributes,
) as span:
span.set_status(status=StatusCode.OK)
with _tracer.event(
span,
final_trace_name,
payload=start_payload,
) as event:
_parent_span_context.set(
ot_trace.set_span_in_context(span),
)
if _function_accepts_kwargs(func):
func_kwargs = kwargs.copy() if kwargs else {}
func_kwargs["trace_event"] = event
else:
func_kwargs = kwargs.copy() if kwargs else {}
try:
result = await func(*args, **func_kwargs)
end_payload = _obj_to_dict(result)
(
output_mine_type,
output_value,
) = _get_ot_type_and_value(end_payload)
span.set_attribute(
"output.mine_type",
output_mine_type,
)
span.set_attribute(
"output.value",
output_value,
)
event.on_end(payload=end_payload)
return result
except Exception as e:
span.set_status(
status=StatusCode.ERROR,
description=f"exception={e}",
)
event.on_log(str(e))
raise e
finally:
if not trace_context:
_parent_span_context.set(parent_ctx)
@wraps(func)
def sync_exec(*args: Any, **kwargs: Any) -> Any:
"""Execute sync function with tracing.
Args:
*args (Any): Positional arguments for the function.
**kwargs (Any): Keyword arguments for the function.
Returns:
Any: The result of the traced function.
"""
_init_trace_context()
start_payload = _get_start_payload(args, kwargs, func)
trace_context = kwargs.get("trace_context") if kwargs else None
if trace_context:
attach(trace_context)
parent_ctx = (
trace_context
if trace_context
else _parent_span_context.get(None)
)
(
final_trace_type,
final_trace_name,
final_is_root_span,
) = _validate_trace_options(
trace_type,
trace_name,
is_root_span,
func.__name__,
parent_ctx,
)
# Auto generate request_id for root span if needed
_set_request_id(parent_ctx)
common_attrs = TracingUtil.get_common_attributes() or {}
span_attributes = {
"gen_ai.span.kind": final_trace_type,
"gen_ai.user.query_root_flag": 1 if final_is_root_span else 0,
"input.mine_type": MineType.JSON,
"input.value": json.dumps(
start_payload,
ensure_ascii=False,
),
**common_attrs,
}
with _get_ot_tracer().start_as_current_span(
final_trace_name,
context=parent_ctx,
attributes=span_attributes,
) as span:
span.set_status(status=StatusCode.OK)
with _tracer.event(
span,
final_trace_name,
payload=start_payload,
) as event:
_parent_span_context.set(
ot_trace.set_span_in_context(span),
)
if _function_accepts_kwargs(func):
func_kwargs = kwargs.copy() if kwargs else {}
func_kwargs["trace_event"] = event
else:
func_kwargs = kwargs.copy() if kwargs else {}
try:
result = func(*args, **func_kwargs)
end_payload = _obj_to_dict(result)
(
output_mine_type,
output_value,
) = _get_ot_type_and_value(end_payload)
span.set_attribute(
"output.mine_type",
output_mine_type,
)
span.set_attribute(
"output.value",
output_value,
)
event.on_end(payload=end_payload)
return result
except Exception as e:
span.set_status(
status=StatusCode.ERROR,
description=f"exception={e}",
)
event.on_log(str(e))
raise e
finally:
if not trace_context:
_parent_span_context.set(parent_ctx)
@wraps(func)
async def async_iter_task( # pylint: disable=too-many-statements
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[T_co, None]:
"""Execute async generator function with tracing.
Args:
*args (Any): Positional arguments for the function.
**kwargs (Any): Keyword arguments for the function.
Yields:
T_co: Items from the original generator with tracing.
"""
_init_trace_context()
start_payload = _get_start_payload(args, kwargs, func)
trace_context = kwargs.get("trace_context") if kwargs else None
if trace_context:
attach(trace_context)
parent_ctx = (
trace_context
if trace_context
else _parent_span_context.get(None)
)
(
final_trace_type,
final_trace_name,
final_is_root_span,
) = _validate_trace_options(
trace_type,
trace_name,
is_root_span,
func.__name__,
parent_ctx,
)
# Auto generate request_id for root span if needed
_set_request_id(parent_ctx)
common_attrs = TracingUtil.get_common_attributes() or {}
span_attributes = {
"gen_ai.span.kind": final_trace_type,
"gen_ai.user.query_root_flag": 1 if final_is_root_span else 0,
"input.mine_type": MineType.JSON,
"input.value": json.dumps(
start_payload,
ensure_ascii=False,
),
**common_attrs,
}
with _get_ot_tracer().start_as_current_span(
final_trace_name,
context=parent_ctx,
attributes=span_attributes,
) as span:
span.set_status(status=StatusCode.OK)
with _tracer.event(
span,
final_trace_name,
payload=start_payload,
) as event:
_parent_span_context.set(
ot_trace.set_span_in_context(span),
)
if _function_accepts_kwargs(func):
func_kwargs = kwargs.copy() if kwargs else {}
func_kwargs["trace_event"] = event
else:
func_kwargs = kwargs.copy() if kwargs else {}
cumulated = []
async def iter_entry() -> AsyncGenerator[T_co, None]:
"""Internal async generator for processing items.
Yields:
T_co: Items from the original generator with
tracing.
"""
try:
start_time = int(time.time() * 1000)
async for i, resp in aenumerate(
func(*args, **func_kwargs),
): # type: ignore
yield resp
cumulated.append(resp)
if i == 0:
_trace_first_resp(
resp,
event,
span,
start_time,
)
if get_finish_reason_func is not None:
_trace_last_resp(
resp,
get_finish_reason_func,
event,
span,
)
if cumulated and merge_output_func is not None:
_trace_merged_resp(
cumulated,
merge_output_func,
event,
span,
)
except Exception as e:
span.set_status(
status=StatusCode.ERROR,
description=f"exception={e}",
)
event.on_log(str(e))
raise e
finally:
if not trace_context:
_parent_span_context.set(parent_ctx)
try:
async for resp in iter_entry():
yield resp
except Exception as e:
raise e
@wraps(func)
def iter_task(*args: Any, **kwargs: Any) -> Iterable[T_co]:
"""Execute generator function with tracing.
Args:
*args (Any): Positional arguments for the function.
**kwargs (Any): Keyword arguments for the function.
Yields:
T_co: Items from the traced generator.
"""
_init_trace_context()
start_payload = _get_start_payload(args, kwargs, func)
trace_context = kwargs.get("trace_context") if kwargs else None
if trace_context:
attach(trace_context)
parent_ctx = (
trace_context
if trace_context
else _parent_span_context.get(None)
)
(
final_trace_type,
final_trace_name,
final_is_root_span,
) = _validate_trace_options(
trace_type,
trace_name,
is_root_span,
func.__name__,
parent_ctx,
)
# Auto generate request_id for root span if needed
_set_request_id(parent_ctx)
common_attrs = TracingUtil.get_common_attributes() or {}
span_attributes = {
"gen_ai.span.kind": final_trace_type,
"gen_ai.user.query_root_flag": 1 if final_is_root_span else 0,
"input.mine_type": MineType.JSON,
"input.value": json.dumps(
start_payload,
ensure_ascii=False,
),
**common_attrs,
}
with _get_ot_tracer().start_as_current_span(
final_trace_name,
context=parent_ctx,
attributes=span_attributes,
) as span:
span.set_status(status=StatusCode.OK)
with _tracer.event(
span,
final_trace_name,
payload=start_payload,
) as event:
_parent_span_context.set(
ot_trace.set_span_in_context(span),
)
try:
if _function_accepts_kwargs(func):
func_kwargs = kwargs.copy() if kwargs else {}
func_kwargs["trace_event"] = event
else:
func_kwargs = kwargs.copy() if kwargs else {}
cumulated = []
start_time = int(time.time() * 1000)
for i, resp in enumerate(func(*args, **func_kwargs)):
yield resp
cumulated.append(resp)
if i == 0:
_trace_first_resp(
resp,
event,
span,
start_time,
)
if get_finish_reason_func is not None:
_trace_last_resp(
resp,
get_finish_reason_func,
event,
span,
)
if cumulated and merge_output_func is not None:
_trace_merged_resp(
cumulated,
merge_output_func,
event,
span,
)
except Exception as e:
span.set_status(
status=StatusCode.ERROR,
description=f"exception={e}",
)
event.on_log(str(e))
raise e
finally:
if not trace_context:
_parent_span_context.set(parent_ctx)
# Choose the appropriate wrapper based on function type
if inspect.isasyncgenfunction(func):
wrapper_func = async_iter_task
elif inspect.isgeneratorfunction(func):
wrapper_func = iter_task
elif inspect.iscoroutinefunction(func):
wrapper_func = async_exec
else:
wrapper_func = sync_exec
# Preserve the original function's signature
try:
wrapper_func.__signature__ = inspect.signature(func)
except (ValueError, TypeError):
pass
return wrapper_func
return wrapper
def _get_start_payload(args: Any, kwargs: Any, func: Any = None) -> Dict:
"""Extract and format the start payload from function arguments.
Args:
args (Any): Positional arguments from the function call.
kwargs (Any): Keyword arguments from the function call.
func (Any): The function being traced (optional).
Returns:
Dict: The formatted start payload for tracing.
"""
merged = {}
# 处理位置参数:尝试将位置参数与函数签名中的参数名对应
if func is not None and isinstance(args, tuple) and len(args) > 0:
try:
sig = inspect.signature(func)
params = list(sig.parameters.values())
# 跳过self参数(如果是实例方法)
start_idx = 0
if params and params[0].name == "self":
start_idx = 1
# 将位置参数与参数名对应
for i, arg in enumerate(args[start_idx:], start=start_idx):
if i < len(params):
param = params[i]
# 只处理位置参数和位置或关键字参数,跳过*args和**kwargs
if param.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
merged[param.name] = _obj_to_dict(arg)
except (ValueError, TypeError, IndexError):
# 如果无法获取函数签名,回退到原来的逻辑
pass
# 如果没有函数信息或无法解析,使用原来的逻辑
if not merged and isinstance(args, tuple) and len(args) > 0:
dict_args = _obj_to_dict(args)
if isinstance(dict_args, list):
for item in dict_args:
if isinstance(item, dict):
merged.update(item)
elif isinstance(dict_args, dict):
merged.update(dict_args)
# 处理关键字参数
dict_kwargs = _obj_to_dict(kwargs)
dict_kwargs = {
key: value
for key, value in dict_kwargs.items()
if not key.startswith("trace_")
}
if dict_kwargs:
merged.update(dict_kwargs)
return merged
def _init_trace_context() -> None:
current_req_id = _current_request_id.get("")
user_req_id = TracingUtil.get_request_id()
if user_req_id and user_req_id != current_req_id:
_parent_span_context.set(None)
_current_request_id.set(user_req_id)
current_trace_header = _current_trace_header.get({})
user_trace_header = TracingUtil.get_trace_header()
if user_trace_header and user_trace_header != current_trace_header:
_current_trace_header.set(user_trace_header)
context = extract(user_trace_header)
attach(context)
def _set_request_id(parent_ctx: Any) -> None:
"""Auto generate request_id for root span if not already set.
Args:
parent_ctx: Parent context. If None, this is a root span.
"""
# Check if we already have a request_id (from context var or baggage)
current_request_id = TracingUtil.get_request_id()
if parent_ctx is None:
# This is a root span
if not current_request_id:
# For root spans without request_id, generate a new one
current_parent_span = _parent_span_context.get(None)
if current_parent_span is None:
# This is a truly new request, generate a new request_id
new_request_id = str(uuid.uuid4())
TracingUtil.set_request_id(new_request_id)
else:
if current_request_id and not TracingUtil.get_common_attributes().get(
"request_id",
):
# Set common attributes from baggage request_id
TracingUtil.set_request_id(current_request_id)
def _trace_first_resp(
resp: Any,
event: EventContext,
span: Any,
start_time: int,
) -> None:
payload = _obj_to_dict(resp)
event.on_log(
"",
**{
"step_suffix": "first_resp",
"payload": payload,
},
)
span.set_attribute(
"gen_ai.response.first_delay",
int(time.time() * 1000) - start_time,
)
_, output_value = _get_ot_type_and_value(payload)
span.set_attribute(
"gen_ai.response.first_pkg",
output_value,
)
def _trace_last_resp(
resp: Any,
func: Callable,
event: EventContext,
span: Any,
) -> None:
resp_copy = deepcopy(resp)
finish_reason = func(resp_copy)
if finish_reason:
step_suffix = "last_resp" if finish_reason == "stop" else finish_reason
payload = _obj_to_dict(resp_copy)
event.on_log(
"",
**{
"step_suffix": step_suffix,
"payload": payload,
},
)
_, output_value = _get_ot_type_and_value(payload)
span.set_attribute(
"gen_ai.response.pkg_" + finish_reason,
output_value,
)
def _trace_merged_resp(
cumulated: Any,
func: Callable,
event: EventContext,
span: Any,
) -> None:
cumulated_copy = deepcopy(cumulated)
merged_output = func(cumulated_copy)
end_payload = _obj_to_dict(merged_output)
output_mine_type, output_value = _get_ot_type_and_value(end_payload)
span.set_attribute(
"output.mine_type",
output_mine_type,
)
span.set_attribute(
"output.value",
output_value,
)
event.on_end(
payload=end_payload,
)
def _get_ot_type_and_value(payload: Any) -> tuple[MineType, Any]:
if isinstance(payload, dict):
mine_type = MineType.JSON
value = json.dumps(payload, ensure_ascii=False)
else:
mine_type = MineType.TEXT
if isinstance(payload, (str, int, float, bool)):
value = payload
else:
value = str(payload)
return mine_type, value
def _validate_trace_options(
trace_type: Union[TraceType, str, None] = None,
trace_name: Optional[str] = None,
is_root_span: Optional[bool] = None,
function_name: Optional[str] = None,
parent_ctx: Optional[Any] = None,
) -> tuple[str, str | None, bool | None]:
out_is_root_span = (
is_root_span
and parent_ctx is None
and not _parent_span_context.get(None)
)
if out_is_root_span:
out_trace_type = TraceType.CHAIN
out_trace_name = "FullCodeApp"
else:
if trace_type:
if isinstance(trace_type, str):
out_trace_type = TraceType(trace_type)
else:
out_trace_type = trace_type
else:
out_trace_type = TraceType.OTHER
if trace_name:
out_trace_name = trace_name
else:
if function_name:
out_trace_name = function_name
else:
out_trace_name = str(out_trace_type).lower()
return out_trace_type, out_trace_name, out_is_root_span
def _obj_to_dict(obj: Any) -> Any:
"""Convert an object to a dictionary representation for tracing.
Args:
obj (Any): The object to convert.
Returns:
Any: The dictionary representation of the object, or the object
itself if it's a primitive type.
"""
if obj is None:
return {}
elif isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, dict):
return {k: _obj_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, (list, set, tuple)):
return [_obj_to_dict(item) for item in obj]
elif isinstance(obj, BaseModel):
return obj.model_dump()
else:
result = None
try:
result = str(obj)
except Exception as e:
print(f"{obj} str method failed with error: {e}")
return result
def _function_accepts_kwargs(func: Any) -> bool:
"""Check if a function accepts **kwargs parameter.
Args:
func (Any): The function to check.
Returns:
bool: True if the function accepts **kwargs, False otherwise.
"""
try:
sig = inspect.signature(func)
return any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in sig.parameters.values()
)
except (ValueError, TypeError):
return False
def _get_service_name() -> str:
"""Get service name
Returns:
str: The extracted service name, or the original name if extraction
fails.
"""
service_name = os.getenv("SERVICE_NAME") or os.getenv("DS_SVC_NAME")
if not service_name:
service_name = "agentscope_runtime-service"
pattern = r"deployment\.([^-]+(?:-[^-]+)*?)(?=-[^-]+-[^-]+$)"
match = re.search(pattern, service_name)
if match:
return match.group(1)
else:
return service_name
def _get_tracer() -> Tracer:
handlers: list[TracerHandler] = []
if _str_to_bool(os.getenv("TRACE_ENABLE_LOG", "false")):
handlers.append(LocalLogHandler(enable_console=True))
tracer = Tracer(handlers=handlers)
return tracer
_otel_tracer_lock = threading.Lock()
_otel_tracer = None
# TODO: support more tracing protocols and platforms
def _get_ot_tracer() -> ot_trace.Tracer:
"""Get the OpenTelemetry tracer.
Returns:
ot_trace.Tracer: The OpenTelemetry tracer instance.
"""
def _get_ot_tracer_inner() -> ot_trace.Tracer:
existing_provider = ot_trace.get_tracer_provider()
if not isinstance(existing_provider, NoOpTracerProvider):
return ot_trace.get_tracer("agentscope_runtime")
resource = Resource(
attributes={
SERVICE_NAME: _get_service_name(),
SERVICE_VERSION: os.getenv("SERVICE_VERSION", "1.0.0"),
"source": "agentscope_runtime-source",
},
)
provider = TracerProvider(resource=resource)
if _str_to_bool(os.getenv("TRACE_ENABLE_REPORT", "false")):
span_exporter = BatchSpanProcessor(
OTLPSpanGrpcExporter(
endpoint=os.getenv("TRACE_ENDPOINT", ""),
headers=f"Authentication="
f"{os.getenv('TRACE_AUTHENTICATION', '')}",
),
)
provider.add_span_processor(span_exporter)
if _str_to_bool(os.getenv("TRACE_ENABLE_DEBUG", "false")):
span_logger = BatchSpanProcessor(ConsoleSpanExporter())
provider.add_span_processor(span_logger)
tracer = ot_trace.get_tracer(
"agentscope_runtime",
tracer_provider=provider,
)
return tracer
global _otel_tracer
if _otel_tracer is None:
with _otel_tracer_lock:
if _otel_tracer is None:
_otel_tracer = _get_ot_tracer_inner()
return _otel_tracer
_tracer = _get_tracer()