# -*- coding: utf-8 -*-
"""Module for the training sandbox client."""
# -*- coding: utf-8 -*-
# pylint: disable=unused-argument,too-many-return-statements
import time
import logging
from typing import Dict, List, Optional, Any
import requests
from requests.exceptions import HTTPError, JSONDecodeError
logger = logging.getLogger(__name__)
[docs]
class TrainingSandboxClient:
"""Client for interacting with the training sandbox."""
[docs]
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url.rstrip("/")
self.timeout = 100
self.session = requests.Session()
def __enter__(self):
# Wait for the runtime api server to be healthy
self.wait_until_healthy()
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
[docs]
def wait_until_healthy(self) -> None:
"""
Waits until the runtime service is running for a specified timeout.
"""
start_time = time.time()
while time.time() - start_time < self.timeout:
if self.check_health():
return
time.sleep(1)
raise TimeoutError(
"Runtime service did not start within the specified timeout.",
)
[docs]
def check_health(self) -> bool:
"""
Checks if the runtime service is running by verifying the health
endpoint.
Returns:
bool: True if the service is reachable, False otherwise
"""
endpoint = f"{self.base_url}/healthz"
try:
response_api = self.session.get(endpoint)
return response_api.status_code == 200
except requests.RequestException:
return False
def _make_request(
self,
endpoint: str,
env_type: str = "default",
task_id: str = None,
instance_id: str = None,
messages: Dict[str, Any] = None,
params: Dict[str, Any] = None,
) -> Dict:
"""Request from fastapi"""
url = f"{self.base_url}/{endpoint}"
data = {
"env_type": env_type,
"task_id": task_id,
"instance_id": instance_id,
"messages": messages or {},
"params": params or {},
}
response = self.session.post(url, json=data)
try:
response.raise_for_status()
except HTTPError as e:
try:
detail = response.json().get("detail", "")
except JSONDecodeError:
detail = response.text
raise ValueError(
f"HTTP Error {e.response.status_code}: {detail}",
) from e
return response.json()
[docs]
def get_task_ids(
self,
env_type: str,
split: str = "train",
params: dict | None = None,
) -> List[str]:
"""Get task id list"""
payload = {"env_type": env_type}
if params:
payload["params"] = params
response = self._make_request(
endpoint="get_env_profile",
env_type=env_type,
params={"split": split},
)
return response["data"]
[docs]
def get_env_profile(
self,
env_type: str,
split: str = "train",
params: dict | None = None,
) -> List[str]:
"""get environment profile"""
payload = {"env_type": env_type}
if params:
payload["params"] = params
response = self._make_request(
endpoint="get_env_profile",
env_type=env_type,
params={"split": split},
)
return response["data"]
[docs]
def create_instance(
self,
env_type: str,
task_id: str,
instance_id: str = None,
params: Dict = None,
) -> Dict[str, str]:
"""create instance of a task"""
response = self._make_request(
endpoint="create",
env_type=env_type,
task_id=task_id,
instance_id=instance_id,
params=params,
)
return response["data"]
[docs]
def step(
self,
instance_id: str,
action: Dict = None,
params: Dict = None,
) -> str:
"""execute step transmission"""
action = action or {}
params = params or {}
response = self._make_request(
endpoint="step",
instance_id=instance_id,
messages=action,
params=params,
)
return response["data"]
[docs]
def evaluate(
self,
instance_id: str,
messages: Dict = None,
params: Dict = None,
) -> float:
"""evaluate instance execution"""
messages = messages or {}
params = params or {}
response = self._make_request(
endpoint="evaluate",
instance_id=instance_id,
messages=messages,
params=params,
)
return response["data"]
[docs]
def release_instance(self, instance_id: str) -> bool:
"""release instance from memory"""
response = self._make_request(
endpoint="release",
instance_id=instance_id,
)
return response["success"]
# remined for future
[docs]
def add_mcp_servers(self, server_configs, overwrite=False):
"""add mcp for future"""
return None