Source code for agentscope_runtime.engine.services.agent_state.redis_state_service

# -*- coding: utf-8 -*-
import json
from typing import Optional, Dict, Any

import redis.asyncio as aioredis

from .state_service import StateService


[docs] class RedisStateService(StateService): """ Redis-based implementation of StateService. Stores agent states in Redis using a hash per (user_id, session_id), with round_id as the hash field and serialized state as the value. """ _DEFAULT_SESSION_ID = "default"
[docs] def __init__( self, redis_url: str = "redis://localhost:6379/0", redis_client: Optional[aioredis.Redis] = None, ): self._redis_url = redis_url self._redis = redis_client self._health = False
[docs] async def start(self) -> None: """Initialize the Redis connection.""" if self._redis is None: self._redis = aioredis.from_url( self._redis_url, decode_responses=True, ) self._health = True
[docs] async def stop(self) -> None: """Close the Redis connection.""" if self._redis: await self._redis.close() self._redis = None self._health = False
[docs] async def health(self) -> bool: """Service health check.""" if not self._redis: return False try: pong = await self._redis.ping() return pong is True or pong == "PONG" except Exception: return False
def _session_key(self, user_id: str, session_id: str) -> str: """Generate the Redis key for a user's session.""" return f"user_state:{user_id}:{session_id}"
[docs] async def save_state( self, user_id: str, state: Dict[str, Any], session_id: Optional[str] = None, round_id: Optional[int] = None, ) -> int: if not self._redis: raise RuntimeError("Redis connection is not available") sid = session_id or self._DEFAULT_SESSION_ID key = self._session_key(user_id, sid) existing_fields = await self._redis.hkeys(key) existing_rounds = sorted( int(f) for f in existing_fields if f.isdigit() ) if round_id is None: if existing_rounds: round_id = max(existing_rounds) + 1 else: round_id = 1 await self._redis.hset(key, round_id, json.dumps(state)) return round_id
[docs] async def export_state( self, user_id: str, session_id: Optional[str] = None, round_id: Optional[int] = None, ) -> Optional[Dict[str, Any]]: if not self._redis: raise RuntimeError("Redis connection is not available") sid = session_id or self._DEFAULT_SESSION_ID key = self._session_key(user_id, sid) existing_fields = await self._redis.hkeys(key) if not existing_fields: return None if round_id is None: numeric_fields = [int(f) for f in existing_fields if f.isdigit()] if not numeric_fields: return None latest_round_id = max(numeric_fields) state_json = await self._redis.hget(key, latest_round_id) else: state_json = await self._redis.hget(key, round_id) if state_json is None: return None return json.loads(state_json)