# -*- coding: utf-8 -*-
import copy
from abc import abstractmethod
from typing import Dict, Any, Optional
from ..base import ServiceWithLifecycleManager
[docs]
class StateService(ServiceWithLifecycleManager):
"""
Abstract base class for agent state management services.
Stores and manages agent states organized by user_id, session_id,
and round_id. Supports saving, retrieving, listing, and deleting states.
"""
[docs]
async def start(self) -> None:
pass
[docs]
async def stop(self) -> None:
pass
[docs]
@abstractmethod
async def save_state(
self,
user_id: str,
state: Dict[str, Any],
session_id: Optional[str] = None,
round_id: Optional[int] = None,
) -> int:
"""
Save serialized state data for a specific user/session.
If round_id is provided, store the state in that round.
If round_id is None, append as a new round with automatically
assigned round_id.
Args:
user_id: The unique ID of the user.
state: A dictionary representing serialized agent state.
session_id: Optional session/conversation ID. Defaults to
"default".
round_id: Optional conversation round number.
Returns:
The round_id in which the state was saved.
"""
[docs]
@abstractmethod
async def export_state(
self,
user_id: str,
session_id: Optional[str] = None,
round_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Retrieve serialized state data for a user/session.
If round_id is provided, return that round's state.
If round_id is None, return the latest round's state.
Args:
user_id: The unique ID of the user.
session_id: Optional session/conversation ID.
round_id: Optional round number.
Returns:
A dictionary representing the agent state, or None if not found.
"""
[docs]
class InMemoryStateService(StateService):
"""
In-memory implementation of StateService using dictionaries
for sparse round storage.
- Multiple users, sessions, and non-contiguous round IDs are supported.
- If round_id is None when saving, a new round is appended automatically.
- If round_id is None when exporting, the latest round is returned.
"""
_DEFAULT_SESSION_ID = "default"
[docs]
def __init__(self) -> None:
# Structure:
# { user_id: { session_id: { round_id: state_dict } } }
self._store: Optional[
Dict[str, Dict[str, Dict[int, Dict[str, Any]]]]
] = None
self._health = False
[docs]
async def start(self) -> None:
"""Initialize the in-memory store."""
if self._store is None:
self._store = {}
self._health = True
[docs]
async def stop(self) -> None:
"""Clear all in-memory state data."""
if self._store is not None:
self._store.clear()
self._store = None
self._health = False
[docs]
async def health(self) -> bool:
"""Service health check."""
return self._health
[docs]
async def save_state(
self,
user_id: str,
state: Dict[str, Any],
session_id: Optional[str] = None,
round_id: Optional[int] = None,
) -> int:
"""
Save serialized state in sparse dict storage.
If round_id is None, a new round_id will be assigned
as (max existing round_id + 1) or 1 if none exist.
Otherwise, the given round_id will be overwritten.
Returns:
The round_id where the state was saved.
"""
if self._store is None:
raise RuntimeError("Service not started")
sid = session_id or self._DEFAULT_SESSION_ID
self._store.setdefault(user_id, {})
self._store[user_id].setdefault(sid, {})
rounds_dict = self._store[user_id][sid]
# Auto-generate round_id if not provided
if round_id is None:
if rounds_dict:
round_id = max(rounds_dict.keys()) + 1
else:
round_id = 1
# Store a deep copy so caller modifications don't affect saved state
rounds_dict[round_id] = copy.deepcopy(state)
return round_id
[docs]
async def export_state(
self,
user_id: str,
session_id: Optional[str] = None,
round_id: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Retrieve state data for given user/session/round.
If round_id is None: return the latest round.
If round_id is provided: return that round's state.
Returns:
Dictionary representing the agent state, or None if not found.
"""
if self._store is None:
raise RuntimeError("Service not started")
sid = session_id or self._DEFAULT_SESSION_ID
sessions = self._store.get(user_id, {})
rounds_dict = sessions.get(sid, {})
if not rounds_dict:
return None
if round_id is None:
# Get the latest round_id
latest_round_id = max(rounds_dict.keys())
return rounds_dict[latest_round_id]
return rounds_dict.get(round_id)