Source code for agentscope_runtime.engine.services.memory.tablestore_memory_service

# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name
import asyncio
import copy
from enum import Enum
from typing import Any, Dict, List, Optional

import tablestore
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_core.embeddings import Embeddings
from tablestore import AsyncOTSClient as AsyncTablestoreClient
from tablestore import VectorMetricType
from tablestore_for_agent_memory.base.filter import Filters
from tablestore_for_agent_memory.knowledge.async_knowledge_store import (
    AsyncKnowledgeStore,
)

from ...schemas.agent_schemas import Message, MessageType
from .memory_service import MemoryService
from ..utils.tablestore_service_utils import (
    convert_messages_to_tablestore_documents,
    convert_tablestore_document_to_message,
    get_message_metadata_names,
    tablestore_log,
)


[docs] class SearchStrategy(Enum): FULL_TEXT = "full_text" VECTOR = "vector"
[docs] class TablestoreMemoryService(MemoryService): """ A Tablestore-based implementation of the memory service. based on tablestore_for_agent_memory (https://github.com/aliyun/ alibabacloud-tablestore-for-agent-memory/blob/main/python/docs/knowledge_store_tutorial.ipynb). """ _SEARCH_INDEX_NAME = "agentscope_runtime_knowledge_search_index_name" _DEFAULT_SESSION_ID = "default"
[docs] def __init__( self, tablestore_client: AsyncTablestoreClient, search_strategy: SearchStrategy = SearchStrategy.FULL_TEXT, embedding_model: Optional[Embeddings] = None, vector_dimension: int = 1536, table_name: Optional[str] = "agentscope_runtime_memory", search_index_schema: Optional[List[tablestore.FieldSchema]] = ( tablestore.FieldSchema("user_id", tablestore.FieldType.KEYWORD), tablestore.FieldSchema("session_id", tablestore.FieldType.KEYWORD), ), text_field: Optional[str] = "text", embedding_field: Optional[str] = "embedding", vector_metric_type: VectorMetricType = VectorMetricType.VM_COSINE, **kwargs: Any, ): if embedding_model is None: embedding_model = DashScopeEmbeddings() self._search_strategy = search_strategy self._embedding_model = ( embedding_model # the parameter is None, don't store vector. ) if ( self._search_strategy == SearchStrategy.VECTOR and self._embedding_model is None ): raise ValueError( "Embedding model is required when search strategy is VECTOR.", ) self._tablestore_client = tablestore_client self._vector_dimension = vector_dimension self._table_name = table_name self._search_index_schema = ( list(search_index_schema) if search_index_schema is not None else None ) self._text_field = text_field self._embedding_field = embedding_field self._vector_metric_type = vector_metric_type self._knowledge_store: Optional[AsyncKnowledgeStore] = None self._knowledge_store_init_parameter_kwargs = kwargs
async def _init_knowledge_store(self) -> None: self._knowledge_store = AsyncKnowledgeStore( tablestore_client=self._tablestore_client, vector_dimension=self._vector_dimension, enable_multi_tenant=False, # enable multi tenant will make user be confused, # we unify the usage of session id and user id, # and allow users to configure the index themselves. table_name=self._table_name, search_index_name=TablestoreMemoryService._SEARCH_INDEX_NAME, search_index_schema=copy.deepcopy(self._search_index_schema), text_field=self._text_field, embedding_field=self._embedding_field, vector_metric_type=self._vector_metric_type, **self._knowledge_store_init_parameter_kwargs, ) await self._knowledge_store.init_table()
[docs] async def start(self) -> None: """Start the tablestore service""" if self._knowledge_store: return await self._init_knowledge_store()
[docs] async def stop(self) -> None: """Close the tablestore service""" if self._knowledge_store is None: return knowledge_store = self._knowledge_store self._knowledge_store = None await knowledge_store.close()
[docs] async def health(self) -> bool: """Checks the health of the service.""" if self._knowledge_store is None: tablestore_log("Tablestore memory service is not started.") return False try: async for _ in await self._knowledge_store.get_all_documents(): return True return True except Exception as e: tablestore_log( f"Tablestore memory service " f"cannot access Tablestore, error: {str(e)}.", ) return False
[docs] async def add_memory( self, user_id: str, messages: list, session_id: Optional[str] = None, ) -> None: if not session_id: session_id = TablestoreMemoryService._DEFAULT_SESSION_ID if not messages: return tablestore_documents = convert_messages_to_tablestore_documents( messages, user_id, session_id, self._embedding_model, ) put_tasks = [ self._knowledge_store.put_document(tablestore_document) for tablestore_document in tablestore_documents ] await asyncio.gather(*put_tasks)
[docs] @staticmethod async def get_query_text(message: Message) -> str: if not message or message.type != MessageType.MESSAGE: return "" for content in message.content: if content.type == "text": return content.text return ""
[docs] async def search_memory( self, user_id: str, messages: list, filters: Optional[Dict[str, Any]] = None, ) -> list: if ( not messages or not isinstance(messages, list) or len(messages) == 0 ): return [] query = await TablestoreMemoryService.get_query_text(messages[-1]) if not query: return [] top_k = 100 if ( filters and "top_k" in filters and isinstance(filters["top_k"], int) ): top_k = filters["top_k"] if self._search_strategy == SearchStrategy.FULL_TEXT: matched_messages = [ convert_tablestore_document_to_message(hit.document) for hit in ( await self._knowledge_store.full_text_search( query=query, metadata_filter=Filters.eq("user_id", user_id), limit=top_k, meta_data_to_get=get_message_metadata_names(), ) ).hits ] elif self._search_strategy == SearchStrategy.VECTOR: matched_messages = [ convert_tablestore_document_to_message(hit.document) for hit in ( await self._knowledge_store.vector_search( query_vector=self._embedding_model.embed_query(query), metadata_filter=Filters.eq("user_id", user_id), top_k=top_k, meta_data_to_get=get_message_metadata_names(), ) ).hits ] else: raise ValueError( f"Unsupported search strategy: {self._search_strategy}", ) return matched_messages
[docs] async def list_memory( self, user_id: str, filters: Optional[Dict[str, Any]] = None, ) -> list: page_num = filters.get("page_num", 1) if filters else 1 page_size = filters.get("page_size", 10) if filters else 10 if page_num < 1 or page_size < 1: raise ValueError("page_num and page_size must be greater than 0.") next_token = None for _ in range(page_num - 1): next_token = ( await self._knowledge_store.search_documents( metadata_filter=Filters.eq("user_id", user_id), limit=page_size, next_token=next_token, ) ).next_token if not next_token: tablestore_log( "Page number exceeds the total number of pages, " "return empty list.", ) return [] messages = [ convert_tablestore_document_to_message(hit.document) for hit in ( await self._knowledge_store.search_documents( metadata_filter=Filters.eq("user_id", user_id), limit=page_size, next_token=next_token, meta_data_to_get=get_message_metadata_names(), ) ).hits ] return messages
[docs] async def delete_memory( self, user_id: str, session_id: Optional[str] = None, ) -> None: delete_tablestore_documents = [ hit.document for hit in ( await self._knowledge_store.search_documents( metadata_filter=( Filters.eq("user_id", user_id) if not session_id else Filters.logical_and( [ Filters.eq("user_id", user_id), Filters.eq("session_id", session_id), ], ) ), ) ).hits ] delete_tasks = [ self._knowledge_store.delete_document( tablestore_document.document_id, ) for tablestore_document in delete_tablestore_documents ] await asyncio.gather(*delete_tasks)