# -*- coding: utf-8 -*-
from typing import Optional, Dict, Any
import json
import redis.asyncio as aioredis
from .memory_service import MemoryService
from ...schemas.agent_schemas import Message, MessageType
[docs]
class RedisMemoryService(MemoryService):
"""
A Redis-based implementation of the memory service.
"""
[docs]
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
redis_client: Optional[aioredis.Redis] = None,
socket_timeout: Optional[float] = 5.0,
socket_connect_timeout: Optional[float] = 5.0,
max_connections: Optional[int] = None,
retry_on_timeout: bool = True,
ttl_seconds: Optional[int] = 3600, # 1 hour in seconds
max_messages_per_session: Optional[int] = None,
health_check_interval: Optional[float] = 30.0,
socket_keepalive: bool = True,
):
"""
Initialize RedisMemoryService.
Args:
redis_url: Redis connection URL.
redis_client: Optional pre-configured Redis client.
socket_timeout: Socket timeout in seconds (default: 5.0).
socket_connect_timeout: Socket connect timeout in seconds
(default: 5.0).
max_connections: Maximum number of connections in the pool
(default: None).
retry_on_timeout: Whether to retry on timeout (default: True).
ttl_seconds: Time-to-live in seconds for memory data. If None,
data never expires (default: 3600, i.e., 1 hour).
max_messages_per_session: Maximum number of messages stored per
session_id field within a user's Redis memory hash. If None,
no limit (default: None).
health_check_interval: Interval in seconds for health checks on
idle connections (default: 30.0). Connections idle longer
than this will be checked before reuse. Set to 0 to disable.
socket_keepalive: Enable TCP keepalive to prevent silent
disconnections (default: True).
"""
self._redis_url = redis_url
self._redis = redis_client
self._DEFAULT_SESSION_ID = "default"
self._socket_timeout = socket_timeout
self._socket_connect_timeout = socket_connect_timeout
self._max_connections = max_connections
self._retry_on_timeout = retry_on_timeout
self._ttl_seconds = ttl_seconds
self._max_messages_per_session = max_messages_per_session
self._health_check_interval = health_check_interval
self._socket_keepalive = socket_keepalive
[docs]
async def start(self) -> None:
"""Starts the Redis connection with proper timeout
and connection pool settings."""
if self._redis is None:
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
socket_timeout=self._socket_timeout,
socket_connect_timeout=self._socket_connect_timeout,
max_connections=self._max_connections,
retry_on_timeout=self._retry_on_timeout,
health_check_interval=self._health_check_interval,
socket_keepalive=self._socket_keepalive,
)
[docs]
async def stop(self) -> None:
"""Closes the Redis connection."""
if self._redis:
await self._redis.aclose()
self._redis = None
[docs]
async def health(self) -> bool:
"""Checks the health of the service."""
if not self._redis:
return False
try:
pong = await self._redis.ping()
return pong is True or pong == "PONG"
except Exception:
return False
def _user_key(self, user_id):
# Each user is a Redis hash
return f"user_memory:{user_id}"
def _serialize(self, messages):
return json.dumps([msg.dict() for msg in messages])
def _deserialize(self, messages_json):
if not messages_json:
return []
return [Message.parse_obj(m) for m in json.loads(messages_json)]
[docs]
async def add_memory(
self,
user_id: str,
messages: list,
session_id: Optional[str] = None,
) -> None:
if not self._redis:
raise RuntimeError("Redis connection is not available")
key = self._user_key(user_id)
field = session_id if session_id else self._DEFAULT_SESSION_ID
existing_json = await self._redis.hget(key, field)
existing_msgs = self._deserialize(existing_json)
all_msgs = existing_msgs + messages
# Limit the number of messages per session to prevent memory issues
if self._max_messages_per_session is not None:
if len(all_msgs) > self._max_messages_per_session:
# Keep only the most recent messages
all_msgs = all_msgs[-self._max_messages_per_session :]
await self._redis.hset(key, field, self._serialize(all_msgs))
# Set TTL for the key if configured
if self._ttl_seconds is not None:
await self._redis.expire(key, self._ttl_seconds)
[docs]
async def search_memory( # pylint: disable=too-many-branches
self,
user_id: str,
messages: list,
filters: Optional[Dict[str, Any]] = None,
) -> list:
if not self._redis:
raise RuntimeError("Redis connection is not available")
key = self._user_key(user_id)
if (
not messages
or not isinstance(messages, list)
or len(messages) == 0
):
return []
message = messages[-1]
query = await self.get_query_text(message)
if not query:
return []
keywords = set(query.lower().split())
# Process messages in batches to avoid loading all into memory at once
matched_messages = []
hash_keys = await self._redis.hkeys(key)
# Get top_k limit early to optimize memory usage
top_k = None
if (
filters
and "top_k" in filters
and isinstance(filters["top_k"], int)
):
top_k = filters["top_k"]
# Process each session separately to reduce memory footprint
for session_id in hash_keys:
msgs_json = await self._redis.hget(key, session_id)
if not msgs_json:
continue
try:
msgs = self._deserialize(msgs_json)
except Exception:
# Skip corrupted message data
continue
# Match messages in this session
for msg in msgs:
candidate_content = await self.get_query_text(msg)
if candidate_content:
msg_content_lower = candidate_content.lower()
if any(
keyword in msg_content_lower for keyword in keywords
):
matched_messages.append(msg)
# Apply top_k filter if specified
if top_k is not None:
result = matched_messages[-top_k:]
else:
result = matched_messages
# Refresh TTL on read to extend lifetime of actively used data,
# if a TTL is configured and there is existing data for this key.
if self._ttl_seconds is not None and hash_keys:
await self._redis.expire(key, self._ttl_seconds)
return result
[docs]
async def get_query_text(self, message: Message) -> str:
if message:
if message.type == MessageType.MESSAGE:
for content in message.content:
if content.type == "text":
return content.text
return ""
[docs]
async def list_memory(
self,
user_id: str,
filters: Optional[Dict[str, Any]] = None,
) -> list:
if not self._redis:
raise RuntimeError("Redis connection is not available")
key = self._user_key(user_id)
page_num = filters.get("page_num", 1) if filters else 1
page_size = filters.get("page_size", 10) if filters else 10
start_index = (page_num - 1) * page_size
end_index = start_index + page_size
# Optimize: Calculate which sessions we need to load
# For simplicity, we still load all but could be optimized further
# to only load sessions that contain the requested page range
all_msgs = []
hash_keys = await self._redis.hkeys(key)
for session_id in sorted(hash_keys):
msgs_json = await self._redis.hget(key, session_id)
if msgs_json:
try:
msgs = self._deserialize(msgs_json)
all_msgs.extend(msgs)
except json.JSONDecodeError:
# Skip corrupted message data
continue
# Early exit optimization: if we've loaded enough messages
# to cover the requested page, we can stop (but this assumes
# we need all previous messages for proper ordering)
# For now, we keep loading all for correctness
# Refresh TTL on active use to keep memory alive,
# mirroring get_session behavior
if self._ttl_seconds is not None and hash_keys:
await self._redis.expire(key, self._ttl_seconds)
return all_msgs[start_index:end_index]
[docs]
async def delete_memory(
self,
user_id: str,
session_id: Optional[str] = None,
) -> None:
if not self._redis:
raise RuntimeError("Redis connection is not available")
key = self._user_key(user_id)
if session_id:
await self._redis.hdel(key, session_id)
else:
await self._redis.delete(key)
[docs]
async def clear_all_memory(self) -> None:
"""
Clears all memory data from Redis.
This method removes all user memory keys from the Redis database.
"""
if not self._redis:
raise RuntimeError("Redis connection is not available")
keys = await self._redis.keys(self._user_key("*"))
if keys:
await self._redis.delete(*keys)
[docs]
async def delete_user_memory(self, user_id: str) -> None:
"""
Deletes all memory data for a specific user.
Args:
user_id (str): The ID of the user whose memory data should be
deleted
"""
if not self._redis:
raise RuntimeError("Redis connection is not available")
key = self._user_key(user_id)
await self._redis.delete(key)