Source code for agentscope_runtime.engine.services.memory.redis_memory_service
# -*- coding: utf-8 -*-
from typing import Optional, Dict, Any
import json
import redis.asyncio as aioredis
from .memory_service import MemoryService
from ...schemas.agent_schemas import Message, MessageType
[docs]
class RedisMemoryService(MemoryService):
"""
A Redis-based implementation of the memory service.
"""
[docs]
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
redis_client: Optional[aioredis.Redis] = None,
):
self._redis_url = redis_url
self._redis = redis_client
self._DEFAULT_SESSION_ID = "default"
[docs]
async def start(self) -> None:
"""Starts the Redis connection."""
if self._redis is None:
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
)
[docs]
async def stop(self) -> None:
"""Closes the Redis connection."""
if self._redis:
await self._redis.close()
self._redis = None
[docs]
async def health(self) -> bool:
"""Checks the health of the service."""
if not self._redis:
return False
try:
pong = await self._redis.ping()
return pong is True or pong == "PONG"
except Exception:
return False
def _user_key(self, user_id):
# Each user is a Redis hash
return f"user_memory:{user_id}"
def _serialize(self, messages):
return json.dumps([msg.dict() for msg in messages])
def _deserialize(self, messages_json):
if not messages_json:
return []
return [Message.parse_obj(m) for m in json.loads(messages_json)]
[docs]
async def add_memory(
self,
user_id: str,
messages: list,
session_id: Optional[str] = None,
) -> None:
if not self._redis:
raise RuntimeError("Redis connection is not available")
key = self._user_key(user_id)
field = session_id if session_id else self._DEFAULT_SESSION_ID
existing_json = await self._redis.hget(key, field)
existing_msgs = self._deserialize(existing_json)
all_msgs = existing_msgs + messages
await self._redis.hset(key, field, self._serialize(all_msgs))
[docs]
async def search_memory(
self,
user_id: str,
messages: list,
filters: Optional[Dict[str, Any]] = None,
) -> list:
key = self._user_key(user_id)
if (
not messages
or not isinstance(messages, list)
or len(messages) == 0
):
return []
message = messages[-1]
query = await self.get_query_text(message)
if not query:
return []
keywords = set(query.lower().split())
all_msgs = []
hash_keys = await self._redis.hkeys(key)
for session_id in hash_keys:
msgs_json = await self._redis.hget(key, session_id)
msgs = self._deserialize(msgs_json)
all_msgs.extend(msgs)
matched_messages = []
for msg in all_msgs:
candidate_content = await self.get_query_text(msg)
if candidate_content:
msg_content_lower = candidate_content.lower()
if any(keyword in msg_content_lower for keyword in keywords):
matched_messages.append(msg)
if (
filters
and "top_k" in filters
and isinstance(filters["top_k"], int)
):
return matched_messages[-filters["top_k"] :]
return matched_messages
[docs]
async def get_query_text(self, message: Message) -> str:
if message:
if message.type == MessageType.MESSAGE:
for content in message.content:
if content.type == "text":
return content.text
return ""
[docs]
async def list_memory(
self,
user_id: str,
filters: Optional[Dict[str, Any]] = None,
) -> list:
key = self._user_key(user_id)
all_msgs = []
hash_keys = await self._redis.hkeys(key)
for session_id in sorted(hash_keys):
msgs_json = await self._redis.hget(key, session_id)
msgs = self._deserialize(msgs_json)
all_msgs.extend(msgs)
page_num = filters.get("page_num", 1) if filters else 1
page_size = filters.get("page_size", 10) if filters else 10
start_index = (page_num - 1) * page_size
end_index = start_index + page_size
return all_msgs[start_index:end_index]
[docs]
async def delete_memory(
self,
user_id: str,
session_id: Optional[str] = None,
) -> None:
key = self._user_key(user_id)
if session_id:
await self._redis.hdel(key, session_id)
else:
await self._redis.delete(key)
[docs]
async def clear_all_memory(self) -> None:
"""
Clears all memory data from Redis.
This method removes all user memory keys from the Redis database.
"""
if not self._redis:
raise RuntimeError("Redis connection is not available")
keys = await self._redis.keys(self._user_key("*"))
if keys:
await self._redis.delete(*keys)
[docs]
async def delete_user_memory(self, user_id: str) -> None:
"""
Deletes all memory data for a specific user.
Args:
user_id (str): The ID of the user whose memory data should be
deleted
"""
if not self._redis:
raise RuntimeError("Redis connection is not available")
key = self._user_key(user_id)
await self._redis.delete(key)