chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
31
src/agentscope/memory/__init__.py
Normal file
31
src/agentscope/memory/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The memory module."""
|
||||
|
||||
from ._working_memory import (
|
||||
MemoryBase,
|
||||
InMemoryMemory,
|
||||
RedisMemory,
|
||||
AsyncSQLAlchemyMemory,
|
||||
)
|
||||
from ._long_term_memory import (
|
||||
LongTermMemoryBase,
|
||||
Mem0LongTermMemory,
|
||||
ReMePersonalLongTermMemory,
|
||||
ReMeTaskLongTermMemory,
|
||||
ReMeToolLongTermMemory,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Working memory
|
||||
"MemoryBase",
|
||||
"InMemoryMemory",
|
||||
"RedisMemory",
|
||||
"AsyncSQLAlchemyMemory",
|
||||
# Long-term memory
|
||||
"LongTermMemoryBase",
|
||||
"Mem0LongTermMemory",
|
||||
"ReMePersonalLongTermMemory",
|
||||
"ReMeTaskLongTermMemory",
|
||||
"ReMeToolLongTermMemory",
|
||||
]
|
||||
18
src/agentscope/memory/_long_term_memory/__init__.py
Normal file
18
src/agentscope/memory/_long_term_memory/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The long-term memory module for AgentScope."""
|
||||
|
||||
from ._long_term_memory_base import LongTermMemoryBase
|
||||
from ._mem0 import Mem0LongTermMemory
|
||||
from ._reme import (
|
||||
ReMePersonalLongTermMemory,
|
||||
ReMeTaskLongTermMemory,
|
||||
ReMeToolLongTermMemory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LongTermMemoryBase",
|
||||
"Mem0LongTermMemory",
|
||||
"ReMePersonalLongTermMemory",
|
||||
"ReMeTaskLongTermMemory",
|
||||
"ReMeToolLongTermMemory",
|
||||
]
|
||||
@@ -0,0 +1,94 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The long-term memory base class."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agentscope.message import Msg
|
||||
from agentscope.module import StateModule
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
|
||||
class LongTermMemoryBase(StateModule):
|
||||
"""The long-term memory base class, which should be a time-series
|
||||
memory management system.
|
||||
|
||||
The `record_to_memory` and `retrieve_from_memory` methods are two tool
|
||||
functions for agent to manage the long-term memory voluntarily. You can
|
||||
choose not to implement these two functions.
|
||||
|
||||
The `record` and `retrieve` methods are for developers to use. For example,
|
||||
retrieving/recording memory at the beginning of each reply, and adding
|
||||
the retrieved memory to the system prompt.
|
||||
"""
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""A developer-designed method to record information from the given
|
||||
input message(s) to the long-term memory."""
|
||||
raise NotImplementedError(
|
||||
"The `record` method is not implemented. ",
|
||||
)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""A developer-designed method to retrieve information from the
|
||||
long-term memory based on the given input message(s). The retrieved
|
||||
information will be added to the system prompt of the agent."""
|
||||
raise NotImplementedError(
|
||||
"The `retrieve` method is not implemented. ",
|
||||
)
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Use this function to record important information that you may
|
||||
need later. The target content should be specific and concise, e.g.
|
||||
who, when, where, do what, why, how, etc.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your thinking and reasoning about what to record
|
||||
content (`list[str]`):
|
||||
The content to remember, which is a list of strings.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The `record_to_memory` method is not implemented. "
|
||||
"You can implement it in your own long-term memory class.",
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Retrieve the memory based on the given keywords.
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
The keywords to search for in the memory, which should be
|
||||
specific and concise, e.g. the person's name, the date, the
|
||||
location, etc.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 5.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
A list of messages that match the keywords.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The `retrieve_from_memory` method is not implemented. "
|
||||
"You can implement it in your own long-term memory class.",
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The Mem0 long-term memory module for AgentScope."""
|
||||
|
||||
from ._mem0_long_term_memory import Mem0LongTermMemory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Mem0LongTermMemory",
|
||||
]
|
||||
@@ -0,0 +1,746 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Long-term memory implementation using mem0 library.
|
||||
|
||||
This module provides a long-term memory implementation that integrates
|
||||
with the mem0 library to provide persistent memory storage and retrieval
|
||||
capabilities for AgentScope agents.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from importlib import metadata
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from ....embedding import EmbeddingModelBase
|
||||
from .._long_term_memory_base import LongTermMemoryBase
|
||||
from ....message import Msg, TextBlock
|
||||
from ....model import ChatModelBase
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mem0.configs.base import MemoryConfig
|
||||
from mem0.vector_stores.configs import VectorStoreConfig
|
||||
else:
|
||||
MemoryConfig = Any
|
||||
VectorStoreConfig = Any
|
||||
|
||||
|
||||
def _create_agentscope_config_classes() -> tuple:
|
||||
"""Create custom config classes for agentscope providers."""
|
||||
from mem0.embeddings.configs import EmbedderConfig
|
||||
from mem0.llms.configs import LlmConfig
|
||||
|
||||
class _ASLlmConfig(LlmConfig):
|
||||
"""Custom LLM config class that updates the validate_config method.
|
||||
|
||||
Attention: in mem0, the validate_config hardcodes the provider, so we
|
||||
need to override the validate_config method to support the agentscope
|
||||
providers. We will follow up with the mem0 to improve this.
|
||||
"""
|
||||
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config(cls, v: Any, values: Any) -> Any:
|
||||
"""Validate the LLM configuration."""
|
||||
from mem0.utils.factory import LlmFactory
|
||||
|
||||
provider = values.data.get("provider")
|
||||
if provider in LlmFactory.provider_to_class:
|
||||
return v
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
class _ASEmbedderConfig(EmbedderConfig):
|
||||
"""Custom embedder config class that updates the validate_config
|
||||
method."""
|
||||
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config(cls, v: Any, values: Any) -> Any:
|
||||
"""Validate the embedder configuration."""
|
||||
from mem0.utils.factory import EmbedderFactory
|
||||
|
||||
provider = values.data.get("provider")
|
||||
if provider in EmbedderFactory.provider_to_class:
|
||||
return v
|
||||
raise ValueError(f"Unsupported Embedder provider: {provider}")
|
||||
|
||||
return _ASLlmConfig, _ASEmbedderConfig
|
||||
|
||||
|
||||
class Mem0LongTermMemory(LongTermMemoryBase):
|
||||
"""A class that implements the LongTermMemoryBase interface using mem0."""
|
||||
|
||||
@staticmethod
|
||||
def _setup_mem0_logging(suppress_mem0_logging: bool) -> None:
|
||||
"""Suppress mem0 logging if requested.
|
||||
|
||||
Args:
|
||||
suppress_mem0_logging (`bool`):
|
||||
Whether to suppress mem0 logging. See class docstring for
|
||||
details on QDRANT validation errors when using mem0 1.0.3.
|
||||
"""
|
||||
if suppress_mem0_logging:
|
||||
import logging
|
||||
|
||||
logging.getLogger("mem0").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("mem0.memory").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("mem0.memory.main").setLevel(logging.CRITICAL)
|
||||
|
||||
@staticmethod
|
||||
def _register_agentscope_providers() -> None:
|
||||
"""Register the agentscope providers with mem0.
|
||||
|
||||
Raises:
|
||||
`ImportError`:
|
||||
If the mem0 library is not installed.
|
||||
"""
|
||||
try:
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
||||
from packaging import version
|
||||
|
||||
# Check mem0 version
|
||||
current_version = metadata.version("mem0ai")
|
||||
is_mem0_version_low = version.parse(
|
||||
current_version,
|
||||
) <= version.parse("0.1.115")
|
||||
|
||||
# Register the agentscope providers with mem0
|
||||
EmbedderFactory.provider_to_class["agentscope"] = (
|
||||
"agentscope.memory._long_term_memory._mem0."
|
||||
"_mem0_utils.AgentScopeEmbedding"
|
||||
)
|
||||
if is_mem0_version_low:
|
||||
# For mem0 version <= 0.1.115, use the old style
|
||||
LlmFactory.provider_to_class["agentscope"] = (
|
||||
"agentscope.memory._long_term_memory._mem0."
|
||||
"_mem0_utils.AgentScopeLLM"
|
||||
)
|
||||
else:
|
||||
# For mem0 version > 0.1.115, use the new style
|
||||
LlmFactory.provider_to_class["agentscope"] = (
|
||||
"agentscope.memory._long_term_memory._mem0."
|
||||
"_mem0_utils.AgentScopeLLM",
|
||||
BaseLlmConfig,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install the mem0 library by `pip install mem0ai`",
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _validate_identifiers(
|
||||
agent_name: str | None,
|
||||
user_name: str | None,
|
||||
run_name: str | None,
|
||||
) -> None:
|
||||
"""Validate that at least one identifier is provided.
|
||||
|
||||
Args:
|
||||
agent_name (`str | None`):
|
||||
The name of the agent.
|
||||
user_name (`str | None`):
|
||||
The name of the user.
|
||||
run_name (`str | None`):
|
||||
The name of the run/session.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If all identifiers are None.
|
||||
"""
|
||||
if agent_name is None and user_name is None and run_name is None:
|
||||
raise ValueError(
|
||||
"at least one of agent_name, user_name, and run_name is "
|
||||
"required",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0_config(
|
||||
mem0_config: MemoryConfig | None,
|
||||
model: ChatModelBase | None,
|
||||
embedding_model: EmbeddingModelBase | None,
|
||||
vector_store_config: VectorStoreConfig | None,
|
||||
_ASLlmConfig: type,
|
||||
_ASEmbedderConfig: type,
|
||||
**kwargs: Any,
|
||||
) -> MemoryConfig:
|
||||
"""Configure the mem0 MemoryConfig object.
|
||||
|
||||
Args:
|
||||
mem0_config (`MemoryConfig | None`):
|
||||
The existing mem0 config, if any.
|
||||
model (`ChatModelBase | None`):
|
||||
The chat model to use.
|
||||
embedding_model (`EmbeddingModelBase | None`):
|
||||
The embedding model to use.
|
||||
vector_store_config (`VectorStoreConfig | None`):
|
||||
The vector store config to use.
|
||||
_ASLlmConfig (`type`):
|
||||
The custom LLM config class for agentscope.
|
||||
_ASEmbedderConfig (`type`):
|
||||
The custom embedder config class for agentscope.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments, including 'on_disk' for
|
||||
vector store configuration.
|
||||
|
||||
Returns:
|
||||
`MemoryConfig`:
|
||||
The configured MemoryConfig object.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If `mem0_config` is None and either `model` or
|
||||
`embedding_model` is None.
|
||||
"""
|
||||
import mem0
|
||||
|
||||
if mem0_config is not None:
|
||||
# Case 1: mem0_config is provided - override specific
|
||||
# configurations if individual params are given
|
||||
|
||||
# Override LLM configuration if model is provided
|
||||
if model is not None:
|
||||
mem0_config.llm = _ASLlmConfig(
|
||||
provider="agentscope",
|
||||
config={"model": model},
|
||||
)
|
||||
|
||||
# Override embedder configuration if embedding_model is provided
|
||||
if embedding_model is not None:
|
||||
mem0_config.embedder = _ASEmbedderConfig(
|
||||
provider="agentscope",
|
||||
config={"model": embedding_model},
|
||||
)
|
||||
|
||||
# Override vector store configuration if vector_store_config is
|
||||
# provided
|
||||
if vector_store_config is not None:
|
||||
mem0_config.vector_store = vector_store_config
|
||||
|
||||
else:
|
||||
# Case 2: mem0_config is not provided - create new configuration
|
||||
# from individual parameters
|
||||
|
||||
# Validate that required parameters are provided
|
||||
if model is None or embedding_model is None:
|
||||
raise ValueError(
|
||||
"model and embedding_model are required if mem0_config "
|
||||
"is not provided",
|
||||
)
|
||||
|
||||
# Create new MemoryConfig with provided LLM and embedder
|
||||
mem0_config = mem0.configs.base.MemoryConfig(
|
||||
llm=_ASLlmConfig(
|
||||
provider="agentscope",
|
||||
config={"model": model},
|
||||
),
|
||||
embedder=_ASEmbedderConfig(
|
||||
provider="agentscope",
|
||||
config={"model": embedding_model},
|
||||
),
|
||||
)
|
||||
|
||||
# Set vector store configuration
|
||||
if vector_store_config is not None:
|
||||
# Use provided vector store configuration
|
||||
mem0_config.vector_store = vector_store_config
|
||||
else:
|
||||
# Use default Qdrant configuration with on-disk storage for
|
||||
# persistence set on_disk to True to enable persistence,
|
||||
# otherwise it will be in memory only
|
||||
on_disk = kwargs.get("on_disk", True)
|
||||
mem0_config.vector_store = (
|
||||
mem0.vector_stores.configs.VectorStoreConfig(
|
||||
config={"on_disk": on_disk},
|
||||
)
|
||||
)
|
||||
|
||||
return mem0_config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
user_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
model: ChatModelBase | None = None,
|
||||
embedding_model: EmbeddingModelBase | None = None,
|
||||
vector_store_config: VectorStoreConfig | None = None,
|
||||
mem0_config: MemoryConfig | None = None,
|
||||
default_memory_type: str | None = None,
|
||||
suppress_mem0_logging: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Mem0LongTermMemory instance
|
||||
|
||||
Args:
|
||||
agent_name (`str | None`, optional):
|
||||
The name of the agent. Default is None.
|
||||
user_name (`str | None`, optional):
|
||||
The name of the user. Default is None.
|
||||
run_name (`str | None`, optional):
|
||||
The name of the run/session. Default is None.
|
||||
|
||||
.. note::
|
||||
1. At least one of `agent_name`, `user_name`, or `run_name` is
|
||||
required.
|
||||
2. During memory recording, these parameters become metadata
|
||||
for the stored memories.
|
||||
3. **Important**: mem0 will extract memories from messages
|
||||
containing role of "user" by default. If you want to
|
||||
extract memories from messages containing role of
|
||||
"assistant", you need to provide `agent_name`.
|
||||
4. During memory retrieval, only memories with matching
|
||||
metadata values will be returned.
|
||||
|
||||
|
||||
model (`ChatModelBase | None`, optional):
|
||||
The chat model to use for the long-term memory. If
|
||||
mem0_config is provided, this will override the LLM
|
||||
configuration. If mem0_config is None, this is required.
|
||||
embedding_model (`EmbeddingModelBase | None`, optional):
|
||||
The embedding model to use for the long-term memory. If
|
||||
mem0_config is provided, this will override the embedder
|
||||
configuration. If mem0_config is None, this is required.
|
||||
vector_store_config (`VectorStoreConfig | None`, optional):
|
||||
The vector store config to use for the long-term memory.
|
||||
If mem0_config is provided, this will override the vector store
|
||||
configuration. If mem0_config is None and this is not
|
||||
provided, defaults to Qdrant with on_disk=True.
|
||||
mem0_config (`MemoryConfig | None`, optional):
|
||||
The mem0 config to use for the long-term memory.
|
||||
If provided, individual
|
||||
model/embedding_model/vector_store_config parameters will
|
||||
override the corresponding configurations in mem0_config. If
|
||||
None, a new MemoryConfig will be created using the provided
|
||||
parameters.
|
||||
default_memory_type (`str | None`, optional):
|
||||
The type of memory to use. Default is None, to create a
|
||||
semantic memory.
|
||||
suppress_mem0_logging (`bool`, optional):
|
||||
Whether to suppress mem0 logging. Default is True.
|
||||
|
||||
.. note::
|
||||
When using vector database QDRANT with mem0 1.0.3, you may
|
||||
encounter validation errors:
|
||||
"Error awaiting memory task (async): 6 validation errors for
|
||||
PointStruct vector.list[float] Input should be a valid list
|
||||
[type=list_type, ...]"
|
||||
According to the mem0 community
|
||||
(see https://github.com/mem0ai/mem0/issues/3780),
|
||||
these error messages are harmless and can be safely ignored.
|
||||
Setting `suppress_mem0_logging=True` (the default) will
|
||||
suppress these error messages.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If `mem0_config` is None and either `model` or
|
||||
`embedding_model` is None.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Suppress mem0 logging if requested
|
||||
self._setup_mem0_logging(suppress_mem0_logging)
|
||||
|
||||
# Register agentscope providers with mem0
|
||||
self._register_agentscope_providers()
|
||||
|
||||
# Create the custom config classes for agentscope providers dynamically
|
||||
_ASLlmConfig, _ASEmbedderConfig = _create_agentscope_config_classes()
|
||||
|
||||
# Validate identifiers
|
||||
self._validate_identifiers(agent_name, user_name, run_name)
|
||||
|
||||
# Store agent and user identifiers for memory management
|
||||
self.agent_id = agent_name
|
||||
self.user_id = user_name
|
||||
self.run_id = run_name
|
||||
|
||||
# Configure mem0_config
|
||||
import mem0
|
||||
|
||||
mem0_config = self._configure_mem0_config(
|
||||
mem0_config=mem0_config,
|
||||
model=model,
|
||||
embedding_model=embedding_model,
|
||||
vector_store_config=vector_store_config,
|
||||
_ASLlmConfig=_ASLlmConfig,
|
||||
_ASEmbedderConfig=_ASEmbedderConfig,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize the async memory instance with the configured settings
|
||||
self.long_term_working_memory = mem0.AsyncMemory(mem0_config)
|
||||
|
||||
# Store the default memory type for future use
|
||||
self.default_memory_type = default_memory_type
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Use this function to record important information that you may
|
||||
need later. The target content should be specific and concise, e.g.
|
||||
who, when, where, do what, why, how, etc.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your thinking and reasoning about what to record.
|
||||
content (`list[str]`):
|
||||
The content to remember, which is a list of strings.
|
||||
"""
|
||||
# Multi-strategy recording approach to ensure content persistence:
|
||||
#
|
||||
# This method employs a three-tier fallback strategy to maximize
|
||||
# successful memory recording:
|
||||
#
|
||||
# 1. Primary: Record as "user" role message
|
||||
# - This is the default approach for capturing user-related
|
||||
# content
|
||||
# - Mem0 extracts and infers memories from messages containing
|
||||
# role of "user"
|
||||
#
|
||||
# 2. Fallback (if agent_id exists): Record as "assistant" role
|
||||
# message
|
||||
# - Triggered when primary recording yields no results
|
||||
# - In this case, mem0 will use the AGENT_MEMORY_EXTRACTION_PROMPT
|
||||
# in mem0/mem0/configs/prompts.py to extract memories from
|
||||
# messages containing role of "assistant", if agent_id is
|
||||
# provided, otherwise it will use the
|
||||
# USER_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
|
||||
# to extract memories.
|
||||
#
|
||||
# 3. Last resort: Record as "assistant" with infer=False
|
||||
# - Used when both previous attempts yield no results
|
||||
# - Bypasses mem0's inference mechanism, which means no
|
||||
# inference is performed, mem0 will only record the content
|
||||
# as is.
|
||||
#
|
||||
# This graduated approach ensures that even if mem0's inference fails
|
||||
# to extract meaningful memories, the raw content is still preserved.
|
||||
|
||||
try:
|
||||
if thinking:
|
||||
content = [thinking] + content
|
||||
|
||||
# Strategy 1: Record as user message first
|
||||
results = await self._mem0_record(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "\n".join(content),
|
||||
"name": "user",
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Strategy 2: Fallback to assistant message. In this case, if
|
||||
# agent_id is provided, mem0 will use the
|
||||
# AGENT_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
|
||||
# to extract memories from messages containing role of
|
||||
# "assistant". If agent_id is not provided, mem0 will still use
|
||||
# the USER_MEMORY_EXTRACTION_PROMPT in
|
||||
# mem0/mem0/configs/prompts.py to extract memories.
|
||||
if (
|
||||
results
|
||||
and isinstance(results, dict)
|
||||
and "results" in results
|
||||
and len(results["results"]) == 0
|
||||
):
|
||||
results = await self._mem0_record(
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "\n".join(content),
|
||||
"name": "assistant",
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Strategy 3: Last resort - direct recording without inference.
|
||||
# In this case, mem0 will not use any prompts to extract
|
||||
# memories, it will only record the content as is.
|
||||
if (
|
||||
results
|
||||
and isinstance(results, dict)
|
||||
and "results" in results
|
||||
and len(results["results"]) == 0
|
||||
):
|
||||
results = await self._mem0_record(
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "\n".join(content),
|
||||
"name": "assistant",
|
||||
},
|
||||
],
|
||||
infer=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Successfully recorded content to memory "
|
||||
f"{results}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Retrieve the memory based on the given keywords.
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
Short, targeted search phrases (for example, a person's name,
|
||||
a specific date, a location, or a phrase describing something
|
||||
you want to retrieve from the memory). Each keyword is issued
|
||||
as an independent query against the memory store.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search,
|
||||
defaults to 5.
|
||||
i.e.,the number of memories to retrieve for
|
||||
each keyword.
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
A ToolResponse containing the retrieved memories as JSON text.
|
||||
"""
|
||||
|
||||
try:
|
||||
results = []
|
||||
search_coroutines = [
|
||||
self.long_term_working_memory.search(
|
||||
query=keyword,
|
||||
agent_id=self.agent_id,
|
||||
user_id=self.user_id,
|
||||
run_id=self.run_id,
|
||||
limit=limit,
|
||||
)
|
||||
for keyword in keywords
|
||||
]
|
||||
search_results = await asyncio.gather(*search_coroutines)
|
||||
for result in search_results:
|
||||
if result:
|
||||
results.extend(
|
||||
[item["memory"] for item in result["results"]],
|
||||
)
|
||||
if "relations" in result.keys():
|
||||
results.extend(
|
||||
self._format_relations(result),
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="\n".join(results),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
memory_type: str | None = None,
|
||||
infer: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Record the content to the long-term memory.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory.
|
||||
memory_type (`str | None`, optional):
|
||||
The type of memory to use. Default is None, to create a
|
||||
semantic memory. "procedural_memory" is explicitly used for
|
||||
procedural memories.
|
||||
infer (`bool`, optional):
|
||||
Whether to infer memory from the content. Default is True.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the mem0 recording.
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "\n".join([str(_.content) for _ in msg_list]),
|
||||
"name": "assistant",
|
||||
},
|
||||
]
|
||||
|
||||
results = await self._mem0_record(
|
||||
messages,
|
||||
memory_type=memory_type,
|
||||
infer=infer,
|
||||
**kwargs,
|
||||
)
|
||||
return results
|
||||
|
||||
def _format_relations(self, result: dict) -> list:
|
||||
"""Format relations from search result.
|
||||
|
||||
Args:
|
||||
result (`dict`):
|
||||
The result from the memory search operation.
|
||||
|
||||
Returns:
|
||||
`list`:
|
||||
The formatted relations.
|
||||
Each relation is a string in the format of:
|
||||
"{source} -- {relationship} -- {destination}"
|
||||
"""
|
||||
if "relations" not in result:
|
||||
return []
|
||||
return [
|
||||
f"{relation['source']} -- "
|
||||
f"{relation['relationship']} -- "
|
||||
f"{relation['destination']}"
|
||||
for relation in result["relations"]
|
||||
]
|
||||
|
||||
async def _mem0_record(
|
||||
self,
|
||||
messages: str | list[dict],
|
||||
memory_type: str | None = None,
|
||||
infer: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Record the content to the long-term memory.
|
||||
|
||||
Args:
|
||||
messages (`str`):
|
||||
The content to remember, which is a string or a list of
|
||||
dictionaries representing messages.
|
||||
memory_type (`str | None`, optional):
|
||||
The type of memory to use. Default is None, to create a
|
||||
semantic memory. "procedural_memory" is explicitly used for
|
||||
procedural memories.
|
||||
infer (`bool`, optional):
|
||||
Whether to infer memory from the content. Default is True.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
The result from the memory recording operation.
|
||||
"""
|
||||
results = await self.long_term_working_memory.add(
|
||||
messages=messages,
|
||||
agent_id=self.agent_id,
|
||||
user_id=self.user_id,
|
||||
run_id=self.run_id,
|
||||
memory_type=(
|
||||
memory_type
|
||||
if memory_type is not None
|
||||
else self.default_memory_type
|
||||
),
|
||||
infer=infer,
|
||||
**kwargs,
|
||||
)
|
||||
return results
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve the content from the long-term memory.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message to search for in the memory, which should be
|
||||
specific and concise, e.g. the person's name, the date, the
|
||||
location, etc.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. if the
|
||||
message is a list of messages, the limit will be applied to
|
||||
each message. If the message is a single message, then the
|
||||
limit is the total number of memories to retrieve for the
|
||||
message. Defaults to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved memory
|
||||
"""
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
msg_strs = [
|
||||
json.dumps(_.to_dict()["content"], ensure_ascii=False) for _ in msg
|
||||
]
|
||||
|
||||
results = []
|
||||
search_coroutines = [
|
||||
self.long_term_working_memory.search(
|
||||
query=item,
|
||||
agent_id=self.agent_id,
|
||||
user_id=self.user_id,
|
||||
run_id=self.run_id,
|
||||
limit=limit,
|
||||
)
|
||||
for item in msg_strs
|
||||
]
|
||||
search_results = await asyncio.gather(*search_coroutines)
|
||||
for result in search_results:
|
||||
if result:
|
||||
results.extend(
|
||||
[memory["memory"] for memory in result["results"]],
|
||||
)
|
||||
if "relations" in result.keys():
|
||||
results.extend(
|
||||
self._format_relations(result),
|
||||
)
|
||||
|
||||
return "\n".join(results)
|
||||
363
src/agentscope/memory/_long_term_memory/_mem0/_mem0_utils.py
Normal file
363
src/agentscope/memory/_long_term_memory/_mem0/_mem0_utils.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Utility classes for integrating AgentScope with mem0 library.
|
||||
|
||||
This module provides wrapper classes that allow AgentScope models to be used
|
||||
with the mem0 library for long-term memory functionality.
|
||||
"""
|
||||
import asyncio
|
||||
import atexit
|
||||
import threading
|
||||
from typing import Any, Coroutine, Dict, List, Literal
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
from ....embedding import EmbeddingModelBase
|
||||
from ....model import ChatModelBase, ChatResponse
|
||||
|
||||
|
||||
class _EventLoopManager:
|
||||
"""Global event loop manager for running async operations in sync context.
|
||||
|
||||
This manager creates and maintains a persistent background event loop
|
||||
that runs in a separate daemon thread. This ensures that async model
|
||||
clients (like Ollama AsyncClient) remain bound to the same event loop
|
||||
across multiple calls, avoiding "Event loop is closed" errors.
|
||||
"""
|
||||
|
||||
_DEFAULT_TIMEOUT = 5.0 # Default timeout in seconds
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event loop manager."""
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.thread: threading.Thread | None = None
|
||||
self.lock = threading.Lock()
|
||||
self.loop_started = threading.Event()
|
||||
|
||||
# Register cleanup function to be called on program exit
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
"""Get or create the persistent background event loop.
|
||||
|
||||
Returns:
|
||||
`asyncio.AbstractEventLoop`:
|
||||
The persistent event loop running in a background thread.
|
||||
|
||||
Raises:
|
||||
`RuntimeError`: If the event loop fails to start within the
|
||||
timeout.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.loop is None or self.loop.is_closed():
|
||||
|
||||
def run_loop() -> None:
|
||||
"""Run the event loop in the background thread."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# Store the loop reference before starting
|
||||
self.loop = loop
|
||||
self.loop_started.set()
|
||||
loop.run_forever()
|
||||
|
||||
# Clear the event before starting the thread
|
||||
self.loop_started.clear()
|
||||
|
||||
# Create and start the background thread
|
||||
self.thread = threading.Thread(
|
||||
target=run_loop,
|
||||
daemon=True,
|
||||
name="AgentScope-AsyncLoop",
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
# Wait for the loop to be ready
|
||||
if not self.loop_started.wait(timeout=self._DEFAULT_TIMEOUT):
|
||||
raise RuntimeError(
|
||||
"Timeout waiting for event loop to start",
|
||||
)
|
||||
|
||||
# After waiting, self.loop should be set by the background thread
|
||||
assert (
|
||||
self.loop is not None
|
||||
), "Event loop was not initialized properly"
|
||||
return self.loop
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup the event loop and thread on program exit."""
|
||||
with self.lock:
|
||||
if self.loop is not None and not self.loop.is_closed():
|
||||
# Stop the event loop gracefully
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
|
||||
# Wait for the thread to finish
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=self._DEFAULT_TIMEOUT)
|
||||
|
||||
# Close the loop
|
||||
self.loop.close()
|
||||
self.loop = None
|
||||
self.thread = None
|
||||
|
||||
|
||||
# Global event loop manager instance
|
||||
_event_loop_manager = _EventLoopManager()
|
||||
|
||||
|
||||
def _run_async_in_persistent_loop(coro: Coroutine) -> Any:
|
||||
"""Run an async coroutine in the persistent background event loop.
|
||||
|
||||
This function uses a global event loop manager to ensure that all
|
||||
async operations run in the same event loop, which is crucial for
|
||||
async clients like Ollama that bind to a specific event loop.
|
||||
|
||||
Args:
|
||||
coro (`Coroutine`):
|
||||
The coroutine to run.
|
||||
|
||||
Returns:
|
||||
`Any`:
|
||||
The result of the coroutine.
|
||||
|
||||
Raises:
|
||||
`RuntimeError`:
|
||||
If there's an error running the coroutine.
|
||||
"""
|
||||
loop = _event_loop_manager.get_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result()
|
||||
|
||||
|
||||
class AgentScopeLLM(LLMBase):
|
||||
"""Wrapper for the AgentScope LLM.
|
||||
|
||||
This class is a wrapper for the AgentScope LLM. It is used to generate
|
||||
responses using the AgentScope LLM in mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BaseLlmConfig | None = None):
|
||||
"""Initialize the AgentScopeLLM wrapper.
|
||||
|
||||
Args:
|
||||
config (`BaseLlmConfig | None`, optional):
|
||||
Configuration object for the LLM. Default is None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
raise ValueError("`model` parameter is required")
|
||||
|
||||
if not isinstance(self.config.model, ChatModelBase):
|
||||
raise ValueError("`model` must be an instance of ChatModelBase")
|
||||
|
||||
self.agentscope_model = self.config.model
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
model_response: ChatResponse,
|
||||
has_tool: bool,
|
||||
) -> str | dict:
|
||||
"""Parse the model response into a string or
|
||||
a dict to follow the mem0 library's format.
|
||||
|
||||
Args:
|
||||
model_response (`ChatResponse`): The response from the model.
|
||||
has_tool (`bool`): Whether there are tool calls in the response.
|
||||
|
||||
Returns:
|
||||
`str | dict`:
|
||||
The parsed response. If has_tool is True, return a dict
|
||||
with "content" and "tool_calls" keys. Otherwise, return
|
||||
a string.
|
||||
"""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_parts = []
|
||||
for block in model_response.content:
|
||||
# Handle TextBlock
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(str(block.get("text", "")))
|
||||
|
||||
# Handle ThinkingBlock
|
||||
elif isinstance(block, dict) and block.get("type") == "thinking":
|
||||
thinking_parts.append(
|
||||
f"[Thinking: {block.get('thinking', '')}]",
|
||||
)
|
||||
# Handle ToolUseBlock
|
||||
elif isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tool_name = block.get("name")
|
||||
tool_input = block.get("input", {})
|
||||
tool_parts.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"arguments": tool_input,
|
||||
},
|
||||
)
|
||||
text_part = thinking_parts + text_parts
|
||||
if has_tool:
|
||||
# If there are tool calls, return the content and tool calls
|
||||
return {
|
||||
"content": "\n".join(text_part) if len(text_part) > 0 else "",
|
||||
"tool_calls": tool_parts,
|
||||
}
|
||||
else:
|
||||
return "\n".join(text_part) if len(text_part) > 0 else ""
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format: Any | None = None,
|
||||
tools: List[Dict] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
) -> str | dict:
|
||||
"""Generate a response based on the given messages using agentscope.
|
||||
|
||||
Args:
|
||||
messages (`List[Dict[str, str]]`):
|
||||
List of message dicts containing 'role' and 'content'.
|
||||
response_format (`Any | None`, optional):
|
||||
Format of the response. Not used in AgentScope.
|
||||
tools (`List[Dict] | None`, optional):
|
||||
List of tools that the model can call. Not used in AgentScope.
|
||||
tool_choice (`str`, optional):
|
||||
Tool choice method. Not used in AgentScope.
|
||||
|
||||
Returns:
|
||||
`str | dict`:
|
||||
The generated response.
|
||||
"""
|
||||
# pylint: disable=unused-argument
|
||||
try:
|
||||
# Convert the messages to AgentScope's format
|
||||
agentscope_messages = []
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role in ["system", "user", "assistant"]:
|
||||
agentscope_messages.append(
|
||||
{"role": role, "content": content},
|
||||
)
|
||||
|
||||
if not agentscope_messages:
|
||||
raise ValueError(
|
||||
"No valid messages found in the messages list",
|
||||
)
|
||||
|
||||
# Use the agentscope model to generate response (async call)
|
||||
async def _async_call() -> ChatResponse:
|
||||
# TODO: handle the streaming response or forbidden streaming
|
||||
# mode
|
||||
return await self.agentscope_model( # type: ignore
|
||||
agentscope_messages,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Run in the persistent event loop
|
||||
# This ensures the model client (e.g., Ollama AsyncClient)
|
||||
# always runs in the same event loop, avoiding binding issues
|
||||
response = _run_async_in_persistent_loop(
|
||||
_async_call(),
|
||||
)
|
||||
has_tool = tools is not None
|
||||
|
||||
# Extract text from the response content blocks
|
||||
if not response.content:
|
||||
if has_tool:
|
||||
return {
|
||||
"content": "",
|
||||
"tool_calls": [],
|
||||
}
|
||||
else:
|
||||
return ""
|
||||
|
||||
return self._parse_response(response, has_tool)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error generating response using agentscope model: {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
class AgentScopeEmbedding(EmbeddingBase):
|
||||
"""Wrapper for the AgentScope Embedding model.
|
||||
|
||||
This class is a wrapper for the AgentScope Embedding model. It is used
|
||||
to generate embeddings using the AgentScope Embedding model in mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BaseEmbedderConfig | None = None):
|
||||
"""Initialize the AgentScopeEmbedding wrapper.
|
||||
|
||||
Args:
|
||||
config (`BaseEmbedderConfig | None`, optional):
|
||||
Configuration object for the embedder. Default is None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
raise ValueError("`model` parameter is required")
|
||||
|
||||
if not isinstance(self.config.model, EmbeddingModelBase):
|
||||
raise ValueError(
|
||||
"`model` must be an instance of EmbeddingModelBase",
|
||||
)
|
||||
|
||||
self.agentscope_model = self.config.model
|
||||
|
||||
def embed(
|
||||
self,
|
||||
text: str | List[str],
|
||||
memory_action: Literal[ # pylint: disable=unused-argument
|
||||
"add",
|
||||
"search",
|
||||
"update",
|
||||
]
|
||||
| None = None,
|
||||
) -> List[float]:
|
||||
"""Get the embedding for the given text using AgentScope.
|
||||
|
||||
Args:
|
||||
text (`str | List[str]`):
|
||||
The text to embed.
|
||||
memory_action (`Literal["add", "search", "update"] | None`, \
|
||||
optional):
|
||||
The type of embedding to use. Must be one of "add", "search",
|
||||
or "update". Defaults to None.
|
||||
|
||||
Returns:
|
||||
`List[float]`:
|
||||
The embedding vector.
|
||||
"""
|
||||
try:
|
||||
# Convert single text to list for AgentScope embedding model
|
||||
text_list = [text] if isinstance(text, str) else text
|
||||
|
||||
# Use the agentscope model to generate embedding (async call)
|
||||
async def _async_call() -> Any:
|
||||
response = await self.agentscope_model(text_list)
|
||||
return response
|
||||
|
||||
# Run in the persistent event loop
|
||||
# This ensures the model client (e.g., Ollama AsyncClient)
|
||||
# always runs in the same event loop, avoiding binding issues
|
||||
response = _run_async_in_persistent_loop(
|
||||
_async_call(),
|
||||
)
|
||||
|
||||
# Extract the embedding vector from the first Embedding object
|
||||
# response.embeddings is a list of Embedding objects
|
||||
# Each Embedding object has an 'embedding' attribute containing
|
||||
# the vector
|
||||
embedding = response.embeddings[0]
|
||||
|
||||
if embedding is None:
|
||||
raise ValueError("Failed to extract embedding from response")
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error generating embedding using agentscope model: {str(e)}",
|
||||
) from e
|
||||
12
src/agentscope/memory/_long_term_memory/_reme/__init__.py
Normal file
12
src/agentscope/memory/_long_term_memory/_reme/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The reme memory module."""
|
||||
|
||||
from ._reme_personal_long_term_memory import ReMePersonalLongTermMemory
|
||||
from ._reme_task_long_term_memory import ReMeTaskLongTermMemory
|
||||
from ._reme_tool_long_term_memory import ReMeToolLongTermMemory
|
||||
|
||||
__all__ = [
|
||||
"ReMePersonalLongTermMemory",
|
||||
"ReMeTaskLongTermMemory",
|
||||
"ReMeToolLongTermMemory",
|
||||
]
|
||||
@@ -0,0 +1,364 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Base long-term memory implementation using ReMe library.
|
||||
|
||||
This module provides a base class for long-term memory implementations
|
||||
that integrate with the ReMe library. ReMe enables agents to maintain
|
||||
persistent, searchable memories across sessions and contexts.
|
||||
|
||||
The module handles the integration between AgentScope's memory system and
|
||||
the ReMe library, including:
|
||||
- Model configuration and API credential management
|
||||
- Context lifecycle management (async context managers)
|
||||
- Graceful handling of missing dependencies
|
||||
- Error handling with helpful installation instructions
|
||||
|
||||
Key Features:
|
||||
- Supports both DashScope and OpenAI model providers
|
||||
- Automatic extraction of API credentials and endpoints
|
||||
- Flexible configuration via config files or kwargs
|
||||
- Safe fallback behavior when reme_ai is not installed
|
||||
|
||||
Dependencies:
|
||||
The ReMe library is an optional dependency that must be installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install reme-ai
|
||||
For more information, visit: https://github.com/modelscope/reMe
|
||||
|
||||
Subclasses:
|
||||
This base class is extended by specific memory type implementations:
|
||||
- ReMeToolLongTermMemory: For tool execution patterns and guidelines
|
||||
- ReMeTaskLongTermMemory: For task execution experiences and learnings
|
||||
- ReMePersonalLongTermMemory: For user preferences and personal information
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentscope.models import OpenAIChatModel
|
||||
from agentscope.embedding import OpenAITextEmbedding
|
||||
from agentscope.memory._reme import ReMeToolLongTermMemory
|
||||
|
||||
# Initialize models
|
||||
model = OpenAIChatModel(model_name="gpt-4", api_key="...")
|
||||
embedding = OpenAITextEmbedding(
|
||||
model_name="text-embedding-3-small", api_key="...")
|
||||
|
||||
# Create memory instance
|
||||
memory = ReMeToolLongTermMemory(
|
||||
agent_name="my_agent",
|
||||
user_name="user_123",
|
||||
model=model,
|
||||
embedding_model=embedding
|
||||
)
|
||||
|
||||
# Use memory in async context
|
||||
async with memory:
|
||||
# Record tool execution
|
||||
await memory.record_to_memory(
|
||||
thinking="This tool worked well for data processing",
|
||||
content=['{"tool_name": "process_data", "success": true, ...}']
|
||||
)
|
||||
|
||||
# Retrieve tool guidelines
|
||||
result = await memory.retrieve_from_memory(
|
||||
keywords=["process_data"]
|
||||
)
|
||||
|
||||
"""
|
||||
from abc import ABCMeta
|
||||
from typing import Any
|
||||
|
||||
from .._long_term_memory_base import LongTermMemoryBase
|
||||
from ....embedding import DashScopeTextEmbedding, OpenAITextEmbedding
|
||||
from ....model import DashScopeChatModel, OpenAIChatModel
|
||||
|
||||
|
||||
class ReMeLongTermMemoryBase(LongTermMemoryBase, metaclass=ABCMeta):
|
||||
"""Base class for ReMe-based long-term memory implementations.
|
||||
|
||||
This class provides the foundation for integrating AgentScope with the ReMe
|
||||
library, enabling agents to maintain and retrieve long-term memories across
|
||||
different contexts.
|
||||
|
||||
The ReMe library must be installed separately:
|
||||
pip install reme-ai
|
||||
|
||||
If the library is not installed, a warning will be issued during
|
||||
initialization,
|
||||
and runtime errors with installation instructions will be raised
|
||||
when attempting
|
||||
to use memory operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
user_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
model: DashScopeChatModel | OpenAIChatModel | None = None,
|
||||
embedding_model: (
|
||||
DashScopeTextEmbedding | OpenAITextEmbedding | None
|
||||
) = None,
|
||||
reme_config_path: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the ReMe-based long-term memory.
|
||||
|
||||
This constructor sets up the connection to the ReMe
|
||||
library and configures
|
||||
the necessary models for memory operations. The ReMe app
|
||||
will be initialized
|
||||
with the provided model configurations.
|
||||
|
||||
Args:
|
||||
agent_name (`str | None`, optional):
|
||||
Name identifier for the agent. Used for organizing
|
||||
memories by agent.
|
||||
user_name (`str | None`, optional):
|
||||
Unique identifier for the user or workspace. This maps
|
||||
to workspace_id in ReMe and helps isolate memories across
|
||||
different users/workspaces.
|
||||
run_name (`str | None`, optional):
|
||||
Name identifier for the current execution run or session.
|
||||
model (`DashScopeChatModel | OpenAIChatModel | None`, optional):
|
||||
The chat model to use for memory operations. The model's
|
||||
API credentials and endpoint will be extracted and
|
||||
passed to ReMe.
|
||||
embedding_model (`DashScopeTextEmbedding | OpenAITextEmbedding | \
|
||||
None`, optional):
|
||||
The embedding model to use for semantic memory retrieval.
|
||||
The model's API credentials and endpoint will be
|
||||
extracted and passed to ReMe.
|
||||
reme_config_path (`str | None`, optional):
|
||||
Path to a custom ReMe configuration file. If not provided, ReMe
|
||||
will use its default configuration.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments to pass to the
|
||||
ReMeApp constructor.
|
||||
These can include custom ReMe configuration parameters.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the provided model is not a DashScopeChatModel or
|
||||
OpenAIChatModel, or if the embedding_model is not a
|
||||
DashScopeTextEmbedding or OpenAITextEmbedding.
|
||||
|
||||
Note:
|
||||
If the reme_ai library is not installed, a warning will be
|
||||
issued and self.app will be set to None. Subsequent memory
|
||||
operations will raise RuntimeError with installation
|
||||
instructions.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentscope.models import OpenAIChatModel
|
||||
from agentscope.embedding import OpenAITextEmbedding
|
||||
from agentscope.memory._reme import ReMeToolLongTermMemory
|
||||
|
||||
# Initialize models
|
||||
model = OpenAIChatModel(
|
||||
model_name="gpt-4",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
embedding = OpenAITextEmbedding(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
|
||||
# Create memory instance
|
||||
memory = ReMeToolLongTermMemory(
|
||||
agent_name="my_agent",
|
||||
user_name="user_123",
|
||||
run_name="session_001",
|
||||
model=model,
|
||||
embedding_model=embedding
|
||||
)
|
||||
|
||||
# Use with async context manager
|
||||
async with memory:
|
||||
# Memory operations...
|
||||
pass
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Store agent and workspace identifiers
|
||||
self.agent_name = agent_name
|
||||
# Maps to ReMe's workspace_id concept
|
||||
self.workspace_id = user_name
|
||||
self.run_name = run_name
|
||||
|
||||
# Build configuration arguments for ReMeApp
|
||||
# These will be passed as command-line style config overrides
|
||||
config_args = []
|
||||
|
||||
# Extract LLM API credentials based on model type
|
||||
# DashScope uses a fixed endpoint, OpenAI can have custom base_url
|
||||
if isinstance(model, DashScopeChatModel):
|
||||
llm_api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
llm_api_key = model.api_key
|
||||
|
||||
elif isinstance(model, OpenAIChatModel):
|
||||
llm_api_base = str(getattr(model.client, "base_url", None))
|
||||
llm_api_key = str(getattr(model.client, "api_key", None))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"model must be a DashScopeChatModel or "
|
||||
f"OpenAIChatModel instance. "
|
||||
f"Got {type(model).__name__} instead.",
|
||||
)
|
||||
|
||||
# Extract model name and add to config if provided
|
||||
llm_model_name = model.model_name
|
||||
|
||||
if llm_model_name:
|
||||
config_args.append(f"llm.default.model_name={llm_model_name}")
|
||||
|
||||
# Extract embedding model API credentials based on type
|
||||
# Similar to LLM, DashScope uses fixed endpoint,
|
||||
# OpenAI can be customized
|
||||
if isinstance(embedding_model, DashScopeTextEmbedding):
|
||||
embedding_api_base = (
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
embedding_api_key = embedding_model.api_key
|
||||
|
||||
elif isinstance(embedding_model, OpenAITextEmbedding):
|
||||
embedding_api_base = str(
|
||||
getattr(embedding_model.client, "base_url", None),
|
||||
)
|
||||
embedding_api_key = str(
|
||||
getattr(embedding_model.client, "api_key", None),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"embedding_model must be a DashScopeTextEmbedding or "
|
||||
"OpenAITextEmbedding instance. "
|
||||
f"Got {type(embedding_model).__name__} instead.",
|
||||
)
|
||||
|
||||
# Extract embedding model name and add to config if provided
|
||||
embedding_model_name = embedding_model.model_name
|
||||
|
||||
if embedding_model_name:
|
||||
config_args.append(
|
||||
f"embedding_model.default.model_name={embedding_model_name}",
|
||||
)
|
||||
|
||||
dimensions = embedding_model.dimensions
|
||||
config_args.append(
|
||||
f'embedding_model.default.params={{"dimensions": {dimensions}}}',
|
||||
)
|
||||
|
||||
# Attempt to import and initialize ReMe
|
||||
# If import fails, set app to None and issue a warning
|
||||
# This allows the class to be instantiated even without
|
||||
# reme_ai installed
|
||||
try:
|
||||
from reme_ai import ReMeApp
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The 'reme_ai' library is required for ReMe-based "
|
||||
"long-term memory. Please install it by `pip install reme-ai`,"
|
||||
"and visit: https://github.com/modelscope/reMe for more "
|
||||
"information.",
|
||||
) from e
|
||||
|
||||
# Initialize ReMe with extracted configurations
|
||||
self.app = ReMeApp(
|
||||
*config_args, # Config overrides as positional args
|
||||
llm_api_key=llm_api_key,
|
||||
llm_api_base=llm_api_base,
|
||||
embedding_api_key=embedding_api_key,
|
||||
embedding_api_base=embedding_api_base,
|
||||
# Optional custom config file
|
||||
config_path=reme_config_path,
|
||||
# Additional ReMe-specific configurations
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Track if the app context is active (started via __aenter__)
|
||||
self._app_started = False
|
||||
|
||||
async def __aenter__(self) -> "ReMeLongTermMemoryBase":
|
||||
"""Async context manager entry point.
|
||||
|
||||
This method is called when entering an async context
|
||||
(using 'async with'). It initializes the ReMe app context if
|
||||
available, enabling memory operations within the context block.
|
||||
|
||||
Returns:
|
||||
`ReMeLongTermMemoryBase`:
|
||||
The memory instance itself, allowing it to be used in
|
||||
the context.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
memory = ReMeToolLongTermMemory(
|
||||
agent_name="my_agent",
|
||||
model=model,
|
||||
embedding_model=embedding
|
||||
)
|
||||
|
||||
async with memory:
|
||||
# Memory operations can be performed here
|
||||
await memory.record_to_memory(
|
||||
thinking="Recording tool usage",
|
||||
content=[...]
|
||||
)
|
||||
|
||||
"""
|
||||
if self.app is not None:
|
||||
await self.app.__aenter__()
|
||||
self._app_started = True
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Any = None,
|
||||
exc_val: Any = None,
|
||||
exc_tb: Any = None,
|
||||
) -> None:
|
||||
"""Async context manager exit point.
|
||||
|
||||
This method is called when exiting an async context (at the end
|
||||
of 'async with' block or when an exception occurs). It properly
|
||||
cleans up the ReMe app context and resources.
|
||||
|
||||
Args:
|
||||
exc_type (`Any`):
|
||||
The type of exception that occurred, if any. None if no
|
||||
exception.
|
||||
exc_val (`Any`):
|
||||
The exception instance that occurred, if any. None if no
|
||||
exception.
|
||||
exc_tb (`Any`):
|
||||
The traceback object for the exception, if any. None if
|
||||
no exception.
|
||||
|
||||
.. note:: This method will gracefully handle the case where self.app
|
||||
is None (reme_ai not installed) by skipping the cleanup but still
|
||||
marking the app as stopped. It will also always set _app_started
|
||||
to False, ensuring the memory state is properly reset.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
async with memory:
|
||||
try:
|
||||
# Memory operations
|
||||
await memory.record_to_memory(...)
|
||||
except Exception as e:
|
||||
# __aexit__ will be called even if an exception occurs
|
||||
print(f"Error: {e}")
|
||||
# __aexit__ has been called and resources are cleaned up
|
||||
|
||||
"""
|
||||
if self.app is not None:
|
||||
await self.app.__aexit__(exc_type, exc_val, exc_tb)
|
||||
self._app_started = False
|
||||
@@ -0,0 +1,414 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Personal memory implementation using ReMe library.
|
||||
|
||||
This module provides a personal memory implementation that integrates
|
||||
with the ReMe library to provide persistent personal memory storage and
|
||||
retrieval capabilities for AgentScope agents.
|
||||
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
|
||||
from ...._logging import logger
|
||||
from ....message import Msg, TextBlock
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
class ReMePersonalLongTermMemory(ReMeLongTermMemoryBase):
|
||||
"""Personal memory implementation using ReMe library."""
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Record important user information to long-term memory.
|
||||
|
||||
Record important user information to long-term memory for future
|
||||
reference.
|
||||
|
||||
Use this function to save user's personal information,
|
||||
preferences, habits, and facts that you may need in future
|
||||
conversations. This enables you to provide personalized and
|
||||
contextually relevant responses.
|
||||
|
||||
When to record:
|
||||
|
||||
- User shares personal preferences (e.g., "I prefer homestays
|
||||
when traveling")
|
||||
- User mentions habits or routines (e.g., "I start work at 9 AM")
|
||||
- User states likes/dislikes (e.g., "I enjoy drinking green tea")
|
||||
- User provides personal facts (e.g., "I work as a software
|
||||
engineer")
|
||||
|
||||
What to record: Be specific and structured. Include who, when,
|
||||
where, what, why, and how when relevant.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your reasoning about why this information is worth
|
||||
recording and how it might be useful later.
|
||||
content (`list[str]`):
|
||||
List of specific facts to remember. Each string should be
|
||||
a clear, standalone piece of information. Examples:
|
||||
["User prefers homestays in Hangzhou", "User likes
|
||||
visiting West Lake in the morning"].
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Confirmation message indicating successful memory
|
||||
recording.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMePersonalMemory] Entering record_to_memory - "
|
||||
"thinking: %s, content: %s, kwargs: %s",
|
||||
thinking,
|
||||
content,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare messages for personal memory recording
|
||||
messages = []
|
||||
|
||||
# Add thinking as a user message if provided
|
||||
if thinking:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": thinking,
|
||||
},
|
||||
)
|
||||
|
||||
# Add content items as user messages
|
||||
for item in content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": item,
|
||||
},
|
||||
)
|
||||
# Add a simple assistant acknowledgment
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"I understand and will remember this "
|
||||
"information."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
result = await self.app.async_execute(
|
||||
name="summary_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract metadata about stored memories if available
|
||||
metadata = result.get("metadata", {})
|
||||
memory_list = metadata.get("memory_list", [])
|
||||
|
||||
if memory_list:
|
||||
summary_text = (
|
||||
f"Successfully recorded {len(memory_list)} "
|
||||
f"memory/memories to personal memory."
|
||||
)
|
||||
else:
|
||||
summary_text = "Memory recording completed."
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=summary_text,
|
||||
),
|
||||
],
|
||||
metadata={"result": result},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error recording memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Search and retrieve relevant information from long-term memory.
|
||||
|
||||
.. note:: You should call this function BEFORE answering
|
||||
questions about the user's preferences, past information, or
|
||||
personal details. This ensures you provide accurate information
|
||||
based on stored memories rather than guessing.
|
||||
|
||||
Use this when:
|
||||
|
||||
- User asks "what do I like?", "what are my preferences?",
|
||||
"what do you know about me?"
|
||||
- User asks about their past behaviors, habits, or stated
|
||||
preferences
|
||||
- User refers to information they shared in previous
|
||||
conversations
|
||||
- You need to personalize responses based on user's history
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
Keywords to search for in memory. Be specific and use
|
||||
multiple keywords for better results. Examples:
|
||||
["travel preferences", "Hangzhou"], ["work habits",
|
||||
"morning routine"], ["food preferences", "tea"].
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 3.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the retrieval operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Retrieved memories matching the keywords. If no memories
|
||||
found, you'll receive a message indicating that.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMePersonalMemory] Entering retrieve_from_memory - "
|
||||
"keywords: %s, kwargs: %s",
|
||||
keywords,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
# Search for each keyword
|
||||
for keyword in keywords:
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=keyword,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract the answer from the result
|
||||
answer = result.get("answer", "")
|
||||
if answer:
|
||||
results.append(f"Keyword '{keyword}':\n{answer}")
|
||||
|
||||
# Combine all results
|
||||
if results:
|
||||
combined_text = "\n\n".join(results)
|
||||
else:
|
||||
combined_text = "No memories found for the given keywords."
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=combined_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record the content to the long-term memory.
|
||||
|
||||
This method converts AgentScope messages to ReMe's format and
|
||||
records them using the personal memory flow.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the mem0 recording.
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not msg_list:
|
||||
return
|
||||
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert AgentScope messages to ReMe format
|
||||
messages = []
|
||||
for msg in msg_list:
|
||||
# Extract content as string
|
||||
if isinstance(msg.content, str):
|
||||
content_str = msg.content
|
||||
elif isinstance(msg.content, list):
|
||||
# Join content blocks into a single string
|
||||
content_parts = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
content_str = "\n".join(content_parts)
|
||||
else:
|
||||
content_str = str(msg.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": content_str,
|
||||
},
|
||||
)
|
||||
|
||||
await self.app.async_execute(
|
||||
name="summary_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise to maintain compatibility
|
||||
logger.exception("Error recording messages to memory: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error recording messages to memory: {str(e)}")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve the content from the long-term memory.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message to search for in the memory, which should be
|
||||
specific and concise, e.g. the person's name, the date, the
|
||||
location, etc.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. If the
|
||||
message is a list of messages, the limit applies to each
|
||||
message. If the message is a single message, the limit is the
|
||||
total number of memories to retrieve for that message. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved memory as a string.
|
||||
"""
|
||||
if msg is None:
|
||||
return ""
|
||||
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Only use the last message's content for retrieval
|
||||
last_msg = msg[-1]
|
||||
query = ""
|
||||
|
||||
if isinstance(last_msg.content, str):
|
||||
query = last_msg.content
|
||||
elif isinstance(last_msg.content, list):
|
||||
# Extract text from content blocks
|
||||
content_parts = []
|
||||
for block in last_msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
query = "\n".join(content_parts)
|
||||
|
||||
if not query:
|
||||
return ""
|
||||
|
||||
# Retrieve using the query from the last message
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=query,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return result.get("answer", "")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving memory: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error retrieving memory: {str(e)}")
|
||||
return ""
|
||||
@@ -0,0 +1,436 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task memory implementation using ReMe library.
|
||||
|
||||
This module provides a task memory implementation that integrates
|
||||
with the ReMe library to learn from execution trajectories and
|
||||
retrieve relevant task experiences.
|
||||
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
|
||||
from ...._logging import logger
|
||||
from ....message import Msg, TextBlock
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
class ReMeTaskLongTermMemory(ReMeLongTermMemoryBase):
|
||||
"""Task memory implementation using ReMe library.
|
||||
|
||||
Task memory learns from execution trajectories and provides
|
||||
retrieval of relevant task experiences.
|
||||
|
||||
"""
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Record task execution experiences and learnings.
|
||||
|
||||
Record task execution experiences and learnings to long-term
|
||||
memory.
|
||||
|
||||
Use this function to save valuable task-related knowledge that
|
||||
can help with future similar tasks. This enables learning from
|
||||
experience and improving over time.
|
||||
|
||||
When to record:
|
||||
|
||||
- After solving technical problems or completing tasks
|
||||
- When discovering useful techniques or approaches
|
||||
- After implementing solutions with specific steps
|
||||
- When learning best practices or important lessons
|
||||
|
||||
What to record: Be detailed and actionable. Include:
|
||||
|
||||
- Task description and context
|
||||
- Step-by-step execution details
|
||||
- Specific techniques and methods used
|
||||
- Results, outcomes, and effectiveness
|
||||
- Lessons learned and considerations
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your reasoning about why this task experience is valuable
|
||||
and what makes it worth remembering for future reference.
|
||||
content (`list[str]`):
|
||||
List of specific task insights to remember. Each string
|
||||
should be a clear, actionable piece of information.
|
||||
Examples: ["Add indexes on WHERE clause columns to speed
|
||||
up queries", "Use EXPLAIN ANALYZE to identify missing
|
||||
indexes"].
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments. Can include 'score' (float)
|
||||
to indicate the quality/success of this approach
|
||||
(default: 1.0).
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Confirmation message indicating successful memory
|
||||
recording.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeTaskMemory] Entering record_to_memory - "
|
||||
"thinking: %s, content: %s, kwargs: %s",
|
||||
thinking,
|
||||
content,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare messages for task memory recording
|
||||
messages = []
|
||||
|
||||
# Add thinking as a user message if provided
|
||||
if thinking:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": thinking,
|
||||
},
|
||||
)
|
||||
|
||||
# Add content items as user-assistant pairs
|
||||
for item in content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": item,
|
||||
},
|
||||
)
|
||||
# Add a simple assistant acknowledgment
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Task information recorded.",
|
||||
},
|
||||
)
|
||||
|
||||
result = await self.app.async_execute(
|
||||
name="summary_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
"score": kwargs.pop("score", 1.0),
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract metadata if available
|
||||
summary_text = (
|
||||
f"Successfully recorded {len(content)} task memory/memories."
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=summary_text,
|
||||
),
|
||||
],
|
||||
metadata={"result": result},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error recording task memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording task memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Search and retrieve relevant task experiences.
|
||||
|
||||
Search and retrieve relevant task experiences from long-term
|
||||
memory.
|
||||
|
||||
IMPORTANT: You should call this function BEFORE attempting to
|
||||
solve problems or answer technical questions. This ensures you
|
||||
leverage experiences and proven solutions rather than
|
||||
starting from scratch.
|
||||
|
||||
Use this when:
|
||||
- Asked to solve a technical problem or implement a solution
|
||||
- Asked for recommendations, best practices, or approaches
|
||||
- Asked "what do you know about...?" or "have you seen this
|
||||
before?"
|
||||
- Dealing with tasks that may be similar to experiences
|
||||
- Need to recall specific techniques or methods
|
||||
|
||||
Benefits of retrieving first:
|
||||
- Learn from past successes and mistakes
|
||||
- Provide more accurate, battle-tested solutions
|
||||
- Avoid reinventing the wheel
|
||||
- Give consistent, informed recommendations
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
Keywords describing the task or problem domain. Be
|
||||
specific and use technical terms. Examples:
|
||||
["database optimization", "slow queries"], ["API design",
|
||||
"rate limiting"], ["code refactoring", "Python"].
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments. Can include 'top_k' (int)
|
||||
to specify number of experiences to retrieve
|
||||
(default: 3).
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Retrieved task experiences and learnings. If no relevant
|
||||
experiences found, you'll receive a message indicating
|
||||
that.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeTaskMemory] Entering retrieve_from_memory - "
|
||||
"keywords: %s, kwargs: %s",
|
||||
keywords,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
# Search for each keyword
|
||||
for keyword in keywords:
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=keyword,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract the answer from the result
|
||||
answer = result.get("answer", "")
|
||||
if answer:
|
||||
results.append(f"Keyword '{keyword}':\n{answer}")
|
||||
|
||||
# Combine all results
|
||||
if results:
|
||||
combined_text = "\n\n".join(results)
|
||||
else:
|
||||
combined_text = (
|
||||
"No task experiences found for the given keywords."
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=combined_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving task memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving task memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record the content to the task memory.
|
||||
|
||||
This method converts AgentScope messages to ReMe's format and
|
||||
records them as a task execution trajectory.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording.
|
||||
Can include 'score' (float) for trajectory scoring
|
||||
(default: 1.0).
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not msg_list:
|
||||
return
|
||||
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert AgentScope messages to ReMe format
|
||||
messages = []
|
||||
for msg in msg_list:
|
||||
# Extract content as string
|
||||
if isinstance(msg.content, str):
|
||||
content_str = msg.content
|
||||
elif isinstance(msg.content, list):
|
||||
# Join content blocks into a single string
|
||||
content_parts = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
content_str = "\n".join(content_parts)
|
||||
else:
|
||||
content_str = str(msg.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": content_str,
|
||||
},
|
||||
)
|
||||
|
||||
# Extract score from kwargs if provided, default to 1.0
|
||||
score = kwargs.pop("score", 1.0)
|
||||
|
||||
await self.app.async_execute(
|
||||
name="summary_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
"score": score,
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise to maintain compatibility
|
||||
logger.exception(
|
||||
"Error recording messages to task memory: %s",
|
||||
str(e),
|
||||
)
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Error recording messages to task memory: {str(e)}",
|
||||
)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve relevant task experiences from memory.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message to search for relevant task experiences.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. If the
|
||||
message is a list of messages, the limit applies to each
|
||||
message. If the message is a single message, the limit is the
|
||||
total number of memories to retrieve for that message. Defaults
|
||||
to 3.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved task experiences as a string.
|
||||
"""
|
||||
if msg is None:
|
||||
return ""
|
||||
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Only use the last message's content for retrieval
|
||||
last_msg = msg[-1]
|
||||
query = ""
|
||||
|
||||
if isinstance(last_msg.content, str):
|
||||
query = last_msg.content
|
||||
elif isinstance(last_msg.content, list):
|
||||
# Extract text from content blocks
|
||||
content_parts = []
|
||||
for block in last_msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
query = "\n".join(content_parts)
|
||||
|
||||
if not query:
|
||||
return ""
|
||||
|
||||
# Retrieve using the query from the last message
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=query,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return result.get("answer", "")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving task memory: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error retrieving task memory: {str(e)}")
|
||||
return ""
|
||||
@@ -0,0 +1,545 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tool memory implementation using ReMe library.
|
||||
|
||||
This module provides a tool memory implementation that integrates
|
||||
with the ReMe library to record tool execution results and retrieve
|
||||
tool usage guidelines.
|
||||
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
|
||||
from ...._logging import logger
|
||||
from ....message import Msg, TextBlock
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
class ReMeToolLongTermMemory(ReMeLongTermMemoryBase):
|
||||
"""Tool memory implementation using ReMe library.
|
||||
|
||||
Tool memory records tool execution results and generates usage
|
||||
guidelines from the execution history.
|
||||
|
||||
"""
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Record tool execution results to build tool usage patterns.
|
||||
|
||||
Record tool execution results to build a knowledge base of tool
|
||||
usage patterns.
|
||||
|
||||
Use this function after successfully using tools to capture
|
||||
execution details, results, and performance metrics. Over time,
|
||||
this builds comprehensive usage guidelines and best practices
|
||||
for each tool.
|
||||
|
||||
When to record:
|
||||
|
||||
- After successfully executing any tool
|
||||
- After tool failures (to learn what doesn't work)
|
||||
- When discovering effective parameter combinations
|
||||
- After noteworthy tool usage patterns
|
||||
|
||||
What to record: Each tool execution should include complete
|
||||
execution details.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your reasoning about why this tool execution is worth
|
||||
recording. Mention what worked well, what could be
|
||||
improved, or lessons learned.
|
||||
content (`list[str]`):
|
||||
List of JSON strings, each representing a tool execution.
|
||||
Each JSON must have these fields:
|
||||
- create_time: Timestamp in format "YYYY-MM-DD HH:MM:SS"
|
||||
- tool_name: Name of the tool executed
|
||||
- input: Input parameters as a dict
|
||||
- output: Tool's output as a string
|
||||
- token_cost: Token cost (integer)
|
||||
- success: Whether execution succeeded (boolean)
|
||||
- time_cost: Execution time in seconds (float)
|
||||
|
||||
Example: '{"create_time": "2024-01-01 10:00:00",
|
||||
"tool_name": "search", "input": {"query": "Python"},
|
||||
"output": "Found 10 results", "token_cost": 100,
|
||||
"success": true, "time_cost": 1.2}'
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Confirmation message with number of executions recorded
|
||||
and guidelines generated.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeToolMemory] Entering record_to_memory - "
|
||||
"thinking: %s, content: %s, kwargs: %s",
|
||||
thinking,
|
||||
content,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
import json
|
||||
|
||||
# Parse each content item as a tool_call_result
|
||||
tool_call_results = []
|
||||
tool_names_set = set()
|
||||
|
||||
for item in content:
|
||||
try:
|
||||
# Parse JSON string to dict
|
||||
tool_call_result = json.loads(item)
|
||||
tool_call_results.append(tool_call_result)
|
||||
|
||||
# Track tool names for summary
|
||||
if "tool_name" in tool_call_result:
|
||||
tool_names_set.add(tool_call_result["tool_name"])
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Skip invalid JSON items
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Failed to parse tool call result JSON: {item}. "
|
||||
f"Error: {str(e)}",
|
||||
)
|
||||
continue
|
||||
|
||||
if not tool_call_results:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="No valid tool call results to record.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# First, add the tool call results
|
||||
await self.app.async_execute(
|
||||
name="add_tool_call_result",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_call_results=tool_call_results,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Then, summarize the tool memory for the affected tools
|
||||
if tool_names_set:
|
||||
tool_names_list = list(tool_names_set)
|
||||
await self.app.async_execute(
|
||||
name="summary_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names_list,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
num_results = len(tool_call_results)
|
||||
summary_text = (
|
||||
f"Successfully recorded {num_results} tool execution "
|
||||
f"result{'s' if num_results > 1 else ''} and generated "
|
||||
f"usage guidelines."
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=summary_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error recording tool memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording tool memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Retrieve usage guidelines and best practices for tools.
|
||||
|
||||
Retrieve usage guidelines and best practices for specific tools.
|
||||
|
||||
.. note:: You should call this function BEFORE using a tool,
|
||||
especially if you're uncertain about its proper usage or want to
|
||||
follow established best practices. This retrieves synthesized
|
||||
guidelines based on past tool executions.
|
||||
|
||||
Use this when:
|
||||
|
||||
- About to use a tool and want to know the best practices
|
||||
- Uncertain about tool parameters or usage patterns
|
||||
- Want to learn from past successful/failed tool executions
|
||||
- User asks "how should I use this tool?" or "what's the best
|
||||
way to..."
|
||||
- Need to understand tool performance characteristics or
|
||||
limitations
|
||||
|
||||
Benefits of retrieving first:
|
||||
|
||||
- Learn from accumulated tool usage experience
|
||||
- Avoid common mistakes and pitfalls
|
||||
- Use optimal parameter combinations
|
||||
- Understand tool performance and cost characteristics
|
||||
- Follow established best practices
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
List of tool names to retrieve guidelines for. Use the
|
||||
exact tool names. Examples: ["search"],
|
||||
["database_query", "cache_get"], ["api_call"].
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the retrieval operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Retrieved usage guidelines and best practices for the
|
||||
specified tools. If no guidelines exist yet, you'll
|
||||
receive a message indicating that.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeToolMemory] Entering retrieve_from_memory - "
|
||||
"keywords: %s, kwargs: %s",
|
||||
keywords,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Join all tool names with comma
|
||||
tool_names = ",".join(keywords)
|
||||
|
||||
# Retrieve tool guidelines for all tools at once
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract the answer from the result
|
||||
answer = result.get("answer", "")
|
||||
if answer:
|
||||
combined_text = answer
|
||||
else:
|
||||
combined_text = f"No tool guidelines found for: {tool_names}"
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=combined_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving tool memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving tool memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _extract_content_from_messages(self, msg_list: list[Msg]) -> list[str]:
|
||||
"""Extract content strings from messages.
|
||||
|
||||
Args:
|
||||
msg_list (`list[Msg]`):
|
||||
List of messages to extract content from.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
List of extracted content strings.
|
||||
"""
|
||||
content_list = []
|
||||
for msg in msg_list:
|
||||
if isinstance(msg.content, str):
|
||||
content_list.append(msg.content)
|
||||
elif isinstance(msg.content, list):
|
||||
content_list.extend(
|
||||
self._extract_text_from_blocks(msg.content),
|
||||
)
|
||||
return content_list
|
||||
|
||||
def _extract_text_from_blocks(self, blocks: list) -> list[str]:
|
||||
"""Extract text from content blocks.
|
||||
|
||||
Args:
|
||||
blocks (`list`):
|
||||
List of content blocks.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
List of extracted text strings.
|
||||
"""
|
||||
texts = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
texts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
texts.append(block)
|
||||
return texts
|
||||
|
||||
def _parse_tool_call_results(
|
||||
self,
|
||||
content_list: list[str],
|
||||
) -> tuple[list[dict], set[str]]:
|
||||
"""Parse JSON content strings into tool call results.
|
||||
|
||||
Args:
|
||||
content_list (`list[str]`):
|
||||
List of JSON strings to parse.
|
||||
|
||||
Returns:
|
||||
`tuple[list[dict], set[str]]`:
|
||||
Tuple of (tool_call_results, tool_names_set).
|
||||
"""
|
||||
import json
|
||||
import warnings
|
||||
|
||||
tool_call_results = []
|
||||
tool_names_set = set()
|
||||
|
||||
for item in content_list:
|
||||
try:
|
||||
tool_call_result = json.loads(item)
|
||||
tool_call_results.append(tool_call_result)
|
||||
if "tool_name" in tool_call_result:
|
||||
tool_names_set.add(tool_call_result["tool_name"])
|
||||
except json.JSONDecodeError as e:
|
||||
warnings.warn(
|
||||
f"Failed to parse tool call result JSON: {item}. "
|
||||
f"Error: {str(e)}",
|
||||
)
|
||||
|
||||
return tool_call_results, tool_names_set
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record the content to the tool memory.
|
||||
|
||||
This method extracts content from messages and treats them as
|
||||
JSON strings representing tool_call_results, similar to
|
||||
record_to_memory.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory. Each message's content
|
||||
should be a JSON string or list of JSON strings
|
||||
representing tool_call_results.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording.
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not msg_list:
|
||||
return
|
||||
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract content from messages and parse as tool_call_results
|
||||
content_list = self._extract_content_from_messages(msg_list)
|
||||
if not content_list:
|
||||
return
|
||||
|
||||
# Parse each content item as a tool_call_result
|
||||
tool_call_results, tool_names_set = self._parse_tool_call_results(
|
||||
content_list,
|
||||
)
|
||||
if not tool_call_results:
|
||||
return
|
||||
|
||||
# First, add the tool call results
|
||||
await self.app.async_execute(
|
||||
name="add_tool_call_result",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_call_results=tool_call_results,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Then, summarize the tool memory for the affected tools
|
||||
if tool_names_set:
|
||||
tool_names_list = list(tool_names_set)
|
||||
await self.app.async_execute(
|
||||
name="summary_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names_list,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise to maintain compatibility
|
||||
logger.exception(
|
||||
"Error recording tool messages to memory: %s",
|
||||
str(e),
|
||||
)
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Error recording tool messages to memory: {str(e)}",
|
||||
)
|
||||
|
||||
def _extract_tool_names_from_message(self, msg: Msg) -> str:
|
||||
"""Extract tool names from a message.
|
||||
|
||||
Args:
|
||||
msg (`Msg`):
|
||||
Message to extract tool names from.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
Extracted tool names as a string.
|
||||
"""
|
||||
if isinstance(msg.content, str):
|
||||
return msg.content
|
||||
|
||||
if isinstance(msg.content, list):
|
||||
content_parts = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
return " ".join(content_parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _format_retrieve_result(self, result: Any) -> str:
|
||||
"""Format the retrieve result into a string.
|
||||
|
||||
Args:
|
||||
result (`Any`):
|
||||
Result from the retrieve operation.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
Formatted result string.
|
||||
"""
|
||||
if isinstance(result, dict) and "answer" in result:
|
||||
return result["answer"]
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
return str(result)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve tool guidelines from memory.
|
||||
|
||||
Retrieve tool guidelines from memory based on message content.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message containing tool names or queries to
|
||||
retrieve guidelines for.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. If the
|
||||
message is a list of messages, the limit applies to each
|
||||
message. If the message is a single message, the limit is the
|
||||
total number of memories to retrieve for that message. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved tool guidelines as a string.
|
||||
"""
|
||||
if msg is None:
|
||||
return ""
|
||||
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool names from the last message
|
||||
last_msg = msg[-1]
|
||||
tool_names = self._extract_tool_names_from_message(last_msg)
|
||||
|
||||
if not tool_names:
|
||||
return ""
|
||||
|
||||
# Retrieve tool guidelines
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._format_retrieve_result(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving tool guidelines: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error retrieving tool guidelines: {str(e)}")
|
||||
return ""
|
||||
16
src/agentscope/memory/_working_memory/__init__.py
Normal file
16
src/agentscope/memory/_working_memory/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The working memory module in AgentScope, which provides various memory
|
||||
storage implementations. In AgentScope, such module is responsible for
|
||||
storing and managing the short-term memory with specific marks."""
|
||||
|
||||
from ._base import MemoryBase
|
||||
from ._in_memory_memory import InMemoryMemory
|
||||
from ._redis_memory import RedisMemory
|
||||
from ._sqlalchemy_memory import AsyncSQLAlchemyMemory
|
||||
|
||||
__all__ = [
|
||||
"MemoryBase",
|
||||
"InMemoryMemory",
|
||||
"RedisMemory",
|
||||
"AsyncSQLAlchemyMemory",
|
||||
]
|
||||
168
src/agentscope/memory/_working_memory/_base.py
Normal file
168
src/agentscope/memory/_working_memory/_base.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The memory base class."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from ...message import Msg
|
||||
from ...module import StateModule
|
||||
|
||||
|
||||
class MemoryBase(StateModule):
|
||||
"""The base class for memory in agentscope."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the memory base."""
|
||||
super().__init__()
|
||||
|
||||
self._compressed_summary: str = ""
|
||||
|
||||
self.register_state("_compressed_summary")
|
||||
|
||||
async def update_compressed_summary(self, summary: str) -> None:
|
||||
"""Update the compressed summary of the memory.
|
||||
|
||||
Args:
|
||||
summary (`str`):
|
||||
The new compressed summary.
|
||||
"""
|
||||
self._compressed_summary = summary
|
||||
|
||||
@abstractmethod
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message(s) into the memory storage with the given mark
|
||||
(if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg] | None`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the memory by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or a list of strings.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The delete_by_mark method is not implemented in "
|
||||
f"{self.__class__.__name__} class.",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def size(self) -> int:
|
||||
"""Get the number of messages in the storage.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages in the storage.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear the memory content."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The update_messages_mark method is not implemented in "
|
||||
f"{self.__class__.__name__} class.",
|
||||
)
|
||||
305
src/agentscope/memory/_working_memory/_in_memory_memory.py
Normal file
305
src/agentscope/memory/_working_memory/_in_memory_memory.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The in-memory storage module for memory storage."""
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from ...message import Msg
|
||||
from ._base import MemoryBase
|
||||
|
||||
|
||||
class InMemoryMemory(MemoryBase):
|
||||
"""The in-memory implementation of memory storage."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory storage."""
|
||||
super().__init__()
|
||||
# Use a list of tuples to store messages along with their marks
|
||||
self.content: list[tuple[Msg, list[str]]] = []
|
||||
|
||||
# Register the state for serialization
|
||||
self.register_state("content")
|
||||
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or None.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
# Type checks
|
||||
if not (mark is None or isinstance(mark, str)):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or None, but got {type(mark)}.",
|
||||
)
|
||||
|
||||
if not (exclude_mark is None or isinstance(exclude_mark, str)):
|
||||
raise TypeError(
|
||||
f"The exclude_mark should be a string or None, but got "
|
||||
f"{type(exclude_mark)}.",
|
||||
)
|
||||
|
||||
# Filter messages based on mark
|
||||
filtered_content = [
|
||||
(msg, marks)
|
||||
for msg, marks in self.content
|
||||
if mark is None or mark in marks
|
||||
]
|
||||
|
||||
# Further filter messages based on exclude_mark
|
||||
if exclude_mark is not None:
|
||||
filtered_content = [
|
||||
(msg, marks)
|
||||
for msg, marks in filtered_content
|
||||
if exclude_mark not in marks
|
||||
]
|
||||
|
||||
if prepend_summary and self._compressed_summary:
|
||||
return [
|
||||
Msg(
|
||||
"user",
|
||||
self._compressed_summary,
|
||||
"user",
|
||||
),
|
||||
*[msg for msg, _ in filtered_content],
|
||||
]
|
||||
|
||||
return [msg for msg, _ in filtered_content]
|
||||
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
allow_duplicates: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message(s) into the memory storage with the given mark
|
||||
(if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg] | None`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
allow_duplicates (`bool`, defaults to `False`):
|
||||
Whether to allow duplicate messages in the storage.
|
||||
"""
|
||||
if memories is None:
|
||||
return
|
||||
|
||||
if isinstance(memories, Msg):
|
||||
memories = [memories]
|
||||
|
||||
if marks is None:
|
||||
marks = []
|
||||
elif isinstance(marks, str):
|
||||
marks = [marks]
|
||||
elif not isinstance(marks, list) or not all(
|
||||
isinstance(m, str) for m in marks
|
||||
):
|
||||
raise TypeError(
|
||||
f"The mark should be a string, a list of strings, or None, "
|
||||
f"but got {type(marks)}.",
|
||||
)
|
||||
|
||||
if not allow_duplicates:
|
||||
existing_ids = {msg.id for msg, _ in self.content}
|
||||
memories = [msg for msg in memories if msg.id not in existing_ids]
|
||||
|
||||
for msg in memories:
|
||||
self.content.append((deepcopy(msg), deepcopy(marks)))
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
initial_size = len(self.content)
|
||||
self.content = [
|
||||
(msg, marks)
|
||||
for msg, marks in self.content
|
||||
if msg.id not in msg_ids
|
||||
]
|
||||
return initial_size - len(self.content)
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the memory by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or a list of strings.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if isinstance(mark, str):
|
||||
mark = [mark]
|
||||
|
||||
if isinstance(mark, list) and not all(
|
||||
isinstance(m, str) for m in mark
|
||||
):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or a list of strings, "
|
||||
f"but got {type(mark)} with elements of types "
|
||||
f"{[type(m) for m in mark]}.",
|
||||
)
|
||||
|
||||
initial_size = len(self.content)
|
||||
for m in mark:
|
||||
self.content = [
|
||||
(msg, marks) for msg, marks in self.content if m not in marks
|
||||
]
|
||||
|
||||
return initial_size - len(self.content)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all messages from the storage."""
|
||||
self.content.clear()
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the number of messages in the storage.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages in the storage.
|
||||
"""
|
||||
return len(self.content)
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
updated_count = 0
|
||||
|
||||
for idx, (msg, marks) in enumerate(self.content):
|
||||
# If msg_ids is provided, skip messages not in the list
|
||||
if msg_ids is not None and msg.id not in msg_ids:
|
||||
continue
|
||||
|
||||
# If old_mark is provided, skip messages that do not have the old
|
||||
# mark
|
||||
if old_mark is not None and old_mark not in marks:
|
||||
continue
|
||||
|
||||
# If new_mark is None, remove the old_mark
|
||||
if new_mark is None:
|
||||
if old_mark in marks:
|
||||
marks.remove(old_mark)
|
||||
updated_count += 1
|
||||
|
||||
else:
|
||||
# If new_mark is provided, add or replace the old_mark
|
||||
if old_mark is not None and old_mark in marks:
|
||||
marks.remove(old_mark)
|
||||
if new_mark not in marks:
|
||||
marks.append(new_mark)
|
||||
updated_count += 1
|
||||
|
||||
self.content[idx] = (msg, marks)
|
||||
|
||||
return updated_count
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary for serialization."""
|
||||
return {
|
||||
**super().state_dict(),
|
||||
"content": [[msg.to_dict(), marks] for msg, marks in self.content],
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict, strict: bool = True) -> None:
|
||||
"""Load the state dictionary for deserialization."""
|
||||
if strict and "content" not in state_dict:
|
||||
raise KeyError(
|
||||
"The state_dict does not contain 'content' "
|
||||
"keys required for InMemoryMemory.",
|
||||
)
|
||||
|
||||
self._compressed_summary = state_dict.get("_compressed_summary", "")
|
||||
|
||||
self.content = []
|
||||
for item in state_dict.get("content", []):
|
||||
if isinstance(item, (tuple, list)) and len(item) == 2:
|
||||
msg_dict, marks = item
|
||||
msg = Msg.from_dict(msg_dict)
|
||||
self.content.append((msg, marks))
|
||||
|
||||
elif isinstance(item, dict):
|
||||
# For compatibility with older versions
|
||||
msg = Msg.from_dict(item)
|
||||
self.content.append((msg, []))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid item format in state_dict for InMemoryMemory.",
|
||||
)
|
||||
827
src/agentscope/memory/_working_memory/_redis_memory.py
Normal file
827
src/agentscope/memory/_working_memory/_redis_memory.py
Normal file
@@ -0,0 +1,827 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The redis based memory storage implementation."""
|
||||
import json
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from ._base import MemoryBase
|
||||
from ...message import Msg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
else:
|
||||
ConnectionPool = Any
|
||||
Redis = Any
|
||||
|
||||
|
||||
class RedisMemory(MemoryBase):
|
||||
"""Redis memory storage implementation, which supports session and user
|
||||
context.
|
||||
|
||||
.. note:: All the operations in this class are within a specific session
|
||||
and user context, identified by `session_id` and `user_id`. Cross-session
|
||||
or cross-user operations are not supported. For example, the
|
||||
`remove_messages` method will only remove messages that belong to the
|
||||
specified `session_id` and `user_id`.
|
||||
|
||||
.. note:: All Redis keys used by this class will be prefixed by `prefix`
|
||||
(if provided) to support multi-tenant / multi-app isolation.
|
||||
|
||||
**Mark Index Storage:**
|
||||
|
||||
This class maintains a `marks_index` (Redis Set) to efficiently track all
|
||||
mark names within a session. When a mark is created via `add_mark()`, the
|
||||
mark name is added to this set. This allows quick retrieval of all marks
|
||||
without scanning all Redis keys. The marks_index key pattern is:
|
||||
``user_id:{user_id}:session:{session_id}:marks_index``
|
||||
|
||||
Each individual mark stores its associated message IDs in a separate Redis
|
||||
List with the key pattern:
|
||||
``user_id:{user_id}:session:{session_id}:mark:{mark}``
|
||||
|
||||
"""
|
||||
|
||||
SESSION_KEY = "user_id:{user_id}:session:{session_id}:messages"
|
||||
"""Redis key pattern (without prefix) for storing message IDs (ordered) for
|
||||
a specific session.
|
||||
"""
|
||||
|
||||
SESSION_PATTERN = "user_id:{user_id}:session:{session_id}:*"
|
||||
"""Redis key pattern (without prefix) for scanning all keys belong to
|
||||
a specific user and session."""
|
||||
|
||||
MARK_KEY = "user_id:{user_id}:session:{session_id}:mark:{mark}"
|
||||
"""Redis key pattern (without prefix) for storing message IDs that belong
|
||||
to a specific mark.
|
||||
"""
|
||||
|
||||
MESSAGE_KEY = "user_id:{user_id}:session:{session_id}:msg:{msg_id}"
|
||||
"""Redis key pattern (without prefix) for storing message payload data."""
|
||||
|
||||
MARKS_INDEX_KEY = "user_id:{user_id}:session:{session_id}:marks_index"
|
||||
"""Redis key pattern (without prefix) for storing all mark names as a set.
|
||||
This is used to avoid scanning all keys to find marks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default_session",
|
||||
user_id: str = "default_user",
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str | None = None,
|
||||
connection_pool: ConnectionPool | None = None,
|
||||
key_prefix: str = "",
|
||||
key_ttl: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Redis based storage by connecting to the Redis
|
||||
server. You can provide either the connection parameters or an
|
||||
existing connection pool.
|
||||
|
||||
Args:
|
||||
session_id (`str`, default to `"default_session"`):
|
||||
The session ID for the storage.
|
||||
user_id (`str`, default to `"default_user"`):
|
||||
The user ID for the storage.
|
||||
host (`str`, default to `"localhost"`):
|
||||
The Redis server host.
|
||||
port (`int`, default to `6379`):
|
||||
The Redis server port.
|
||||
db (`int`, default to `0`):
|
||||
The Redis database index.
|
||||
password (`str | None`, optional):
|
||||
The password for the Redis server, if required.
|
||||
connection_pool (`ConnectionPool | None`, optional):
|
||||
An optional Redis connection pool. If provided, it will be used
|
||||
instead of creating a new connection.
|
||||
key_prefix (`str`, default to `""`):
|
||||
Optional Redis key prefix prepended to every key generated by
|
||||
this storage. Useful for isolating keys across
|
||||
apps/environments (e.g. `"prod"`, `"staging"`, `"myapp"`).
|
||||
key_ttl (`int | None`, default to `None`):
|
||||
The expired time in seconds for each key. If provided, the
|
||||
expiration will be refreshed on every access (sliding TTL). If
|
||||
`None`, the keys will not expire. **Note** the ttl will update
|
||||
all session related keys, so it's not recommended to set for
|
||||
large sessions.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments to pass to the Redis client.
|
||||
"""
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The 'redis' package is required for RedisStorage. "
|
||||
"Please install it via 'pip install redis[async]'.",
|
||||
) from e
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.key_prefix = key_prefix or ""
|
||||
self.key_ttl = key_ttl
|
||||
|
||||
self._client = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
connection_pool=connection_pool,
|
||||
decode_responses=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_client(self) -> Redis:
|
||||
"""Get the underlying Redis client.
|
||||
|
||||
Returns:
|
||||
`Redis`:
|
||||
The Redis client instance.
|
||||
"""
|
||||
return self._client
|
||||
|
||||
def _decode_if_bytes(self, data: Any) -> Any:
|
||||
"""Helper method to decode bytes to str if needed.
|
||||
|
||||
Args:
|
||||
data (`Any`):
|
||||
The data to decode, which may be bytes, bytearray, or str.
|
||||
|
||||
Returns:
|
||||
`Any`:
|
||||
The decoded string if input was bytes/bytearray, otherwise
|
||||
the original data.
|
||||
"""
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
return data.decode("utf-8")
|
||||
return data
|
||||
|
||||
def _decode_list(self, data_list: list) -> list:
|
||||
"""Helper method to decode a list of potential bytes.
|
||||
|
||||
Args:
|
||||
data_list (`list`):
|
||||
A list that may contain bytes, bytearray, or str elements.
|
||||
|
||||
Returns:
|
||||
`list`:
|
||||
A list with all bytes/bytearray elements decoded to str.
|
||||
"""
|
||||
return [self._decode_if_bytes(item) for item in data_list]
|
||||
|
||||
def _get_session_key(self) -> str:
|
||||
"""Get the Redis key for the current session.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing messages in the current session.
|
||||
"""
|
||||
return self.key_prefix + self.SESSION_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def _get_session_pattern(self) -> str:
|
||||
"""Get the Redis key pattern for all keys in the current session.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key pattern for all keys in the current session.
|
||||
"""
|
||||
return self.key_prefix + self.SESSION_PATTERN.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def _get_mark_key(self, mark: str) -> str:
|
||||
"""Get the Redis key for a specific mark.
|
||||
|
||||
Args:
|
||||
mark (`str`):
|
||||
The mark name.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing message IDs with the given mark.
|
||||
"""
|
||||
return self.key_prefix + self.MARK_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
mark=mark,
|
||||
)
|
||||
|
||||
def _get_mark_pattern(self) -> str:
|
||||
"""Get the Redis key pattern for all marks in the current session.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key pattern for all mark keys.
|
||||
"""
|
||||
return self.key_prefix + self.MARK_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
mark="*",
|
||||
)
|
||||
|
||||
def _get_marks_index_key(self) -> str:
|
||||
"""Get the Redis key for the marks index set.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing all mark names as a set.
|
||||
"""
|
||||
return self.key_prefix + self.MARKS_INDEX_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def _extract_mark_from_key(self, mark_key: str) -> str:
|
||||
"""Extract the mark name from a full mark key.
|
||||
|
||||
Args:
|
||||
mark_key (`str`):
|
||||
The full Redis key for a mark.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The mark name extracted from the key.
|
||||
"""
|
||||
# Remove the prefix and the base pattern to get the mark name
|
||||
# Example: "prefix:user_id:xxx:session:yyy:mark:my_mark" -> "my_mark"
|
||||
prefix_pattern = self.key_prefix + self.MARK_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
mark="",
|
||||
)
|
||||
return mark_key.replace(prefix_pattern, "")
|
||||
|
||||
def _get_message_key(self, msg_id: str) -> str:
|
||||
"""Get the Redis key for a specific message.
|
||||
|
||||
Args:
|
||||
msg_id (`str`):
|
||||
The message ID.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing the message data.
|
||||
"""
|
||||
return self.key_prefix + self.MESSAGE_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
msg_id=msg_id,
|
||||
)
|
||||
|
||||
async def _refresh_session_ttl(
|
||||
self,
|
||||
pipe: Any | None = None,
|
||||
) -> None:
|
||||
"""Refresh the TTL for the session keys (if `key_ttl` is set).
|
||||
|
||||
Args:
|
||||
pipe (`Any | None`, optional):
|
||||
An optional Redis pipeline to use. If `None`, a new pipeline
|
||||
will be created and executed immediately. If provided, the
|
||||
expired commands will be added to the pipeline without
|
||||
executing it.
|
||||
"""
|
||||
if self.key_ttl is None:
|
||||
return
|
||||
|
||||
# Create a new pipeline if not provided
|
||||
should_execute = pipe is None
|
||||
if pipe is None:
|
||||
pipe = self._client.pipeline()
|
||||
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(
|
||||
cursor,
|
||||
match=self._get_session_pattern(),
|
||||
count=100,
|
||||
)
|
||||
# Decode keys if they are bytes
|
||||
keys = self._decode_list(keys)
|
||||
for key in keys:
|
||||
await pipe.expire(key, self.key_ttl)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
if should_execute:
|
||||
await pipe.execute()
|
||||
|
||||
async def _scan_and_migrate_marks(self) -> list[str]:
|
||||
"""Scan all mark keys and migrate them to the marks index.
|
||||
|
||||
This method is only called once for old data that doesn't have
|
||||
a marks index yet. After migration, the marks index will be
|
||||
maintained automatically.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
The list of all mark keys found.
|
||||
"""
|
||||
mark_keys = []
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(
|
||||
cursor,
|
||||
match=self._get_mark_pattern(),
|
||||
count=50,
|
||||
)
|
||||
keys = self._decode_list(keys)
|
||||
mark_keys.extend(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Build the marks index
|
||||
if mark_keys:
|
||||
pipe = self._client.pipeline()
|
||||
for mark_key in mark_keys:
|
||||
mark = self._extract_mark_from_key(mark_key)
|
||||
await pipe.sadd(self._get_marks_index_key(), mark)
|
||||
await pipe.execute()
|
||||
|
||||
return mark_keys
|
||||
|
||||
async def _get_all_mark_keys(self) -> list[str]:
|
||||
"""Get all mark keys, compatible with both old and new data.
|
||||
|
||||
For new data (with marks index), this method uses the index directly.
|
||||
For old data (without marks index), this method scans once and
|
||||
migrates to the new structure.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
The list of all mark keys.
|
||||
"""
|
||||
marks_index_key = self._get_marks_index_key()
|
||||
|
||||
# Try to read from the index first
|
||||
marks = await self._client.smembers(marks_index_key)
|
||||
if marks:
|
||||
# Index exists, use it
|
||||
marks = self._decode_list(list(marks))
|
||||
return [self._get_mark_key(mark) for mark in marks]
|
||||
|
||||
# Index doesn't exist, check if this is a new session
|
||||
session_exists = await self._client.exists(self._get_session_key())
|
||||
if not session_exists:
|
||||
# New session, no data at all, return empty
|
||||
return []
|
||||
|
||||
# Old session without index, need to scan and migrate (only once)
|
||||
mark_keys = await self._scan_and_migrate_marks()
|
||||
return mark_keys
|
||||
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
# Type checks
|
||||
if not (mark is None or isinstance(mark, str)):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or None, but got {type(mark)}.",
|
||||
)
|
||||
|
||||
if not (exclude_mark is None or isinstance(exclude_mark, str)):
|
||||
raise TypeError(
|
||||
f"The exclude_mark should be a string or None, but got "
|
||||
f"{type(exclude_mark)}.",
|
||||
)
|
||||
|
||||
if mark is None:
|
||||
# Obtain the message IDs from the session list
|
||||
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
|
||||
|
||||
else:
|
||||
# Obtain the message IDs from the mark list
|
||||
msg_ids = await self._client.lrange(
|
||||
self._get_mark_key(mark),
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
|
||||
msg_ids = self._decode_list(msg_ids)
|
||||
|
||||
# Exclude messages by exclude_mark
|
||||
if exclude_mark:
|
||||
exclude_msg_ids = await self._client.lrange(
|
||||
self._get_mark_key(exclude_mark),
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
exclude_msg_ids = self._decode_list(exclude_msg_ids)
|
||||
msg_ids = [_ for _ in msg_ids if _ not in exclude_msg_ids]
|
||||
|
||||
# Use mget for batch retrieval to avoid N+1 queries
|
||||
messages: list[Msg] = []
|
||||
if msg_ids:
|
||||
msg_keys = [self._get_message_key(msg_id) for msg_id in msg_ids]
|
||||
msg_data_list = await self._client.mget(msg_keys)
|
||||
|
||||
for msg_data in msg_data_list:
|
||||
if msg_data is not None:
|
||||
# Decode if bytes
|
||||
msg_data = self._decode_if_bytes(msg_data)
|
||||
msg_dict = json.loads(msg_data)
|
||||
messages.append(Msg.from_dict(msg_dict))
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl()
|
||||
|
||||
if prepend_summary and self._compressed_summary:
|
||||
return [
|
||||
Msg(
|
||||
"user",
|
||||
self._compressed_summary,
|
||||
"user",
|
||||
),
|
||||
*messages,
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
skip_duplicated: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message into the storage with the given mark (if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg]`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
skip_duplicated (`bool`, defaults to `True`):
|
||||
If `True`, skip messages with duplicate IDs that already exist
|
||||
in the storage. If `False`, allow duplicate message IDs to be
|
||||
added to the session list (though the message data will be
|
||||
overwritten).
|
||||
"""
|
||||
if memories is None:
|
||||
return
|
||||
|
||||
if isinstance(memories, Msg):
|
||||
memories = [memories]
|
||||
|
||||
# Normalize marks to a list
|
||||
if marks is None:
|
||||
mark_list = []
|
||||
elif isinstance(marks, str):
|
||||
mark_list = [marks]
|
||||
else:
|
||||
mark_list = marks
|
||||
|
||||
# Filter out existing messages if skip_duplicated is True
|
||||
messages_to_add = memories
|
||||
if skip_duplicated:
|
||||
# Get all existing message IDs in the current session
|
||||
existing_msg_ids = await self._client.lrange(
|
||||
self._get_session_key(),
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
existing_msg_ids = self._decode_list(existing_msg_ids)
|
||||
existing_msg_ids_set = set(existing_msg_ids)
|
||||
|
||||
# Filter out messages that already exist
|
||||
messages_to_add = [
|
||||
m for m in memories if m.id not in existing_msg_ids_set
|
||||
]
|
||||
|
||||
# If all messages are duplicates, return early
|
||||
if not messages_to_add:
|
||||
return
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
pipe = self._client.pipeline()
|
||||
|
||||
# Push message ids into the session list
|
||||
if messages_to_add:
|
||||
await pipe.rpush(
|
||||
self._get_session_key(),
|
||||
*[m.id for m in messages_to_add],
|
||||
)
|
||||
|
||||
# Store message data and marks
|
||||
for m in messages_to_add:
|
||||
# Record the message data
|
||||
await pipe.set(
|
||||
self._get_message_key(m.id),
|
||||
json.dumps(m.to_dict(), ensure_ascii=False),
|
||||
)
|
||||
|
||||
# Record the marks if provided
|
||||
for mark in mark_list:
|
||||
await pipe.rpush(self._get_mark_key(mark), m.id)
|
||||
# Maintain the marks index
|
||||
await pipe.sadd(self._get_marks_index_key(), mark)
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl(pipe=pipe)
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if not msg_ids:
|
||||
return 0
|
||||
|
||||
# Get all mark keys using the new method (compatible with old data)
|
||||
mark_keys = await self._get_all_mark_keys()
|
||||
|
||||
pipe = self._client.pipeline()
|
||||
for msg_id in msg_ids:
|
||||
# Remove from the session (0 means remove all occurrences)
|
||||
await pipe.lrem(self._get_session_key(), 0, msg_id)
|
||||
|
||||
# Remove the message data
|
||||
await pipe.delete(self._get_message_key(msg_id))
|
||||
|
||||
# Remove from all marks
|
||||
for mark_key in mark_keys:
|
||||
await pipe.lrem(mark_key, 0, msg_id)
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl(pipe=pipe)
|
||||
|
||||
results = await pipe.execute()
|
||||
|
||||
# Count actual deletions from lrem results (every 3rd result
|
||||
# starting from 0)
|
||||
removed_count = sum(
|
||||
1
|
||||
for i in range(
|
||||
0,
|
||||
len(msg_ids) * (2 + len(mark_keys)),
|
||||
2 + len(mark_keys),
|
||||
)
|
||||
if results[i] > 0
|
||||
)
|
||||
|
||||
return removed_count
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the storage by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if isinstance(mark, str):
|
||||
mark = [mark]
|
||||
|
||||
total_removed = 0
|
||||
|
||||
for m in mark:
|
||||
mark_key = self._get_mark_key(m)
|
||||
msg_ids = await self._client.lrange(mark_key, 0, -1)
|
||||
msg_ids = self._decode_list(msg_ids)
|
||||
|
||||
if not msg_ids:
|
||||
continue
|
||||
|
||||
# Remove messages by IDs
|
||||
removed_count = await self.delete(
|
||||
msg_ids,
|
||||
)
|
||||
total_removed += removed_count
|
||||
|
||||
# Delete the mark list
|
||||
await self._client.delete(mark_key)
|
||||
|
||||
# Remove from the marks index
|
||||
await self._client.srem(self._get_marks_index_key(), m)
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl()
|
||||
|
||||
return total_removed
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all messages belong to this session from the storage."""
|
||||
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
|
||||
msg_ids = self._decode_list(msg_ids)
|
||||
|
||||
# Get all mark keys using the new method (compatible with old data)
|
||||
mark_keys = await self._get_all_mark_keys()
|
||||
|
||||
pipe = self._client.pipeline()
|
||||
|
||||
for msg_id in msg_ids:
|
||||
# Remove the message data
|
||||
await pipe.delete(self._get_message_key(msg_id))
|
||||
|
||||
# Delete the session list
|
||||
await pipe.delete(self._get_session_key())
|
||||
|
||||
# Delete all mark lists
|
||||
for mark_key in mark_keys:
|
||||
await pipe.delete(mark_key)
|
||||
|
||||
# Delete the marks index
|
||||
await pipe.delete(self._get_marks_index_key())
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the number of messages in the storage.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages in the storage.
|
||||
"""
|
||||
size = await self._client.llen(self._get_session_key())
|
||||
await self._refresh_session_ttl()
|
||||
return size
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
# Determine which message IDs to update
|
||||
# Get source key based on old_mark
|
||||
source_key = (
|
||||
self._get_mark_key(old_mark)
|
||||
if old_mark is not None
|
||||
else self._get_session_key()
|
||||
)
|
||||
mark_msg_ids = await self._client.lrange(source_key, 0, -1)
|
||||
mark_msg_ids = self._decode_list(mark_msg_ids)
|
||||
|
||||
# Check if we're removing all messages from old_mark
|
||||
removing_all_from_old_mark = old_mark is not None and (
|
||||
msg_ids is None or all(mid in set(msg_ids) for mid in mark_msg_ids)
|
||||
)
|
||||
|
||||
# Filter by msg_ids if provided
|
||||
if msg_ids is not None:
|
||||
msg_ids_set = set(msg_ids)
|
||||
mark_msg_ids = [mid for mid in mark_msg_ids if mid in msg_ids_set]
|
||||
|
||||
if not mark_msg_ids:
|
||||
return 0
|
||||
|
||||
# Get existing IDs in new_mark list once (if needed)
|
||||
existing_ids_set = set()
|
||||
new_mark_key = None
|
||||
if new_mark is not None:
|
||||
new_mark_key = self._get_mark_key(new_mark)
|
||||
existing_ids = await self._client.lrange(new_mark_key, 0, -1)
|
||||
existing_ids = self._decode_list(existing_ids)
|
||||
existing_ids_set = set(existing_ids)
|
||||
|
||||
# Use pipeline for batch operations
|
||||
pipe = self._client.pipeline()
|
||||
updated_count = 0
|
||||
|
||||
for msg_id in mark_msg_ids:
|
||||
# Remove from old_mark list if applicable
|
||||
if old_mark is not None:
|
||||
await pipe.lrem(
|
||||
self._get_mark_key(old_mark),
|
||||
0,
|
||||
msg_id,
|
||||
)
|
||||
|
||||
# Add to new_mark list only if not already present
|
||||
if new_mark is not None and msg_id not in existing_ids_set:
|
||||
await pipe.rpush(new_mark_key, msg_id)
|
||||
existing_ids_set.add(msg_id)
|
||||
# Maintain the marks index
|
||||
await pipe.sadd(self._get_marks_index_key(), new_mark)
|
||||
|
||||
# Count update only if we actually did something
|
||||
if old_mark is not None or new_mark is not None:
|
||||
updated_count += 1
|
||||
|
||||
# Clean up old_mark only if we removed ALL messages from it
|
||||
if old_mark is not None and removing_all_from_old_mark:
|
||||
old_mark_key = self._get_mark_key(old_mark)
|
||||
# After lrem operations, the old mark list will be empty
|
||||
# Delete the mark key and remove from index
|
||||
await pipe.delete(old_mark_key)
|
||||
await pipe.srem(self._get_marks_index_key(), old_mark)
|
||||
|
||||
await self._refresh_session_ttl(pipe=pipe)
|
||||
|
||||
await pipe.execute()
|
||||
return updated_count
|
||||
|
||||
async def close(self, close_connection_pool: bool | None = None) -> None:
|
||||
"""Close the Redis client connection.
|
||||
|
||||
Args:
|
||||
close_connection_pool (`bool | None`, optional):
|
||||
Decides whether to close the connection pool used by this
|
||||
Redis client, overriding Redis.auto_close_connection_pool.
|
||||
By default, let Redis.auto_close_connection_pool decide
|
||||
whether to close the connection pool
|
||||
"""
|
||||
await self._client.aclose(close_connection_pool=close_connection_pool)
|
||||
|
||||
async def __aenter__(self) -> "RedisMemory":
|
||||
"""Enter the async context manager.
|
||||
|
||||
Returns:
|
||||
`RedisMemory`:
|
||||
The memory instance itself.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: Any,
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the session.
|
||||
|
||||
Args:
|
||||
exc_type (`type[BaseException] | None`):
|
||||
The exception type if an exception was raised.
|
||||
exc_value (`BaseException | None`):
|
||||
The exception instance if an exception was raised.
|
||||
traceback (`Any`):
|
||||
The traceback object if an exception was raised.
|
||||
"""
|
||||
await self.close()
|
||||
873
src/agentscope/memory/_working_memory/_sqlalchemy_memory.py
Normal file
873
src/agentscope/memory/_working_memory/_sqlalchemy_memory.py
Normal file
@@ -0,0 +1,873 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The SQLAlchemy database storage module, which supports storing messages in
|
||||
a SQL database using SQLAlchemy ORM (e.g., SQLite, PostgreSQL, MySQL)."""
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
String,
|
||||
JSON,
|
||||
BigInteger,
|
||||
ForeignKey,
|
||||
select,
|
||||
delete,
|
||||
update,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ._base import MemoryBase
|
||||
from ...message import Msg
|
||||
|
||||
Base: Any = declarative_base()
|
||||
|
||||
|
||||
class AsyncSQLAlchemyMemory(MemoryBase):
|
||||
"""The SQLAlchemy memory storage class for storing messages in a SQL
|
||||
database using SQLAlchemy ORM, such as SQLite, PostgreSQL, MySQL, etc.
|
||||
|
||||
.. note:: All the operations in this class are within a specific session
|
||||
and user context, identified by `session_id` and `user_id`. Cross-session
|
||||
or cross-user operations are not supported. For example, the
|
||||
`remove_messages` method will only remove messages that belong to the
|
||||
specified `session_id` and `user_id`.
|
||||
|
||||
"""
|
||||
|
||||
class MessageTable(Base):
|
||||
"""The default message table definition."""
|
||||
|
||||
__tablename__ = "message"
|
||||
"""The table name"""
|
||||
|
||||
id = Column(String(255), primary_key=True)
|
||||
"""The id column, we use the f"{user_id}-{session_id}-{message_id}"
|
||||
as the primary key to ensure uniqueness across users and sessions."""
|
||||
|
||||
msg = Column(JSON, nullable=False)
|
||||
"""The message JSON content column"""
|
||||
|
||||
session = relationship(
|
||||
"SessionTable",
|
||||
back_populates="messages",
|
||||
)
|
||||
"""The foreign key to the session id relationship"""
|
||||
|
||||
session_id = Column(
|
||||
String(255),
|
||||
ForeignKey("session.id"),
|
||||
nullable=False,
|
||||
)
|
||||
"""The foreign key to the session id"""
|
||||
|
||||
index = Column(BigInteger, nullable=False, index=True)
|
||||
"""The index column for ordering messages, so that we can retrieve
|
||||
messages in the order they were added."""
|
||||
|
||||
class MessageMarkTable(Base):
|
||||
"""The default message mark table definition."""
|
||||
|
||||
__tablename__ = "message_mark"
|
||||
"""The table name"""
|
||||
|
||||
msg_id = Column(
|
||||
String(255),
|
||||
ForeignKey("message.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
"""The message id column"""
|
||||
|
||||
mark = Column(String(255), primary_key=True)
|
||||
"""The mark column"""
|
||||
|
||||
class SessionTable(Base):
|
||||
"""The default session table definition."""
|
||||
|
||||
__tablename__ = "session"
|
||||
"""The table name"""
|
||||
|
||||
id = Column(String(255), primary_key=True)
|
||||
"""The session id column"""
|
||||
|
||||
user = relationship("UserTable", back_populates="sessions")
|
||||
"""The foreign key to the user id relationship"""
|
||||
|
||||
user_id = Column(String(255), ForeignKey("users.id"), nullable=False)
|
||||
"""The foreign key to the user id"""
|
||||
|
||||
messages = relationship("MessageTable", back_populates="session")
|
||||
"""The relationship to messages"""
|
||||
|
||||
class UserTable(Base):
|
||||
"""The default user table definition."""
|
||||
|
||||
__tablename__ = "users"
|
||||
"""The table name"""
|
||||
|
||||
id = Column(String(255), primary_key=True)
|
||||
"""The user id column"""
|
||||
|
||||
sessions = relationship("SessionTable", back_populates="user")
|
||||
"""The relationship to sessions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_or_session: AsyncEngine | AsyncSession,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the SqlAlchemyDBStorage with a SQLAlchemy session.
|
||||
|
||||
Args:
|
||||
engine_or_session (`AsyncEngine | AsyncSession`):
|
||||
The SQLAlchemy asynchronous engine or session to use for
|
||||
database operations. If you're using a connection pool, maybe
|
||||
you want to pass in an `AsyncSession` instance.
|
||||
session_id (`str | None`, optional):
|
||||
The session ID for the messages. If `None`, a default session
|
||||
ID will be used.
|
||||
user_id (`str | None`, optional):
|
||||
The user ID for the messages. If `None`, a default user ID
|
||||
will be used.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the `engine` parameter is not an instance of
|
||||
`sqlalchemy.ext.asyncio.AsyncEngine` or `sqlalchemy.
|
||||
ext.asyncio.AsyncSession`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._db_session: AsyncSession | None = None
|
||||
|
||||
if isinstance(engine_or_session, AsyncEngine):
|
||||
self._session_factory = async_sessionmaker(
|
||||
bind=engine_or_session,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
elif isinstance(engine_or_session, AsyncSession):
|
||||
self._session_factory = None
|
||||
self._db_session = engine_or_session
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"The 'engine_or_session' parameter must be an instance of "
|
||||
"sqlalchemy.ext.asyncio.AsyncEngine.",
|
||||
)
|
||||
|
||||
self.session_id = session_id or "default_session"
|
||||
self.user_id = user_id or "default_user"
|
||||
|
||||
# Flag to track if tables and records have been initialized
|
||||
self._initialized = False
|
||||
|
||||
def _make_message_id(self, msg_id: str) -> str:
|
||||
"""Generate a composite primary key for a message.
|
||||
|
||||
Args:
|
||||
msg_id (`str`):
|
||||
The original message ID.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The composite primary key in the format
|
||||
"{user_id}-{session_id}-{message_id}".
|
||||
"""
|
||||
return f"{self.user_id}-{self.session_id}-{msg_id}"
|
||||
|
||||
@property
|
||||
def session(self) -> AsyncSession:
|
||||
"""Get the current database session, creating one if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
`AsyncSession`:
|
||||
The current database session.
|
||||
|
||||
Note:
|
||||
- If an external session was provided, it will be returned as-is
|
||||
- If using internal session factory, a new session will be created
|
||||
if the current one is None or inactive, and _initialized flag
|
||||
will be reset to ensure proper re-initialization
|
||||
"""
|
||||
# External session: return as-is (managed by caller)
|
||||
if self._session_factory is None:
|
||||
return self._db_session
|
||||
|
||||
# Internal session: check validity and recreate if needed
|
||||
if self._db_session is None or not self._db_session.is_active:
|
||||
self._db_session = self._session_factory()
|
||||
# Reset initialized flag when creating new session
|
||||
self._initialized = False
|
||||
|
||||
return self._db_session
|
||||
|
||||
async def _create_table(self) -> None:
|
||||
"""Create tables in database."""
|
||||
# Skip if already initialized
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Obtain the engine first
|
||||
engine: AsyncEngine = self.session.bind
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Track if we need to commit
|
||||
needs_commit = False
|
||||
|
||||
# Create user record if not exists
|
||||
result = await self.session.execute(
|
||||
select(self.UserTable).filter(
|
||||
self.UserTable.id == self.user_id,
|
||||
),
|
||||
)
|
||||
user_record = result.scalar_one_or_none()
|
||||
|
||||
if user_record is None:
|
||||
user_record = self.UserTable(
|
||||
id=self.user_id,
|
||||
)
|
||||
self.session.add(user_record)
|
||||
needs_commit = True
|
||||
|
||||
# Create session record if not exists
|
||||
result = await self.session.execute(
|
||||
select(self.SessionTable).filter(
|
||||
self.SessionTable.id == self.session_id,
|
||||
),
|
||||
)
|
||||
session_record = result.scalar_one_or_none()
|
||||
|
||||
if session_record is None:
|
||||
session_record = self.SessionTable(
|
||||
id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
self.session.add(session_record)
|
||||
needs_commit = True
|
||||
|
||||
# Commit once if any records were added
|
||||
if needs_commit:
|
||||
await self.session.commit()
|
||||
|
||||
# Mark as initialized
|
||||
self._initialized = True
|
||||
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or None.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
# Type checks
|
||||
if mark is not None and not isinstance(mark, str):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or None, but got {type(mark)}.",
|
||||
)
|
||||
|
||||
if exclude_mark is not None and not isinstance(exclude_mark, str):
|
||||
raise TypeError(
|
||||
f"The exclude_mark should be a string or None, but got "
|
||||
f"{type(exclude_mark)}.",
|
||||
)
|
||||
|
||||
await self._create_table()
|
||||
|
||||
# Step 1: First filter by session_id to narrow down the dataset
|
||||
# This ensures the database uses the session_id index first
|
||||
base_query = select(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
)
|
||||
|
||||
# Step 2: Apply mark filtering if provided
|
||||
if mark:
|
||||
# Join with mark table only on the session-filtered messages
|
||||
base_query = base_query.join(
|
||||
self.MessageMarkTable,
|
||||
self.MessageTable.id == self.MessageMarkTable.msg_id,
|
||||
).filter(
|
||||
self.MessageMarkTable.mark == mark,
|
||||
)
|
||||
|
||||
# Step 3: Apply exclude_mark filtering if provided
|
||||
if exclude_mark:
|
||||
# Use a subquery to find message IDs with the exclude_mark
|
||||
# within the current session only
|
||||
exclude_subquery = (
|
||||
select(self.MessageMarkTable.msg_id)
|
||||
.filter(
|
||||
self.MessageMarkTable.msg_id.in_(
|
||||
select(self.MessageTable.id).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
),
|
||||
self.MessageMarkTable.mark == exclude_mark,
|
||||
)
|
||||
.scalar_subquery()
|
||||
)
|
||||
# Exclude messages whose IDs are in the subquery
|
||||
base_query = base_query.filter(
|
||||
self.MessageTable.id.notin_(exclude_subquery),
|
||||
)
|
||||
|
||||
# Step 4: Order by index to maintain message order
|
||||
query = base_query.order_by(self.MessageTable.index)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
results = result.scalars().all()
|
||||
|
||||
msgs = [Msg.from_dict(result.msg) for result in results]
|
||||
if prepend_summary and self._compressed_summary:
|
||||
return [
|
||||
Msg(
|
||||
"user",
|
||||
self._compressed_summary,
|
||||
"user",
|
||||
),
|
||||
*msgs,
|
||||
]
|
||||
|
||||
return msgs
|
||||
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
skip_duplicated: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message into the storage with the given mark (if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg] | None`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
skip_duplicated (`bool`, defaults to `True`):
|
||||
If `True`, skip messages with duplicate IDs that already exist
|
||||
in the storage. If `False`, raise an `IntegrityError` when
|
||||
attempting to add a message with an existing ID.
|
||||
|
||||
Raises:
|
||||
`IntegrityError`:
|
||||
If a message with the same ID already exists in the storage
|
||||
and `skip_duplicated` is set to `False`.
|
||||
"""
|
||||
if memories is None:
|
||||
return
|
||||
|
||||
# Type checking
|
||||
if isinstance(memories, Msg):
|
||||
memories = [memories]
|
||||
elif not (
|
||||
isinstance(memories, list)
|
||||
and all(isinstance(_, Msg) for _ in memories)
|
||||
):
|
||||
raise TypeError(
|
||||
"The 'memories' parameter must be a Msg instance or a list of "
|
||||
f"Msg instances, but got {type(memories)}.",
|
||||
)
|
||||
|
||||
if isinstance(marks, str):
|
||||
marks = [marks]
|
||||
elif marks is not None and not (
|
||||
isinstance(marks, list) and all(isinstance(m, str) for m in marks)
|
||||
):
|
||||
raise TypeError(
|
||||
"The 'marks' parameter must be a string or a list of strings, "
|
||||
f"but got {type(marks)}.",
|
||||
)
|
||||
|
||||
# Create table if not exists
|
||||
await self._create_table()
|
||||
|
||||
# If skip_duplicated is True, filter out existing messages
|
||||
messages_to_add = memories
|
||||
if skip_duplicated:
|
||||
existing_msg_ids = set()
|
||||
result = await self.session.execute(
|
||||
select(self.MessageTable.id).filter(
|
||||
self.MessageTable.id.in_(
|
||||
[self._make_message_id(m.id) for m in memories],
|
||||
),
|
||||
),
|
||||
)
|
||||
existing_msg_ids = {row[0] for row in result.fetchall()}
|
||||
|
||||
messages_to_add = [
|
||||
m
|
||||
for m in memories
|
||||
if self._make_message_id(m.id) not in existing_msg_ids
|
||||
]
|
||||
|
||||
# If all messages are duplicates, return early
|
||||
if not messages_to_add:
|
||||
return
|
||||
|
||||
# Get the starting index once to avoid race conditions
|
||||
start_index = await self._get_next_index()
|
||||
|
||||
# Add messages to message table
|
||||
for i, m in enumerate(messages_to_add):
|
||||
message_record = self.MessageTable(
|
||||
id=self._make_message_id(m.id),
|
||||
msg=m.to_dict(),
|
||||
session_id=self.session_id,
|
||||
index=start_index + i,
|
||||
)
|
||||
self.session.add(message_record)
|
||||
|
||||
# Create mark records if marks are provided (use bulk insert)
|
||||
if marks:
|
||||
mark_records = [
|
||||
{"msg_id": self._make_message_id(msg.id), "mark": mark}
|
||||
for msg in messages_to_add
|
||||
for mark in marks
|
||||
]
|
||||
if mark_records:
|
||||
if skip_duplicated:
|
||||
# Query existing mark combinations to avoid duplicates
|
||||
result = await self.session.execute(
|
||||
select(
|
||||
self.MessageMarkTable.msg_id,
|
||||
self.MessageMarkTable.mark,
|
||||
),
|
||||
)
|
||||
existing_marks = {
|
||||
(row[0], row[1]) for row in result.fetchall()
|
||||
}
|
||||
|
||||
# Filter out existing mark combinations
|
||||
mark_records = [
|
||||
r
|
||||
for r in mark_records
|
||||
if (r["msg_id"], r["mark"]) not in existing_marks
|
||||
]
|
||||
|
||||
if mark_records:
|
||||
await self.session.run_sync(
|
||||
lambda session: session.bulk_insert_mappings(
|
||||
self.MessageMarkTable,
|
||||
mark_records,
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_next_index(self) -> int:
|
||||
"""Get the next index for a new message in the current session.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The next index value.
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(self.MessageTable.index)
|
||||
.filter(self.MessageTable.session_id == self.session_id)
|
||||
.order_by(self.MessageTable.index.desc())
|
||||
.limit(1),
|
||||
)
|
||||
max_index = result.scalar_one_or_none()
|
||||
return (max_index + 1) if max_index is not None else 0
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the size of the messages in the storage."""
|
||||
result = await self.session.execute(
|
||||
select(func.count(self.MessageTable.id)).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all messages from the storage."""
|
||||
# Delete all marks for messages in this session
|
||||
await self.session.execute(
|
||||
delete(self.MessageMarkTable).where(
|
||||
self.MessageMarkTable.msg_id.in_(
|
||||
select(self.MessageTable.id).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Then delete all messages
|
||||
await self.session.execute(
|
||||
delete(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the storage by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if isinstance(mark, str):
|
||||
mark = [mark]
|
||||
|
||||
# First, find message IDs that have the specified marks
|
||||
query = (
|
||||
select(self.MessageTable.id)
|
||||
.join(
|
||||
self.MessageMarkTable,
|
||||
self.MessageTable.id == self.MessageMarkTable.msg_id,
|
||||
)
|
||||
.filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
self.MessageMarkTable.mark.in_(mark),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
msg_ids = [row[0] for row in result.all()]
|
||||
|
||||
if not msg_ids:
|
||||
return 0
|
||||
|
||||
# Store the count before deletion
|
||||
deleted_count = len(msg_ids)
|
||||
|
||||
# Delete marks first
|
||||
await self.session.execute(
|
||||
delete(self.MessageMarkTable).filter(
|
||||
self.MessageMarkTable.msg_id.in_(msg_ids),
|
||||
),
|
||||
)
|
||||
|
||||
# Then delete the messages
|
||||
await self.session.execute(
|
||||
delete(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
self.MessageTable.id.in_(msg_ids),
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
return deleted_count
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
.. note:: Although MessageMarkTable has CASCADE delete on foreign key,
|
||||
we explicitly delete marks first for reliability across all database
|
||||
engines and configurations. SQLAlchemy's bulk delete bypasses
|
||||
ORM-level cascades, and SQLite requires foreign keys to be
|
||||
explicitly enabled.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
# Convert to composite keys
|
||||
composite_ids = [self._make_message_id(msg_id) for msg_id in msg_ids]
|
||||
|
||||
if not composite_ids:
|
||||
return 0
|
||||
|
||||
# Store the count before deletion
|
||||
deleted_count = len(composite_ids)
|
||||
|
||||
# Delete related marks first (explicit cleanup for reliability)
|
||||
await self.session.execute(
|
||||
delete(self.MessageMarkTable).filter(
|
||||
self.MessageMarkTable.msg_id.in_(composite_ids),
|
||||
),
|
||||
)
|
||||
|
||||
# Then delete the messages
|
||||
await self.session.execute(
|
||||
delete(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
self.MessageTable.id.in_(composite_ids),
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
return deleted_count
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
|
||||
# Type checking
|
||||
if new_mark is not None and not isinstance(new_mark, str):
|
||||
raise ValueError(
|
||||
f"The 'new_mark' parameter must be a string or None, "
|
||||
f"but got {type(new_mark)}.",
|
||||
)
|
||||
|
||||
if old_mark is not None and not isinstance(old_mark, str):
|
||||
raise ValueError(
|
||||
f"The 'old_mark' parameter must be a string or None, "
|
||||
f"but got {type(old_mark)}.",
|
||||
)
|
||||
|
||||
if msg_ids is not None and not (
|
||||
isinstance(msg_ids, list)
|
||||
and all(isinstance(_, str) for _ in msg_ids)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The 'msg_ids' parameter must be a list of strings or None, "
|
||||
f"but got {type(msg_ids)}.",
|
||||
)
|
||||
|
||||
# First obtain the message ids that belong to this session
|
||||
query = select(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
)
|
||||
|
||||
# Filter by msg_ids if provided
|
||||
if msg_ids is not None:
|
||||
# Convert to composite keys
|
||||
composite_ids = [
|
||||
self._make_message_id(msg_id) for msg_id in msg_ids
|
||||
]
|
||||
query = query.filter(self.MessageTable.id.in_(composite_ids))
|
||||
|
||||
# Filter by old_mark if provided
|
||||
if old_mark is not None:
|
||||
query = query.join(
|
||||
self.MessageMarkTable,
|
||||
self.MessageTable.id == self.MessageMarkTable.msg_id,
|
||||
).filter(self.MessageMarkTable.mark == old_mark)
|
||||
|
||||
# Obtain the message records
|
||||
result = await self.session.execute(query)
|
||||
msg_ids = [str(_.id) for _ in result.scalars().all()]
|
||||
|
||||
# Return early if no messages found
|
||||
if not msg_ids:
|
||||
return 0
|
||||
|
||||
if new_mark:
|
||||
if old_mark:
|
||||
# Replace old_mark with new_mark
|
||||
return await self._replace_message_mark(
|
||||
msg_ids=msg_ids,
|
||||
old_mark=old_mark,
|
||||
new_mark=new_mark,
|
||||
)
|
||||
|
||||
# Add new_mark to the messages
|
||||
return await self._add_message_mark(
|
||||
msg_ids=msg_ids,
|
||||
mark=new_mark,
|
||||
)
|
||||
|
||||
# Remove all marks from the messages
|
||||
return await self._remove_message_mark(
|
||||
msg_ids=msg_ids,
|
||||
old_mark=old_mark,
|
||||
)
|
||||
|
||||
async def _replace_message_mark(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
old_mark: str,
|
||||
new_mark: str,
|
||||
) -> int:
|
||||
"""Replace the old mark with the new mark for the given messages by
|
||||
updating records in the message_mark table.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be updated.
|
||||
old_mark (`str`):
|
||||
The old mark to be replaced.
|
||||
new_mark (`str`):
|
||||
The new mark to be set.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
|
||||
await self.session.execute(
|
||||
update(self.MessageMarkTable)
|
||||
.filter(
|
||||
self.MessageMarkTable.msg_id.in_(msg_ids),
|
||||
self.MessageMarkTable.mark == old_mark,
|
||||
)
|
||||
.values(mark=new_mark),
|
||||
)
|
||||
await self.session.commit()
|
||||
return len(msg_ids)
|
||||
|
||||
async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int:
|
||||
"""Mark the messages with the given mark by adding records to the
|
||||
message_mark table.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be marked.
|
||||
mark (`str`):
|
||||
The mark to be added to the messages.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages marked.
|
||||
"""
|
||||
# Use bulk insert for better performance
|
||||
mark_records = [{"msg_id": msg_id, "mark": mark} for msg_id in msg_ids]
|
||||
|
||||
if mark_records:
|
||||
await self.session.run_sync(
|
||||
lambda session: session.bulk_insert_mappings(
|
||||
self.MessageMarkTable,
|
||||
mark_records,
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
return len(msg_ids)
|
||||
|
||||
async def _remove_message_mark(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
old_mark: str | None,
|
||||
) -> int:
|
||||
"""Remove marks from the messages by deleting records from the
|
||||
message_mark table.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be unmarked.
|
||||
old_mark (`str | None`):
|
||||
The old mark to be removed. If `None`, all marks will be
|
||||
removed from the messages.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages unmarked.
|
||||
"""
|
||||
delete_query = delete(self.MessageMarkTable).filter(
|
||||
self.MessageMarkTable.msg_id.in_(msg_ids),
|
||||
)
|
||||
|
||||
if old_mark:
|
||||
delete_query = delete_query.filter(
|
||||
self.MessageMarkTable.mark == old_mark,
|
||||
)
|
||||
|
||||
await self.session.execute(delete_query)
|
||||
await self.session.commit()
|
||||
return len(msg_ids)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the database session."""
|
||||
if self._db_session and self._db_session.is_active:
|
||||
await self._db_session.close()
|
||||
|
||||
self._db_session = None
|
||||
self._initialized = False
|
||||
|
||||
async def __aenter__(self) -> "AsyncSQLAlchemyMemory":
|
||||
"""Enter the async context manager.
|
||||
|
||||
Returns:
|
||||
`AsyncSQLAlchemyMemory`:
|
||||
The memory instance itself.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: Any,
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the session.
|
||||
|
||||
Args:
|
||||
exc_type (`type[BaseException] | None`):
|
||||
The exception type if an exception was raised.
|
||||
exc_value (`BaseException | None`):
|
||||
The exception instance if an exception was raised.
|
||||
traceback (`Any`):
|
||||
The traceback object if an exception was raised.
|
||||
"""
|
||||
await self.close()
|
||||
Reference in New Issue
Block a user