# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name, protected-access
# pylint: disable=too-many-branches, too-many-statements
# pylint: disable=redefined-outer-name, protected-access, too-many-branches
import inspect
import json
import logging
import os
import secrets
import traceback
from functools import wraps
from typing import Optional, Dict, Union, List
import requests
import shortuuid
from ..client import SandboxHttpClient, TrainingSandboxClient
from ..enums import SandboxType
from ..manager.storage import (
LocalStorage,
OSSStorage,
)
from ..model import (
ContainerModel,
SandboxManagerEnvConfig,
)
from ..registry import SandboxRegistry
from ...common.collections import (
RedisMapping,
RedisQueue,
InMemoryMapping,
InMemoryQueue,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
[docs]
def remote_wrapper(
method: str = "POST",
success_key: str = "data",
):
"""
Decorator to handle both remote and local method execution.
"""
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.http_session:
# Execute the original function locally
return func(self, *args, **kwargs)
endpoint = "/" + func.__name__
# Prepare data for remote call
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
data = dict(zip(param_names, args))
data.update(kwargs)
# Make the remote HTTP request
response = self._make_request(method, endpoint, data)
# Process response
if success_key:
return response.get(success_key)
return response
wrapper._is_remote_wrapper = True
wrapper._http_method = method
wrapper._path = "/" + func.__name__
return wrapper
return decorator
[docs]
class SandboxManager:
[docs]
def __init__(
self,
config: Optional[SandboxManagerEnvConfig] = None,
base_url=None,
bearer_token=None,
default_type: Union[
SandboxType,
str,
List[Union[SandboxType, str]],
] = SandboxType.BASE,
):
if base_url:
# Initialize HTTP session for remote mode with bearer token
# authentication
self.http_session = requests.Session()
self.http_session.timeout = 30
self.base_url = base_url.rstrip("/")
if bearer_token:
self.http_session.headers.update(
{"Authorization": f"Bearer {bearer_token}"},
)
# Remote mode, return directly
return
else:
self.http_session = None
self.base_url = None
if not config:
config = SandboxManagerEnvConfig(
file_system="local",
redis_enabled=False,
container_deployment="docker",
pool_size=0,
default_mount_dir="sessions_mount_dir",
)
# Support multi sandbox pool
if isinstance(default_type, (SandboxType, str)):
self.default_type = [SandboxType(default_type)]
else:
self.default_type = [SandboxType(x) for x in list(default_type)]
self.workdir = "/workspace"
self.config = config
self.pool_size = self.config.pool_size
self.prefix = self.config.container_prefix_key
self.default_mount_dir = self.config.default_mount_dir
self.readonly_mounts = self.config.readonly_mounts
self.storage_folder = (
self.config.storage_folder or self.default_mount_dir
)
self.pool_queues = {}
if self.config.redis_enabled:
import redis
redis_client = redis.Redis(
host=self.config.redis_server,
port=self.config.redis_port,
db=self.config.redis_db,
username=self.config.redis_user,
password=self.config.redis_password,
decode_responses=True,
)
try:
redis_client.ping()
except ConnectionError as e:
raise RuntimeError(
"Unable to connect to the Redis server.",
) from e
self.container_mapping = RedisMapping(redis_client)
self.session_mapping = RedisMapping(
redis_client,
prefix="session_mapping",
)
# Init multi sand box pool
for t in self.default_type:
queue_key = f"{self.config.redis_container_pool_key}:{t.value}"
self.pool_queues[t] = RedisQueue(redis_client, queue_key)
else:
self.container_mapping = InMemoryMapping()
self.session_mapping = InMemoryMapping()
# Init multi sand box pool
for t in self.default_type:
self.pool_queues[t] = InMemoryQueue()
self.container_deployment = self.config.container_deployment
if base_url is None:
if self.container_deployment == "docker":
from ...common.container_clients.docker_client import (
DockerClient,
)
self.client = DockerClient(config=self.config)
elif self.container_deployment == "k8s":
from ...common.container_clients.kubernetes_client import (
KubernetesClient,
)
self.client = KubernetesClient(config=self.config)
elif self.container_deployment == "agentrun":
from ...common.container_clients.agentrun_client import (
AgentRunClient,
)
self.client = AgentRunClient(config=self.config)
elif self.container_deployment == "fc":
from ...common.container_clients.fc_client import FCClient
self.client = FCClient(config=self.config)
else:
raise NotImplementedError("Not implemented")
else:
self.client = None
self.file_system = self.config.file_system
if self.file_system == "oss":
self.storage = OSSStorage(
self.config.oss_access_key_id,
self.config.oss_access_key_secret,
self.config.oss_endpoint,
self.config.oss_bucket_name,
)
else:
self.storage = LocalStorage()
if self.pool_size > 0:
self._init_container_pool()
logger.debug(str(config))
def __enter__(self):
logger.debug(
"Entering SandboxManager context (sync). "
"Cleanup will be performed automatically on exit.",
)
return self
def __exit__(self, exc_type, exc_value, traceback):
logger.debug(
"Exiting SandboxManager context (sync). Cleaning up resources.",
)
self.cleanup()
def _generate_container_key(self, session_id):
return f"{self.prefix}{session_id}"
def _make_request(self, method: str, endpoint: str, data: dict):
"""
Make an HTTP request to the specified endpoint.
"""
url = f"{self.base_url}/{endpoint.lstrip('/')}"
if method.upper() == "GET":
response = self.http_session.get(url, params=data)
else:
response = self.http_session.request(method, url, json=data)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
error_components = [
f"HTTP {response.status_code} Error: {str(e)}",
]
try:
server_response = response.json()
if "detail" in server_response:
error_components.append(
f"Server Detail: {server_response['detail']}",
)
elif "error" in server_response:
error_components.append(
f"Server Error: {server_response['error']}",
)
else:
error_components.append(
f"Server Response: {server_response}",
)
except (ValueError, json.JSONDecodeError):
if response.text:
error_components.append(
f"Server Response: {response.text}",
)
error = " | ".join(error_components)
logger.error(f"Error making request: {error}")
return {"data": f"Error: {error}"}
return response.json()
def _init_container_pool(self):
"""
Init runtime pool
"""
for t in self.default_type:
queue = self.pool_queues[t]
while queue.size() < self.pool_size:
try:
container_name = self.create(sandbox_type=t.value)
container_model = self.container_mapping.get(
container_name,
)
if container_model:
# Check the pool size again to avoid race condition
if queue.size() < self.pool_size:
queue.enqueue(container_model)
else:
# The pool size has reached the limit
self.release(container_name)
break
else:
logger.error("Failed to create container for pool")
break
except Exception as e:
logger.error(f"Error initializing runtime pool: {e}")
break
[docs]
@remote_wrapper()
def cleanup(self):
logger.debug(
"Cleaning up resources.",
)
# Clean up pool first
for queue in self.pool_queues.values():
try:
while queue.size() > 0:
container_json = queue.dequeue()
if container_json:
container_model = ContainerModel(**container_json)
logger.debug(
f"Destroy container"
f" {container_model.container_id}",
)
self.release(container_model.session_id)
except Exception as e:
logger.error(f"Error cleaning up runtime pool: {e}")
# Clean up rest container
for key in self.container_mapping.scan(self.prefix):
try:
container_json = self.container_mapping.get(key)
if container_json:
container_model = ContainerModel(**container_json)
logger.debug(
f"Destroy container {container_model.container_id}",
)
self.release(container_model.session_id)
except Exception as e:
logger.error(
f"Error cleaning up container {key}: {e}",
)
[docs]
@remote_wrapper()
def create_from_pool(self, sandbox_type=None, meta: Optional[Dict] = None):
"""Try to get a container from runtime pool"""
# If not specified, use the first one
sandbox_type = SandboxType(sandbox_type or self.default_type[0])
if sandbox_type not in self.pool_queues:
return self.create(sandbox_type=sandbox_type.value, meta=meta)
queue = self.pool_queues[sandbox_type]
cnt = 0
try:
while True:
if cnt > self.pool_size:
raise RuntimeError(
"No container available in pool after check the pool.",
)
cnt += 1
# Add a new one to container
container_name = self.create(sandbox_type=sandbox_type)
new_container_model = self.container_mapping.get(
container_name,
)
if new_container_model:
queue.enqueue(
new_container_model,
)
container_json = queue.dequeue()
if not container_json:
raise RuntimeError(
"No container available in pool.",
)
container_model = ContainerModel(**container_json)
# Add meta field to container
if meta and not container_model.meta:
container_model.meta = meta
self.container_mapping.set(
container_model.container_name,
container_model.model_dump(),
)
# Update session mapping
if "session_ctx_id" in meta:
env_ids = (
self.session_mapping.get(
meta["session_ctx_id"],
)
or []
)
if container_model.container_name not in env_ids:
env_ids.append(container_model.container_name)
self.session_mapping.set(
meta["session_ctx_id"],
env_ids,
)
logger.debug(
f"Retrieved container from pool:"
f" {container_model.session_id}",
)
if (
container_model.version
!= SandboxRegistry.get_image_by_type(
sandbox_type,
)
):
logger.warning(
f"Container {container_model.session_id} outdated, "
f"trying next one in pool",
)
self.release(container_model.session_id)
continue
if self.client.inspect(container_model.container_id) is None:
logger.warning(
f"Container {container_model.container_id} not found "
f"or unexpected error happens.",
)
continue
if (
self.client.get_status(container_model.container_id)
== "running"
):
return container_model.container_name
else:
logger.error(
f"Container {container_model.container_id} is not "
f"running. Trying next one in pool.",
)
# Destroy the stopped container
self.release(container_model.session_id)
except Exception as e:
logger.warning(
"Error getting container from pool, create a new one.",
)
logger.debug(f"{e}: {traceback.format_exc()}")
return self.create()
[docs]
@remote_wrapper()
def create(
self,
sandbox_type=None,
mount_dir=None, # TODO: remove to avoid leaking
storage_path=None,
environment: Optional[Dict] = None,
meta: Optional[Dict] = None,
):
if sandbox_type is not None:
target_sandbox_type = SandboxType(sandbox_type)
else:
target_sandbox_type = self.default_type[0]
config = SandboxRegistry.get_config_by_type(target_sandbox_type)
if not config:
logger.warning(
f"Not found sandbox {target_sandbox_type}, " f"using default",
)
config = SandboxRegistry.get_config_by_type(
self.default_type[0],
)
image = config.image_name
environment = {
**(config.environment if config.environment else {}),
**(environment if environment else {}),
}
for key, value in environment.items():
if value is None:
logger.error(
f"Env variable {key} is None.",
)
return None
alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
short_uuid = shortuuid.ShortUUID(alphabet=alphabet).uuid()
session_id = str(short_uuid)
if not mount_dir:
if self.default_mount_dir:
mount_dir = os.path.join(self.default_mount_dir, session_id)
os.makedirs(mount_dir, exist_ok=True)
if (
mount_dir
and self.container_deployment != "agentrun"
and self.container_deployment != "fc"
):
if not os.path.isabs(mount_dir):
mount_dir = os.path.abspath(mount_dir)
if storage_path is None:
if self.storage_folder:
storage_path = self.storage.path_join(
self.storage_folder,
session_id,
)
if (
mount_dir
and storage_path
and self.container_deployment != "agentrun"
and self.container_deployment != "fc"
):
self.storage.download_folder(storage_path, mount_dir)
# Check for an existing container with the same name
container_name = self._generate_container_key(session_id)
try:
if self.client.inspect(container_name):
raise ValueError(
f"Container with name {container_name} already exists.",
)
# Generate a random secret token
runtime_token = secrets.token_hex(16)
# Prepare volume bindings if a mount directory is provided
if (
mount_dir
and self.container_deployment != "agentrun"
and self.container_deployment != "fc"
):
volume_bindings = {
mount_dir: {
"bind": self.workdir,
"mode": "rw",
},
}
else:
volume_bindings = {}
if self.readonly_mounts:
for host_path, container_path in self.readonly_mounts.items():
if not os.path.isabs(host_path):
host_path = os.path.abspath(host_path)
volume_bindings[host_path] = {
"bind": container_path,
"mode": "ro",
}
_id, ports, ip, *rest = self.client.create(
image,
name=container_name,
ports=["80/tcp"], # Nginx
volumes=volume_bindings,
environment={
"SECRET_TOKEN": runtime_token,
**environment,
},
runtime_config=config.runtime_config,
)
http_protocol = "http"
if rest and rest[0] == "https":
http_protocol = "https"
if _id is None:
return None
# Check the container status
status = self.client.get_status(container_name)
if self.client.get_status(container_name) != "running":
logger.warning(
f"Container {container_name} is not running. Current "
f"status: {status}",
)
return None
# TODO: update ContainerModel according to images & backend
container_model = ContainerModel(
session_id=session_id,
container_id=_id,
container_name=container_name,
url=f"{http_protocol}://{ip}:{ports[0]}",
ports=[ports[0]],
mount_dir=str(mount_dir),
storage_path=storage_path,
runtime_token=runtime_token,
version=image,
meta=meta or {},
timeout=config.timeout,
)
# Register in mapping
self.container_mapping.set(
container_model.container_name,
container_model.model_dump(),
)
# Build mapping session_ctx_id to container_name
if meta and "session_ctx_id" in meta:
env_ids = (
self.session_mapping.get(
meta["session_ctx_id"],
)
or []
)
env_ids.append(container_model.container_name)
self.session_mapping.set(meta["session_ctx_id"], env_ids)
logger.debug(
f"Created container {container_name}"
f":{container_model.model_dump()}",
)
return container_name
except Exception as e:
logger.warning(
f"Failed to create container: {e}",
)
logger.debug(f"{traceback.format_exc()}")
self.release(identity=container_name)
return None
[docs]
@remote_wrapper()
def release(self, identity):
try:
container_json = self.get_info(identity)
if not container_json:
logger.warning(
f"No container found for {identity}.",
)
return True
container_info = ContainerModel(**container_json)
# remove key in mapping before we remove container
self.container_mapping.delete(container_json.get("container_name"))
# remove key in mapping
session_ctx_id = container_info.meta.get("session_ctx_id")
if session_ctx_id:
env_ids = self.session_mapping.get(session_ctx_id) or []
env_ids = [
eid
for eid in env_ids
if eid != container_info.container_name
]
if env_ids:
self.session_mapping.set(session_ctx_id, env_ids)
else:
self.session_mapping.delete(session_ctx_id)
self.client.stop(container_info.container_id, timeout=1)
self.client.remove(container_info.container_id, force=True)
logger.debug(f"Container for {identity} destroyed.")
# Upload to storage
if container_info.mount_dir and container_info.storage_path:
self.storage.upload_folder(
container_info.mount_dir,
container_info.storage_path,
)
return True
except Exception as e:
logger.warning(
f"Failed to destroy container: {e}",
)
logger.debug(f"{traceback.format_exc()}")
return False
[docs]
@remote_wrapper()
def start(self, identity):
try:
container_json = self.get_info(identity)
if not container_json:
logger.warning(
f"No container found for {identity}.",
)
return False
container_info = ContainerModel(**container_json)
self.client.start(container_info.container_id)
status = self.client.get_status(container_info.container_id)
if status != "running":
logger.error(
f"Failed to start container {identity}. "
f"Current status: {status}",
)
return False
logger.debug(f"Container {identity} started.")
return True
except Exception as e:
logger.error(
f"Failed to start container: {e}:"
f" {traceback.format_exc()}",
)
return False
[docs]
@remote_wrapper()
def stop(self, identity):
try:
container_json = self.get_info(identity)
if not container_json:
logger.warning(f"No container found for {identity}.")
return True
container_info = ContainerModel(**container_json)
self.client.stop(container_info.container_id, timeout=1)
status = self.client.get_status(container_info.container_id)
if status != "exited":
logger.error(
f"Failed to stop container {identity}. "
f"Current status: {status}",
)
return False
logger.debug(f"Container {identity} stopped.")
return True
except Exception as e:
logger.error(
f"Failed to stop container: {e}: {traceback.format_exc()}",
)
return False
[docs]
@remote_wrapper()
def get_status(self, identity):
"""Get container status by container_name or container_id."""
return self.client.get_status(identity)
[docs]
@remote_wrapper()
def get_info(self, identity):
"""Get container information by container_name or container_id."""
container_model = self.container_mapping.get(identity)
if container_model is None:
container_model = self.container_mapping.get(
self._generate_container_key(identity),
)
if container_model is None:
raise RuntimeError(f"No container found with id: {identity}.")
if hasattr(container_model, "model_dump_json"):
container_model = container_model.model_dump_json()
return container_model
def _establish_connection(self, identity):
container_model = ContainerModel(**self.get_info(identity))
# TODO: remake docker name
if (
"sandbox-appworld" in container_model.version
or "sandbox-bfcl" in container_model.version
):
return TrainingSandboxClient(
base_url=container_model.url,
).__enter__()
return SandboxHttpClient(
container_model,
).__enter__()
[docs]
@remote_wrapper()
def check_health(self, identity):
"""List tool"""
client = self._establish_connection(identity)
return client.check_health()
[docs]
@remote_wrapper()
def add_mcp_servers(self, identity, server_configs, overwrite=False):
"""
Add MCP servers to runtime.
"""
client = self._establish_connection(identity)
return client.add_mcp_servers(
server_configs=server_configs,
overwrite=overwrite,
)
[docs]
@remote_wrapper()
def get_session_mapping(self, session_ctx_id: str) -> list:
"""Get all container names bound to a session context"""
return self.session_mapping.get(session_ctx_id) or []
[docs]
@remote_wrapper()
def list_session_keys(self) -> list:
"""Return all session_ctx_id keys currently in mapping"""
session_keys = []
for key in self.session_mapping.scan():
session_keys.append(key)
return session_keys