chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled

This commit is contained in:
2026-03-02 18:21:40 +08:00
commit a842f1861f
561 changed files with 91892 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
"""The memory module."""
from ._working_memory import (
MemoryBase,
InMemoryMemory,
RedisMemory,
AsyncSQLAlchemyMemory,
)
from ._long_term_memory import (
LongTermMemoryBase,
Mem0LongTermMemory,
ReMePersonalLongTermMemory,
ReMeTaskLongTermMemory,
ReMeToolLongTermMemory,
)
__all__ = [
# Working memory
"MemoryBase",
"InMemoryMemory",
"RedisMemory",
"AsyncSQLAlchemyMemory",
# Long-term memory
"LongTermMemoryBase",
"Mem0LongTermMemory",
"ReMePersonalLongTermMemory",
"ReMeTaskLongTermMemory",
"ReMeToolLongTermMemory",
]

View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""The long-term memory module for AgentScope."""
from ._long_term_memory_base import LongTermMemoryBase
from ._mem0 import Mem0LongTermMemory
from ._reme import (
ReMePersonalLongTermMemory,
ReMeTaskLongTermMemory,
ReMeToolLongTermMemory,
)
__all__ = [
"LongTermMemoryBase",
"Mem0LongTermMemory",
"ReMePersonalLongTermMemory",
"ReMeTaskLongTermMemory",
"ReMeToolLongTermMemory",
]

View File

@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""The long-term memory base class."""
from typing import Any
from agentscope.message import Msg
from agentscope.module import StateModule
from agentscope.tool import ToolResponse
class LongTermMemoryBase(StateModule):
"""The long-term memory base class, which should be a time-series
memory management system.
The `record_to_memory` and `retrieve_from_memory` methods are two tool
functions for agent to manage the long-term memory voluntarily. You can
choose not to implement these two functions.
The `record` and `retrieve` methods are for developers to use. For example,
retrieving/recording memory at the beginning of each reply, and adding
the retrieved memory to the system prompt.
"""
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> Any:
"""A developer-designed method to record information from the given
input message(s) to the long-term memory."""
raise NotImplementedError(
"The `record` method is not implemented. ",
)
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""A developer-designed method to retrieve information from the
long-term memory based on the given input message(s). The retrieved
information will be added to the system prompt of the agent."""
raise NotImplementedError(
"The `retrieve` method is not implemented. ",
)
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Use this function to record important information that you may
need later. The target content should be specific and concise, e.g.
who, when, where, do what, why, how, etc.
Args:
thinking (`str`):
Your thinking and reasoning about what to record
content (`list[str]`):
The content to remember, which is a list of strings.
"""
raise NotImplementedError(
"The `record_to_memory` method is not implemented. "
"You can implement it in your own long-term memory class.",
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Retrieve the memory based on the given keywords.
Args:
keywords (`list[str]`):
The keywords to search for in the memory, which should be
specific and concise, e.g. the person's name, the date, the
location, etc.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 5.
Returns:
`list[Msg]`:
A list of messages that match the keywords.
"""
raise NotImplementedError(
"The `retrieve_from_memory` method is not implemented. "
"You can implement it in your own long-term memory class.",
)

View File

@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""The Mem0 long-term memory module for AgentScope."""
from ._mem0_long_term_memory import Mem0LongTermMemory
__all__ = [
"Mem0LongTermMemory",
]

View File

@@ -0,0 +1,746 @@
# -*- coding: utf-8 -*-
"""Long-term memory implementation using mem0 library.
This module provides a long-term memory implementation that integrates
with the mem0 library to provide persistent memory storage and retrieval
capabilities for AgentScope agents.
"""
import asyncio
import json
from typing import Any, TYPE_CHECKING
from importlib import metadata
from pydantic import field_validator
from ....embedding import EmbeddingModelBase
from .._long_term_memory_base import LongTermMemoryBase
from ....message import Msg, TextBlock
from ....model import ChatModelBase
from ....tool import ToolResponse
if TYPE_CHECKING:
from mem0.configs.base import MemoryConfig
from mem0.vector_stores.configs import VectorStoreConfig
else:
MemoryConfig = Any
VectorStoreConfig = Any
def _create_agentscope_config_classes() -> tuple:
"""Create custom config classes for agentscope providers."""
from mem0.embeddings.configs import EmbedderConfig
from mem0.llms.configs import LlmConfig
class _ASLlmConfig(LlmConfig):
"""Custom LLM config class that updates the validate_config method.
Attention: in mem0, the validate_config hardcodes the provider, so we
need to override the validate_config method to support the agentscope
providers. We will follow up with the mem0 to improve this.
"""
@field_validator("config")
@classmethod
def validate_config(cls, v: Any, values: Any) -> Any:
"""Validate the LLM configuration."""
from mem0.utils.factory import LlmFactory
provider = values.data.get("provider")
if provider in LlmFactory.provider_to_class:
return v
raise ValueError(f"Unsupported LLM provider: {provider}")
class _ASEmbedderConfig(EmbedderConfig):
"""Custom embedder config class that updates the validate_config
method."""
@field_validator("config")
@classmethod
def validate_config(cls, v: Any, values: Any) -> Any:
"""Validate the embedder configuration."""
from mem0.utils.factory import EmbedderFactory
provider = values.data.get("provider")
if provider in EmbedderFactory.provider_to_class:
return v
raise ValueError(f"Unsupported Embedder provider: {provider}")
return _ASLlmConfig, _ASEmbedderConfig
class Mem0LongTermMemory(LongTermMemoryBase):
"""A class that implements the LongTermMemoryBase interface using mem0."""
@staticmethod
def _setup_mem0_logging(suppress_mem0_logging: bool) -> None:
"""Suppress mem0 logging if requested.
Args:
suppress_mem0_logging (`bool`):
Whether to suppress mem0 logging. See class docstring for
details on QDRANT validation errors when using mem0 1.0.3.
"""
if suppress_mem0_logging:
import logging
logging.getLogger("mem0").setLevel(logging.CRITICAL)
logging.getLogger("mem0.memory").setLevel(logging.CRITICAL)
logging.getLogger("mem0.memory.main").setLevel(logging.CRITICAL)
@staticmethod
def _register_agentscope_providers() -> None:
"""Register the agentscope providers with mem0.
Raises:
`ImportError`:
If the mem0 library is not installed.
"""
try:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.utils.factory import LlmFactory, EmbedderFactory
from packaging import version
# Check mem0 version
current_version = metadata.version("mem0ai")
is_mem0_version_low = version.parse(
current_version,
) <= version.parse("0.1.115")
# Register the agentscope providers with mem0
EmbedderFactory.provider_to_class["agentscope"] = (
"agentscope.memory._long_term_memory._mem0."
"_mem0_utils.AgentScopeEmbedding"
)
if is_mem0_version_low:
# For mem0 version <= 0.1.115, use the old style
LlmFactory.provider_to_class["agentscope"] = (
"agentscope.memory._long_term_memory._mem0."
"_mem0_utils.AgentScopeLLM"
)
else:
# For mem0 version > 0.1.115, use the new style
LlmFactory.provider_to_class["agentscope"] = (
"agentscope.memory._long_term_memory._mem0."
"_mem0_utils.AgentScopeLLM",
BaseLlmConfig,
)
except ImportError as e:
raise ImportError(
"Please install the mem0 library by `pip install mem0ai`",
) from e
@staticmethod
def _validate_identifiers(
agent_name: str | None,
user_name: str | None,
run_name: str | None,
) -> None:
"""Validate that at least one identifier is provided.
Args:
agent_name (`str | None`):
The name of the agent.
user_name (`str | None`):
The name of the user.
run_name (`str | None`):
The name of the run/session.
Raises:
`ValueError`:
If all identifiers are None.
"""
if agent_name is None and user_name is None and run_name is None:
raise ValueError(
"at least one of agent_name, user_name, and run_name is "
"required",
)
@staticmethod
def _configure_mem0_config(
mem0_config: MemoryConfig | None,
model: ChatModelBase | None,
embedding_model: EmbeddingModelBase | None,
vector_store_config: VectorStoreConfig | None,
_ASLlmConfig: type,
_ASEmbedderConfig: type,
**kwargs: Any,
) -> MemoryConfig:
"""Configure the mem0 MemoryConfig object.
Args:
mem0_config (`MemoryConfig | None`):
The existing mem0 config, if any.
model (`ChatModelBase | None`):
The chat model to use.
embedding_model (`EmbeddingModelBase | None`):
The embedding model to use.
vector_store_config (`VectorStoreConfig | None`):
The vector store config to use.
_ASLlmConfig (`type`):
The custom LLM config class for agentscope.
_ASEmbedderConfig (`type`):
The custom embedder config class for agentscope.
**kwargs (`Any`):
Additional keyword arguments, including 'on_disk' for
vector store configuration.
Returns:
`MemoryConfig`:
The configured MemoryConfig object.
Raises:
`ValueError`:
If `mem0_config` is None and either `model` or
`embedding_model` is None.
"""
import mem0
if mem0_config is not None:
# Case 1: mem0_config is provided - override specific
# configurations if individual params are given
# Override LLM configuration if model is provided
if model is not None:
mem0_config.llm = _ASLlmConfig(
provider="agentscope",
config={"model": model},
)
# Override embedder configuration if embedding_model is provided
if embedding_model is not None:
mem0_config.embedder = _ASEmbedderConfig(
provider="agentscope",
config={"model": embedding_model},
)
# Override vector store configuration if vector_store_config is
# provided
if vector_store_config is not None:
mem0_config.vector_store = vector_store_config
else:
# Case 2: mem0_config is not provided - create new configuration
# from individual parameters
# Validate that required parameters are provided
if model is None or embedding_model is None:
raise ValueError(
"model and embedding_model are required if mem0_config "
"is not provided",
)
# Create new MemoryConfig with provided LLM and embedder
mem0_config = mem0.configs.base.MemoryConfig(
llm=_ASLlmConfig(
provider="agentscope",
config={"model": model},
),
embedder=_ASEmbedderConfig(
provider="agentscope",
config={"model": embedding_model},
),
)
# Set vector store configuration
if vector_store_config is not None:
# Use provided vector store configuration
mem0_config.vector_store = vector_store_config
else:
# Use default Qdrant configuration with on-disk storage for
# persistence set on_disk to True to enable persistence,
# otherwise it will be in memory only
on_disk = kwargs.get("on_disk", True)
mem0_config.vector_store = (
mem0.vector_stores.configs.VectorStoreConfig(
config={"on_disk": on_disk},
)
)
return mem0_config
def __init__(
self,
agent_name: str | None = None,
user_name: str | None = None,
run_name: str | None = None,
model: ChatModelBase | None = None,
embedding_model: EmbeddingModelBase | None = None,
vector_store_config: VectorStoreConfig | None = None,
mem0_config: MemoryConfig | None = None,
default_memory_type: str | None = None,
suppress_mem0_logging: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the Mem0LongTermMemory instance
Args:
agent_name (`str | None`, optional):
The name of the agent. Default is None.
user_name (`str | None`, optional):
The name of the user. Default is None.
run_name (`str | None`, optional):
The name of the run/session. Default is None.
.. note::
1. At least one of `agent_name`, `user_name`, or `run_name` is
required.
2. During memory recording, these parameters become metadata
for the stored memories.
3. **Important**: mem0 will extract memories from messages
containing role of "user" by default. If you want to
extract memories from messages containing role of
"assistant", you need to provide `agent_name`.
4. During memory retrieval, only memories with matching
metadata values will be returned.
model (`ChatModelBase | None`, optional):
The chat model to use for the long-term memory. If
mem0_config is provided, this will override the LLM
configuration. If mem0_config is None, this is required.
embedding_model (`EmbeddingModelBase | None`, optional):
The embedding model to use for the long-term memory. If
mem0_config is provided, this will override the embedder
configuration. If mem0_config is None, this is required.
vector_store_config (`VectorStoreConfig | None`, optional):
The vector store config to use for the long-term memory.
If mem0_config is provided, this will override the vector store
configuration. If mem0_config is None and this is not
provided, defaults to Qdrant with on_disk=True.
mem0_config (`MemoryConfig | None`, optional):
The mem0 config to use for the long-term memory.
If provided, individual
model/embedding_model/vector_store_config parameters will
override the corresponding configurations in mem0_config. If
None, a new MemoryConfig will be created using the provided
parameters.
default_memory_type (`str | None`, optional):
The type of memory to use. Default is None, to create a
semantic memory.
suppress_mem0_logging (`bool`, optional):
Whether to suppress mem0 logging. Default is True.
.. note::
When using vector database QDRANT with mem0 1.0.3, you may
encounter validation errors:
"Error awaiting memory task (async): 6 validation errors for
PointStruct vector.list[float] Input should be a valid list
[type=list_type, ...]"
According to the mem0 community
(see https://github.com/mem0ai/mem0/issues/3780),
these error messages are harmless and can be safely ignored.
Setting `suppress_mem0_logging=True` (the default) will
suppress these error messages.
Raises:
`ValueError`:
If `mem0_config` is None and either `model` or
`embedding_model` is None.
"""
super().__init__()
# Suppress mem0 logging if requested
self._setup_mem0_logging(suppress_mem0_logging)
# Register agentscope providers with mem0
self._register_agentscope_providers()
# Create the custom config classes for agentscope providers dynamically
_ASLlmConfig, _ASEmbedderConfig = _create_agentscope_config_classes()
# Validate identifiers
self._validate_identifiers(agent_name, user_name, run_name)
# Store agent and user identifiers for memory management
self.agent_id = agent_name
self.user_id = user_name
self.run_id = run_name
# Configure mem0_config
import mem0
mem0_config = self._configure_mem0_config(
mem0_config=mem0_config,
model=model,
embedding_model=embedding_model,
vector_store_config=vector_store_config,
_ASLlmConfig=_ASLlmConfig,
_ASEmbedderConfig=_ASEmbedderConfig,
**kwargs,
)
# Initialize the async memory instance with the configured settings
self.long_term_working_memory = mem0.AsyncMemory(mem0_config)
# Store the default memory type for future use
self.default_memory_type = default_memory_type
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Use this function to record important information that you may
need later. The target content should be specific and concise, e.g.
who, when, where, do what, why, how, etc.
Args:
thinking (`str`):
Your thinking and reasoning about what to record.
content (`list[str]`):
The content to remember, which is a list of strings.
"""
# Multi-strategy recording approach to ensure content persistence:
#
# This method employs a three-tier fallback strategy to maximize
# successful memory recording:
#
# 1. Primary: Record as "user" role message
# - This is the default approach for capturing user-related
# content
# - Mem0 extracts and infers memories from messages containing
# role of "user"
#
# 2. Fallback (if agent_id exists): Record as "assistant" role
# message
# - Triggered when primary recording yields no results
# - In this case, mem0 will use the AGENT_MEMORY_EXTRACTION_PROMPT
# in mem0/mem0/configs/prompts.py to extract memories from
# messages containing role of "assistant", if agent_id is
# provided, otherwise it will use the
# USER_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
# to extract memories.
#
# 3. Last resort: Record as "assistant" with infer=False
# - Used when both previous attempts yield no results
# - Bypasses mem0's inference mechanism, which means no
# inference is performed, mem0 will only record the content
# as is.
#
# This graduated approach ensures that even if mem0's inference fails
# to extract meaningful memories, the raw content is still preserved.
try:
if thinking:
content = [thinking] + content
# Strategy 1: Record as user message first
results = await self._mem0_record(
[
{
"role": "user",
"content": "\n".join(content),
"name": "user",
},
],
**kwargs,
)
# Strategy 2: Fallback to assistant message. In this case, if
# agent_id is provided, mem0 will use the
# AGENT_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
# to extract memories from messages containing role of
# "assistant". If agent_id is not provided, mem0 will still use
# the USER_MEMORY_EXTRACTION_PROMPT in
# mem0/mem0/configs/prompts.py to extract memories.
if (
results
and isinstance(results, dict)
and "results" in results
and len(results["results"]) == 0
):
results = await self._mem0_record(
[
{
"role": "assistant",
"content": "\n".join(content),
"name": "assistant",
},
],
**kwargs,
)
# Strategy 3: Last resort - direct recording without inference.
# In this case, mem0 will not use any prompts to extract
# memories, it will only record the content as is.
if (
results
and isinstance(results, dict)
and "results" in results
and len(results["results"]) == 0
):
results = await self._mem0_record(
[
{
"role": "assistant",
"content": "\n".join(content),
"name": "assistant",
},
],
infer=False,
**kwargs,
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Successfully recorded content to memory "
f"{results}",
),
],
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Retrieve the memory based on the given keywords.
Args:
keywords (`list[str]`):
Short, targeted search phrases (for example, a person's name,
a specific date, a location, or a phrase describing something
you want to retrieve from the memory). Each keyword is issued
as an independent query against the memory store.
limit (`int`, optional):
The maximum number of memories to retrieve per search,
defaults to 5.
i.e.,the number of memories to retrieve for
each keyword.
Returns:
`ToolResponse`:
A ToolResponse containing the retrieved memories as JSON text.
"""
try:
results = []
search_coroutines = [
self.long_term_working_memory.search(
query=keyword,
agent_id=self.agent_id,
user_id=self.user_id,
run_id=self.run_id,
limit=limit,
)
for keyword in keywords
]
search_results = await asyncio.gather(*search_coroutines)
for result in search_results:
if result:
results.extend(
[item["memory"] for item in result["results"]],
)
if "relations" in result.keys():
results.extend(
self._format_relations(result),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text="\n".join(results),
),
],
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving memory: {str(e)}",
),
],
)
async def record(
self,
msgs: list[Msg | None],
memory_type: str | None = None,
infer: bool = True,
**kwargs: Any,
) -> dict:
"""Record the content to the long-term memory.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory.
memory_type (`str | None`, optional):
The type of memory to use. Default is None, to create a
semantic memory. "procedural_memory" is explicitly used for
procedural memories.
infer (`bool`, optional):
Whether to infer memory from the content. Default is True.
**kwargs (`Any`):
Additional keyword arguments for the mem0 recording.
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
messages = [
{
"role": "assistant",
"content": "\n".join([str(_.content) for _ in msg_list]),
"name": "assistant",
},
]
results = await self._mem0_record(
messages,
memory_type=memory_type,
infer=infer,
**kwargs,
)
return results
def _format_relations(self, result: dict) -> list:
"""Format relations from search result.
Args:
result (`dict`):
The result from the memory search operation.
Returns:
`list`:
The formatted relations.
Each relation is a string in the format of:
"{source} -- {relationship} -- {destination}"
"""
if "relations" not in result:
return []
return [
f"{relation['source']} -- "
f"{relation['relationship']} -- "
f"{relation['destination']}"
for relation in result["relations"]
]
async def _mem0_record(
self,
messages: str | list[dict],
memory_type: str | None = None,
infer: bool = True,
**kwargs: Any,
) -> dict:
"""Record the content to the long-term memory.
Args:
messages (`str`):
The content to remember, which is a string or a list of
dictionaries representing messages.
memory_type (`str | None`, optional):
The type of memory to use. Default is None, to create a
semantic memory. "procedural_memory" is explicitly used for
procedural memories.
infer (`bool`, optional):
Whether to infer memory from the content. Default is True.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`dict`:
The result from the memory recording operation.
"""
results = await self.long_term_working_memory.add(
messages=messages,
agent_id=self.agent_id,
user_id=self.user_id,
run_id=self.run_id,
memory_type=(
memory_type
if memory_type is not None
else self.default_memory_type
),
infer=infer,
**kwargs,
)
return results
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve the content from the long-term memory.
Args:
msg (`Msg | list[Msg] | None`):
The message to search for in the memory, which should be
specific and concise, e.g. the person's name, the date, the
location, etc.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. if the
message is a list of messages, the limit will be applied to
each message. If the message is a single message, then the
limit is the total number of memories to retrieve for the
message. Defaults to 5.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved memory
"""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
msg_strs = [
json.dumps(_.to_dict()["content"], ensure_ascii=False) for _ in msg
]
results = []
search_coroutines = [
self.long_term_working_memory.search(
query=item,
agent_id=self.agent_id,
user_id=self.user_id,
run_id=self.run_id,
limit=limit,
)
for item in msg_strs
]
search_results = await asyncio.gather(*search_coroutines)
for result in search_results:
if result:
results.extend(
[memory["memory"] for memory in result["results"]],
)
if "relations" in result.keys():
results.extend(
self._format_relations(result),
)
return "\n".join(results)

View File

@@ -0,0 +1,363 @@
# -*- coding: utf-8 -*-
"""Utility classes for integrating AgentScope with mem0 library.
This module provides wrapper classes that allow AgentScope models to be used
with the mem0 library for long-term memory functionality.
"""
import asyncio
import atexit
import threading
from typing import Any, Coroutine, Dict, List, Literal
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.configs.llms.base import BaseLlmConfig
from mem0.embeddings.base import EmbeddingBase
from mem0.llms.base import LLMBase
from ....embedding import EmbeddingModelBase
from ....model import ChatModelBase, ChatResponse
class _EventLoopManager:
"""Global event loop manager for running async operations in sync context.
This manager creates and maintains a persistent background event loop
that runs in a separate daemon thread. This ensures that async model
clients (like Ollama AsyncClient) remain bound to the same event loop
across multiple calls, avoiding "Event loop is closed" errors.
"""
_DEFAULT_TIMEOUT = 5.0 # Default timeout in seconds
def __init__(self) -> None:
"""Initialize the event loop manager."""
self.loop: asyncio.AbstractEventLoop | None = None
self.thread: threading.Thread | None = None
self.lock = threading.Lock()
self.loop_started = threading.Event()
# Register cleanup function to be called on program exit
atexit.register(self.cleanup)
def get_loop(self) -> asyncio.AbstractEventLoop:
"""Get or create the persistent background event loop.
Returns:
`asyncio.AbstractEventLoop`:
The persistent event loop running in a background thread.
Raises:
`RuntimeError`: If the event loop fails to start within the
timeout.
"""
with self.lock:
if self.loop is None or self.loop.is_closed():
def run_loop() -> None:
"""Run the event loop in the background thread."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Store the loop reference before starting
self.loop = loop
self.loop_started.set()
loop.run_forever()
# Clear the event before starting the thread
self.loop_started.clear()
# Create and start the background thread
self.thread = threading.Thread(
target=run_loop,
daemon=True,
name="AgentScope-AsyncLoop",
)
self.thread.start()
# Wait for the loop to be ready
if not self.loop_started.wait(timeout=self._DEFAULT_TIMEOUT):
raise RuntimeError(
"Timeout waiting for event loop to start",
)
# After waiting, self.loop should be set by the background thread
assert (
self.loop is not None
), "Event loop was not initialized properly"
return self.loop
def cleanup(self) -> None:
"""Cleanup the event loop and thread on program exit."""
with self.lock:
if self.loop is not None and not self.loop.is_closed():
# Stop the event loop gracefully
self.loop.call_soon_threadsafe(self.loop.stop)
# Wait for the thread to finish
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=self._DEFAULT_TIMEOUT)
# Close the loop
self.loop.close()
self.loop = None
self.thread = None
# Global event loop manager instance
_event_loop_manager = _EventLoopManager()
def _run_async_in_persistent_loop(coro: Coroutine) -> Any:
"""Run an async coroutine in the persistent background event loop.
This function uses a global event loop manager to ensure that all
async operations run in the same event loop, which is crucial for
async clients like Ollama that bind to a specific event loop.
Args:
coro (`Coroutine`):
The coroutine to run.
Returns:
`Any`:
The result of the coroutine.
Raises:
`RuntimeError`:
If there's an error running the coroutine.
"""
loop = _event_loop_manager.get_loop()
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result()
class AgentScopeLLM(LLMBase):
"""Wrapper for the AgentScope LLM.
This class is a wrapper for the AgentScope LLM. It is used to generate
responses using the AgentScope LLM in mem0.
"""
def __init__(self, config: BaseLlmConfig | None = None):
"""Initialize the AgentScopeLLM wrapper.
Args:
config (`BaseLlmConfig | None`, optional):
Configuration object for the LLM. Default is None.
"""
super().__init__(config)
if self.config.model is None:
raise ValueError("`model` parameter is required")
if not isinstance(self.config.model, ChatModelBase):
raise ValueError("`model` must be an instance of ChatModelBase")
self.agentscope_model = self.config.model
def _parse_response(
self,
model_response: ChatResponse,
has_tool: bool,
) -> str | dict:
"""Parse the model response into a string or
a dict to follow the mem0 library's format.
Args:
model_response (`ChatResponse`): The response from the model.
has_tool (`bool`): Whether there are tool calls in the response.
Returns:
`str | dict`:
The parsed response. If has_tool is True, return a dict
with "content" and "tool_calls" keys. Otherwise, return
a string.
"""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_parts = []
for block in model_response.content:
# Handle TextBlock
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(str(block.get("text", "")))
# Handle ThinkingBlock
elif isinstance(block, dict) and block.get("type") == "thinking":
thinking_parts.append(
f"[Thinking: {block.get('thinking', '')}]",
)
# Handle ToolUseBlock
elif isinstance(block, dict) and block.get("type") == "tool_use":
tool_name = block.get("name")
tool_input = block.get("input", {})
tool_parts.append(
{
"name": tool_name,
"arguments": tool_input,
},
)
text_part = thinking_parts + text_parts
if has_tool:
# If there are tool calls, return the content and tool calls
return {
"content": "\n".join(text_part) if len(text_part) > 0 else "",
"tool_calls": tool_parts,
}
else:
return "\n".join(text_part) if len(text_part) > 0 else ""
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Any | None = None,
tools: List[Dict] | None = None,
tool_choice: str = "auto",
) -> str | dict:
"""Generate a response based on the given messages using agentscope.
Args:
messages (`List[Dict[str, str]]`):
List of message dicts containing 'role' and 'content'.
response_format (`Any | None`, optional):
Format of the response. Not used in AgentScope.
tools (`List[Dict] | None`, optional):
List of tools that the model can call. Not used in AgentScope.
tool_choice (`str`, optional):
Tool choice method. Not used in AgentScope.
Returns:
`str | dict`:
The generated response.
"""
# pylint: disable=unused-argument
try:
# Convert the messages to AgentScope's format
agentscope_messages = []
for message in messages:
role = message["role"]
content = message["content"]
if role in ["system", "user", "assistant"]:
agentscope_messages.append(
{"role": role, "content": content},
)
if not agentscope_messages:
raise ValueError(
"No valid messages found in the messages list",
)
# Use the agentscope model to generate response (async call)
async def _async_call() -> ChatResponse:
# TODO: handle the streaming response or forbidden streaming
# mode
return await self.agentscope_model( # type: ignore
agentscope_messages,
tools=tools,
)
# Run in the persistent event loop
# This ensures the model client (e.g., Ollama AsyncClient)
# always runs in the same event loop, avoiding binding issues
response = _run_async_in_persistent_loop(
_async_call(),
)
has_tool = tools is not None
# Extract text from the response content blocks
if not response.content:
if has_tool:
return {
"content": "",
"tool_calls": [],
}
else:
return ""
return self._parse_response(response, has_tool)
except Exception as e:
raise RuntimeError(
f"Error generating response using agentscope model: {str(e)}",
) from e
class AgentScopeEmbedding(EmbeddingBase):
"""Wrapper for the AgentScope Embedding model.
This class is a wrapper for the AgentScope Embedding model. It is used
to generate embeddings using the AgentScope Embedding model in mem0.
"""
def __init__(self, config: BaseEmbedderConfig | None = None):
"""Initialize the AgentScopeEmbedding wrapper.
Args:
config (`BaseEmbedderConfig | None`, optional):
Configuration object for the embedder. Default is None.
"""
super().__init__(config)
if self.config.model is None:
raise ValueError("`model` parameter is required")
if not isinstance(self.config.model, EmbeddingModelBase):
raise ValueError(
"`model` must be an instance of EmbeddingModelBase",
)
self.agentscope_model = self.config.model
def embed(
self,
text: str | List[str],
memory_action: Literal[ # pylint: disable=unused-argument
"add",
"search",
"update",
]
| None = None,
) -> List[float]:
"""Get the embedding for the given text using AgentScope.
Args:
text (`str | List[str]`):
The text to embed.
memory_action (`Literal["add", "search", "update"] | None`, \
optional):
The type of embedding to use. Must be one of "add", "search",
or "update". Defaults to None.
Returns:
`List[float]`:
The embedding vector.
"""
try:
# Convert single text to list for AgentScope embedding model
text_list = [text] if isinstance(text, str) else text
# Use the agentscope model to generate embedding (async call)
async def _async_call() -> Any:
response = await self.agentscope_model(text_list)
return response
# Run in the persistent event loop
# This ensures the model client (e.g., Ollama AsyncClient)
# always runs in the same event loop, avoiding binding issues
response = _run_async_in_persistent_loop(
_async_call(),
)
# Extract the embedding vector from the first Embedding object
# response.embeddings is a list of Embedding objects
# Each Embedding object has an 'embedding' attribute containing
# the vector
embedding = response.embeddings[0]
if embedding is None:
raise ValueError("Failed to extract embedding from response")
return embedding
except Exception as e:
raise RuntimeError(
f"Error generating embedding using agentscope model: {str(e)}",
) from e

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""The reme memory module."""
from ._reme_personal_long_term_memory import ReMePersonalLongTermMemory
from ._reme_task_long_term_memory import ReMeTaskLongTermMemory
from ._reme_tool_long_term_memory import ReMeToolLongTermMemory
__all__ = [
"ReMePersonalLongTermMemory",
"ReMeTaskLongTermMemory",
"ReMeToolLongTermMemory",
]

View File

@@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
"""Base long-term memory implementation using ReMe library.
This module provides a base class for long-term memory implementations
that integrate with the ReMe library. ReMe enables agents to maintain
persistent, searchable memories across sessions and contexts.
The module handles the integration between AgentScope's memory system and
the ReMe library, including:
- Model configuration and API credential management
- Context lifecycle management (async context managers)
- Graceful handling of missing dependencies
- Error handling with helpful installation instructions
Key Features:
- Supports both DashScope and OpenAI model providers
- Automatic extraction of API credentials and endpoints
- Flexible configuration via config files or kwargs
- Safe fallback behavior when reme_ai is not installed
Dependencies:
The ReMe library is an optional dependency that must be installed:
.. code-block:: bash
pip install reme-ai
For more information, visit: https://github.com/modelscope/reMe
Subclasses:
This base class is extended by specific memory type implementations:
- ReMeToolLongTermMemory: For tool execution patterns and guidelines
- ReMeTaskLongTermMemory: For task execution experiences and learnings
- ReMePersonalLongTermMemory: For user preferences and personal information
Example:
.. code-block:: python
from agentscope.models import OpenAIChatModel
from agentscope.embedding import OpenAITextEmbedding
from agentscope.memory._reme import ReMeToolLongTermMemory
# Initialize models
model = OpenAIChatModel(model_name="gpt-4", api_key="...")
embedding = OpenAITextEmbedding(
model_name="text-embedding-3-small", api_key="...")
# Create memory instance
memory = ReMeToolLongTermMemory(
agent_name="my_agent",
user_name="user_123",
model=model,
embedding_model=embedding
)
# Use memory in async context
async with memory:
# Record tool execution
await memory.record_to_memory(
thinking="This tool worked well for data processing",
content=['{"tool_name": "process_data", "success": true, ...}']
)
# Retrieve tool guidelines
result = await memory.retrieve_from_memory(
keywords=["process_data"]
)
"""
from abc import ABCMeta
from typing import Any
from .._long_term_memory_base import LongTermMemoryBase
from ....embedding import DashScopeTextEmbedding, OpenAITextEmbedding
from ....model import DashScopeChatModel, OpenAIChatModel
class ReMeLongTermMemoryBase(LongTermMemoryBase, metaclass=ABCMeta):
"""Base class for ReMe-based long-term memory implementations.
This class provides the foundation for integrating AgentScope with the ReMe
library, enabling agents to maintain and retrieve long-term memories across
different contexts.
The ReMe library must be installed separately:
pip install reme-ai
If the library is not installed, a warning will be issued during
initialization,
and runtime errors with installation instructions will be raised
when attempting
to use memory operations.
"""
def __init__(
self,
agent_name: str | None = None,
user_name: str | None = None,
run_name: str | None = None,
model: DashScopeChatModel | OpenAIChatModel | None = None,
embedding_model: (
DashScopeTextEmbedding | OpenAITextEmbedding | None
) = None,
reme_config_path: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize the ReMe-based long-term memory.
This constructor sets up the connection to the ReMe
library and configures
the necessary models for memory operations. The ReMe app
will be initialized
with the provided model configurations.
Args:
agent_name (`str | None`, optional):
Name identifier for the agent. Used for organizing
memories by agent.
user_name (`str | None`, optional):
Unique identifier for the user or workspace. This maps
to workspace_id in ReMe and helps isolate memories across
different users/workspaces.
run_name (`str | None`, optional):
Name identifier for the current execution run or session.
model (`DashScopeChatModel | OpenAIChatModel | None`, optional):
The chat model to use for memory operations. The model's
API credentials and endpoint will be extracted and
passed to ReMe.
embedding_model (`DashScopeTextEmbedding | OpenAITextEmbedding | \
None`, optional):
The embedding model to use for semantic memory retrieval.
The model's API credentials and endpoint will be
extracted and passed to ReMe.
reme_config_path (`str | None`, optional):
Path to a custom ReMe configuration file. If not provided, ReMe
will use its default configuration.
**kwargs (`Any`):
Additional keyword arguments to pass to the
ReMeApp constructor.
These can include custom ReMe configuration parameters.
Raises:
`ValueError`:
If the provided model is not a DashScopeChatModel or
OpenAIChatModel, or if the embedding_model is not a
DashScopeTextEmbedding or OpenAITextEmbedding.
Note:
If the reme_ai library is not installed, a warning will be
issued and self.app will be set to None. Subsequent memory
operations will raise RuntimeError with installation
instructions.
Example:
.. code-block:: python
from agentscope.models import OpenAIChatModel
from agentscope.embedding import OpenAITextEmbedding
from agentscope.memory._reme import ReMeToolLongTermMemory
# Initialize models
model = OpenAIChatModel(
model_name="gpt-4",
api_key="your-api-key"
)
embedding = OpenAITextEmbedding(
model_name="text-embedding-3-small",
api_key="your-api-key"
)
# Create memory instance
memory = ReMeToolLongTermMemory(
agent_name="my_agent",
user_name="user_123",
run_name="session_001",
model=model,
embedding_model=embedding
)
# Use with async context manager
async with memory:
# Memory operations...
pass
"""
super().__init__()
# Store agent and workspace identifiers
self.agent_name = agent_name
# Maps to ReMe's workspace_id concept
self.workspace_id = user_name
self.run_name = run_name
# Build configuration arguments for ReMeApp
# These will be passed as command-line style config overrides
config_args = []
# Extract LLM API credentials based on model type
# DashScope uses a fixed endpoint, OpenAI can have custom base_url
if isinstance(model, DashScopeChatModel):
llm_api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
llm_api_key = model.api_key
elif isinstance(model, OpenAIChatModel):
llm_api_base = str(getattr(model.client, "base_url", None))
llm_api_key = str(getattr(model.client, "api_key", None))
else:
raise ValueError(
f"model must be a DashScopeChatModel or "
f"OpenAIChatModel instance. "
f"Got {type(model).__name__} instead.",
)
# Extract model name and add to config if provided
llm_model_name = model.model_name
if llm_model_name:
config_args.append(f"llm.default.model_name={llm_model_name}")
# Extract embedding model API credentials based on type
# Similar to LLM, DashScope uses fixed endpoint,
# OpenAI can be customized
if isinstance(embedding_model, DashScopeTextEmbedding):
embedding_api_base = (
"https://dashscope.aliyuncs.com/compatible-mode/v1"
)
embedding_api_key = embedding_model.api_key
elif isinstance(embedding_model, OpenAITextEmbedding):
embedding_api_base = str(
getattr(embedding_model.client, "base_url", None),
)
embedding_api_key = str(
getattr(embedding_model.client, "api_key", None),
)
else:
raise ValueError(
"embedding_model must be a DashScopeTextEmbedding or "
"OpenAITextEmbedding instance. "
f"Got {type(embedding_model).__name__} instead.",
)
# Extract embedding model name and add to config if provided
embedding_model_name = embedding_model.model_name
if embedding_model_name:
config_args.append(
f"embedding_model.default.model_name={embedding_model_name}",
)
dimensions = embedding_model.dimensions
config_args.append(
f'embedding_model.default.params={{"dimensions": {dimensions}}}',
)
# Attempt to import and initialize ReMe
# If import fails, set app to None and issue a warning
# This allows the class to be instantiated even without
# reme_ai installed
try:
from reme_ai import ReMeApp
except ImportError as e:
raise ImportError(
"The 'reme_ai' library is required for ReMe-based "
"long-term memory. Please install it by `pip install reme-ai`,"
"and visit: https://github.com/modelscope/reMe for more "
"information.",
) from e
# Initialize ReMe with extracted configurations
self.app = ReMeApp(
*config_args, # Config overrides as positional args
llm_api_key=llm_api_key,
llm_api_base=llm_api_base,
embedding_api_key=embedding_api_key,
embedding_api_base=embedding_api_base,
# Optional custom config file
config_path=reme_config_path,
# Additional ReMe-specific configurations
**kwargs,
)
# Track if the app context is active (started via __aenter__)
self._app_started = False
async def __aenter__(self) -> "ReMeLongTermMemoryBase":
"""Async context manager entry point.
This method is called when entering an async context
(using 'async with'). It initializes the ReMe app context if
available, enabling memory operations within the context block.
Returns:
`ReMeLongTermMemoryBase`:
The memory instance itself, allowing it to be used in
the context.
Example:
.. code-block:: python
memory = ReMeToolLongTermMemory(
agent_name="my_agent",
model=model,
embedding_model=embedding
)
async with memory:
# Memory operations can be performed here
await memory.record_to_memory(
thinking="Recording tool usage",
content=[...]
)
"""
if self.app is not None:
await self.app.__aenter__()
self._app_started = True
return self
async def __aexit__(
self,
exc_type: Any = None,
exc_val: Any = None,
exc_tb: Any = None,
) -> None:
"""Async context manager exit point.
This method is called when exiting an async context (at the end
of 'async with' block or when an exception occurs). It properly
cleans up the ReMe app context and resources.
Args:
exc_type (`Any`):
The type of exception that occurred, if any. None if no
exception.
exc_val (`Any`):
The exception instance that occurred, if any. None if no
exception.
exc_tb (`Any`):
The traceback object for the exception, if any. None if
no exception.
.. note:: This method will gracefully handle the case where self.app
is None (reme_ai not installed) by skipping the cleanup but still
marking the app as stopped. It will also always set _app_started
to False, ensuring the memory state is properly reset.
Example:
.. code-block:: python
async with memory:
try:
# Memory operations
await memory.record_to_memory(...)
except Exception as e:
# __aexit__ will be called even if an exception occurs
print(f"Error: {e}")
# __aexit__ has been called and resources are cleaned up
"""
if self.app is not None:
await self.app.__aexit__(exc_type, exc_val, exc_tb)
self._app_started = False

View File

@@ -0,0 +1,414 @@
# -*- coding: utf-8 -*-
"""Personal memory implementation using ReMe library.
This module provides a personal memory implementation that integrates
with the ReMe library to provide persistent personal memory storage and
retrieval capabilities for AgentScope agents.
"""
from typing import Any
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
from ...._logging import logger
from ....message import Msg, TextBlock
from ....tool import ToolResponse
class ReMePersonalLongTermMemory(ReMeLongTermMemoryBase):
"""Personal memory implementation using ReMe library."""
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Record important user information to long-term memory.
Record important user information to long-term memory for future
reference.
Use this function to save user's personal information,
preferences, habits, and facts that you may need in future
conversations. This enables you to provide personalized and
contextually relevant responses.
When to record:
- User shares personal preferences (e.g., "I prefer homestays
when traveling")
- User mentions habits or routines (e.g., "I start work at 9 AM")
- User states likes/dislikes (e.g., "I enjoy drinking green tea")
- User provides personal facts (e.g., "I work as a software
engineer")
What to record: Be specific and structured. Include who, when,
where, what, why, and how when relevant.
Args:
thinking (`str`):
Your reasoning about why this information is worth
recording and how it might be useful later.
content (`list[str]`):
List of specific facts to remember. Each string should be
a clear, standalone piece of information. Examples:
["User prefers homestays in Hangzhou", "User likes
visiting West Lake in the morning"].
**kwargs (`Any`):
Additional keyword arguments for the recording operation.
Returns:
`ToolResponse`:
Confirmation message indicating successful memory
recording.
"""
logger.info(
"[ReMePersonalMemory] Entering record_to_memory - "
"thinking: %s, content: %s, kwargs: %s",
thinking,
content,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Prepare messages for personal memory recording
messages = []
# Add thinking as a user message if provided
if thinking:
messages.append(
{
"role": "user",
"content": thinking,
},
)
# Add content items as user messages
for item in content:
messages.append(
{
"role": "user",
"content": item,
},
)
# Add a simple assistant acknowledgment
messages.append(
{
"role": "assistant",
"content": (
"I understand and will remember this "
"information."
),
},
)
result = await self.app.async_execute(
name="summary_personal_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
},
],
**kwargs,
)
# Extract metadata about stored memories if available
metadata = result.get("metadata", {})
memory_list = metadata.get("memory_list", [])
if memory_list:
summary_text = (
f"Successfully recorded {len(memory_list)} "
f"memory/memories to personal memory."
)
else:
summary_text = "Memory recording completed."
return ToolResponse(
content=[
TextBlock(
type="text",
text=summary_text,
),
],
metadata={"result": result},
)
except Exception as e:
logger.exception("Error recording memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Search and retrieve relevant information from long-term memory.
.. note:: You should call this function BEFORE answering
questions about the user's preferences, past information, or
personal details. This ensures you provide accurate information
based on stored memories rather than guessing.
Use this when:
- User asks "what do I like?", "what are my preferences?",
"what do you know about me?"
- User asks about their past behaviors, habits, or stated
preferences
- User refers to information they shared in previous
conversations
- You need to personalize responses based on user's history
Args:
keywords (`list[str]`):
Keywords to search for in memory. Be specific and use
multiple keywords for better results. Examples:
["travel preferences", "Hangzhou"], ["work habits",
"morning routine"], ["food preferences", "tea"].
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 3.
**kwargs (`Any`):
Additional keyword arguments for the retrieval operation.
Returns:
`ToolResponse`:
Retrieved memories matching the keywords. If no memories
found, you'll receive a message indicating that.
"""
logger.info(
"[ReMePersonalMemory] Entering retrieve_from_memory - "
"keywords: %s, kwargs: %s",
keywords,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
results = []
# Search for each keyword
for keyword in keywords:
result = await self.app.async_execute(
name="retrieve_personal_memory",
workspace_id=self.workspace_id,
query=keyword,
top_k=limit,
**kwargs,
)
# Extract the answer from the result
answer = result.get("answer", "")
if answer:
results.append(f"Keyword '{keyword}':\n{answer}")
# Combine all results
if results:
combined_text = "\n\n".join(results)
else:
combined_text = "No memories found for the given keywords."
return ToolResponse(
content=[
TextBlock(
type="text",
text=combined_text,
),
],
)
except Exception as e:
logger.exception("Error retrieving memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving memory: {str(e)}",
),
],
)
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> None:
"""Record the content to the long-term memory.
This method converts AgentScope messages to ReMe's format and
records them using the personal memory flow.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory.
**kwargs (`Any`):
Additional keyword arguments for the mem0 recording.
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not msg_list:
return
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Convert AgentScope messages to ReMe format
messages = []
for msg in msg_list:
# Extract content as string
if isinstance(msg.content, str):
content_str = msg.content
elif isinstance(msg.content, list):
# Join content blocks into a single string
content_parts = []
for block in msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
content_str = "\n".join(content_parts)
else:
content_str = str(msg.content)
messages.append(
{
"role": msg.role,
"content": content_str,
},
)
await self.app.async_execute(
name="summary_personal_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
},
],
**kwargs,
)
except Exception as e:
# Log the error but don't raise to maintain compatibility
logger.exception("Error recording messages to memory: %s", str(e))
import warnings
warnings.warn(f"Error recording messages to memory: {str(e)}")
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve the content from the long-term memory.
Args:
msg (`Msg | list[Msg] | None`):
The message to search for in the memory, which should be
specific and concise, e.g. the person's name, the date, the
location, etc.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. If the
message is a list of messages, the limit applies to each
message. If the message is a single message, the limit is the
total number of memories to retrieve for that message. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved memory as a string.
"""
if msg is None:
return ""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Only use the last message's content for retrieval
last_msg = msg[-1]
query = ""
if isinstance(last_msg.content, str):
query = last_msg.content
elif isinstance(last_msg.content, list):
# Extract text from content blocks
content_parts = []
for block in last_msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
query = "\n".join(content_parts)
if not query:
return ""
# Retrieve using the query from the last message
result = await self.app.async_execute(
name="retrieve_personal_memory",
workspace_id=self.workspace_id,
query=query,
top_k=limit,
**kwargs,
)
return result.get("answer", "")
except Exception as e:
logger.exception("Error retrieving memory: %s", str(e))
import warnings
warnings.warn(f"Error retrieving memory: {str(e)}")
return ""

View File

@@ -0,0 +1,436 @@
# -*- coding: utf-8 -*-
"""Task memory implementation using ReMe library.
This module provides a task memory implementation that integrates
with the ReMe library to learn from execution trajectories and
retrieve relevant task experiences.
"""
from typing import Any
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
from ...._logging import logger
from ....message import Msg, TextBlock
from ....tool import ToolResponse
class ReMeTaskLongTermMemory(ReMeLongTermMemoryBase):
"""Task memory implementation using ReMe library.
Task memory learns from execution trajectories and provides
retrieval of relevant task experiences.
"""
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Record task execution experiences and learnings.
Record task execution experiences and learnings to long-term
memory.
Use this function to save valuable task-related knowledge that
can help with future similar tasks. This enables learning from
experience and improving over time.
When to record:
- After solving technical problems or completing tasks
- When discovering useful techniques or approaches
- After implementing solutions with specific steps
- When learning best practices or important lessons
What to record: Be detailed and actionable. Include:
- Task description and context
- Step-by-step execution details
- Specific techniques and methods used
- Results, outcomes, and effectiveness
- Lessons learned and considerations
Args:
thinking (`str`):
Your reasoning about why this task experience is valuable
and what makes it worth remembering for future reference.
content (`list[str]`):
List of specific task insights to remember. Each string
should be a clear, actionable piece of information.
Examples: ["Add indexes on WHERE clause columns to speed
up queries", "Use EXPLAIN ANALYZE to identify missing
indexes"].
**kwargs (`Any`):
Additional keyword arguments. Can include 'score' (float)
to indicate the quality/success of this approach
(default: 1.0).
Returns:
`ToolResponse`:
Confirmation message indicating successful memory
recording.
"""
logger.info(
"[ReMeTaskMemory] Entering record_to_memory - "
"thinking: %s, content: %s, kwargs: %s",
thinking,
content,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Prepare messages for task memory recording
messages = []
# Add thinking as a user message if provided
if thinking:
messages.append(
{
"role": "user",
"content": thinking,
},
)
# Add content items as user-assistant pairs
for item in content:
messages.append(
{
"role": "user",
"content": item,
},
)
# Add a simple assistant acknowledgment
messages.append(
{
"role": "assistant",
"content": "Task information recorded.",
},
)
result = await self.app.async_execute(
name="summary_task_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
"score": kwargs.pop("score", 1.0),
},
],
**kwargs,
)
# Extract metadata if available
summary_text = (
f"Successfully recorded {len(content)} task memory/memories."
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=summary_text,
),
],
metadata={"result": result},
)
except Exception as e:
logger.exception("Error recording task memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording task memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Search and retrieve relevant task experiences.
Search and retrieve relevant task experiences from long-term
memory.
IMPORTANT: You should call this function BEFORE attempting to
solve problems or answer technical questions. This ensures you
leverage experiences and proven solutions rather than
starting from scratch.
Use this when:
- Asked to solve a technical problem or implement a solution
- Asked for recommendations, best practices, or approaches
- Asked "what do you know about...?" or "have you seen this
before?"
- Dealing with tasks that may be similar to experiences
- Need to recall specific techniques or methods
Benefits of retrieving first:
- Learn from past successes and mistakes
- Provide more accurate, battle-tested solutions
- Avoid reinventing the wheel
- Give consistent, informed recommendations
Args:
keywords (`list[str]`):
Keywords describing the task or problem domain. Be
specific and use technical terms. Examples:
["database optimization", "slow queries"], ["API design",
"rate limiting"], ["code refactoring", "Python"].
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments. Can include 'top_k' (int)
to specify number of experiences to retrieve
(default: 3).
Returns:
`ToolResponse`:
Retrieved task experiences and learnings. If no relevant
experiences found, you'll receive a message indicating
that.
"""
logger.info(
"[ReMeTaskMemory] Entering retrieve_from_memory - "
"keywords: %s, kwargs: %s",
keywords,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
results = []
# Search for each keyword
for keyword in keywords:
result = await self.app.async_execute(
name="retrieve_task_memory",
workspace_id=self.workspace_id,
query=keyword,
top_k=limit,
**kwargs,
)
# Extract the answer from the result
answer = result.get("answer", "")
if answer:
results.append(f"Keyword '{keyword}':\n{answer}")
# Combine all results
if results:
combined_text = "\n\n".join(results)
else:
combined_text = (
"No task experiences found for the given keywords."
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=combined_text,
),
],
)
except Exception as e:
logger.exception("Error retrieving task memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving task memory: {str(e)}",
),
],
)
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> None:
"""Record the content to the task memory.
This method converts AgentScope messages to ReMe's format and
records them as a task execution trajectory.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory.
**kwargs (`Any`):
Additional keyword arguments for the recording.
Can include 'score' (float) for trajectory scoring
(default: 1.0).
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not msg_list:
return
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Convert AgentScope messages to ReMe format
messages = []
for msg in msg_list:
# Extract content as string
if isinstance(msg.content, str):
content_str = msg.content
elif isinstance(msg.content, list):
# Join content blocks into a single string
content_parts = []
for block in msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
content_str = "\n".join(content_parts)
else:
content_str = str(msg.content)
messages.append(
{
"role": msg.role,
"content": content_str,
},
)
# Extract score from kwargs if provided, default to 1.0
score = kwargs.pop("score", 1.0)
await self.app.async_execute(
name="summary_task_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
"score": score,
},
],
**kwargs,
)
except Exception as e:
# Log the error but don't raise to maintain compatibility
logger.exception(
"Error recording messages to task memory: %s",
str(e),
)
import warnings
warnings.warn(
f"Error recording messages to task memory: {str(e)}",
)
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve relevant task experiences from memory.
Args:
msg (`Msg | list[Msg] | None`):
The message to search for relevant task experiences.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. If the
message is a list of messages, the limit applies to each
message. If the message is a single message, the limit is the
total number of memories to retrieve for that message. Defaults
to 3.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved task experiences as a string.
"""
if msg is None:
return ""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Only use the last message's content for retrieval
last_msg = msg[-1]
query = ""
if isinstance(last_msg.content, str):
query = last_msg.content
elif isinstance(last_msg.content, list):
# Extract text from content blocks
content_parts = []
for block in last_msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
query = "\n".join(content_parts)
if not query:
return ""
# Retrieve using the query from the last message
result = await self.app.async_execute(
name="retrieve_task_memory",
workspace_id=self.workspace_id,
query=query,
top_k=limit,
**kwargs,
)
return result.get("answer", "")
except Exception as e:
logger.exception("Error retrieving task memory: %s", str(e))
import warnings
warnings.warn(f"Error retrieving task memory: {str(e)}")
return ""

View File

@@ -0,0 +1,545 @@
# -*- coding: utf-8 -*-
"""Tool memory implementation using ReMe library.
This module provides a tool memory implementation that integrates
with the ReMe library to record tool execution results and retrieve
tool usage guidelines.
"""
from typing import Any
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
from ...._logging import logger
from ....message import Msg, TextBlock
from ....tool import ToolResponse
class ReMeToolLongTermMemory(ReMeLongTermMemoryBase):
"""Tool memory implementation using ReMe library.
Tool memory records tool execution results and generates usage
guidelines from the execution history.
"""
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Record tool execution results to build tool usage patterns.
Record tool execution results to build a knowledge base of tool
usage patterns.
Use this function after successfully using tools to capture
execution details, results, and performance metrics. Over time,
this builds comprehensive usage guidelines and best practices
for each tool.
When to record:
- After successfully executing any tool
- After tool failures (to learn what doesn't work)
- When discovering effective parameter combinations
- After noteworthy tool usage patterns
What to record: Each tool execution should include complete
execution details.
Args:
thinking (`str`):
Your reasoning about why this tool execution is worth
recording. Mention what worked well, what could be
improved, or lessons learned.
content (`list[str]`):
List of JSON strings, each representing a tool execution.
Each JSON must have these fields:
- create_time: Timestamp in format "YYYY-MM-DD HH:MM:SS"
- tool_name: Name of the tool executed
- input: Input parameters as a dict
- output: Tool's output as a string
- token_cost: Token cost (integer)
- success: Whether execution succeeded (boolean)
- time_cost: Execution time in seconds (float)
Example: '{"create_time": "2024-01-01 10:00:00",
"tool_name": "search", "input": {"query": "Python"},
"output": "Found 10 results", "token_cost": 100,
"success": true, "time_cost": 1.2}'
**kwargs (`Any`):
Additional keyword arguments for the recording operation.
Returns:
`ToolResponse`:
Confirmation message with number of executions recorded
and guidelines generated.
"""
logger.info(
"[ReMeToolMemory] Entering record_to_memory - "
"thinking: %s, content: %s, kwargs: %s",
thinking,
content,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
import json
# Parse each content item as a tool_call_result
tool_call_results = []
tool_names_set = set()
for item in content:
try:
# Parse JSON string to dict
tool_call_result = json.loads(item)
tool_call_results.append(tool_call_result)
# Track tool names for summary
if "tool_name" in tool_call_result:
tool_names_set.add(tool_call_result["tool_name"])
except json.JSONDecodeError as e:
# Skip invalid JSON items
import warnings
warnings.warn(
f"Failed to parse tool call result JSON: {item}. "
f"Error: {str(e)}",
)
continue
if not tool_call_results:
return ToolResponse(
content=[
TextBlock(
type="text",
text="No valid tool call results to record.",
),
],
)
# First, add the tool call results
await self.app.async_execute(
name="add_tool_call_result",
workspace_id=self.workspace_id,
tool_call_results=tool_call_results,
**kwargs,
)
# Then, summarize the tool memory for the affected tools
if tool_names_set:
tool_names_list = list(tool_names_set)
await self.app.async_execute(
name="summary_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names_list,
**kwargs,
)
num_results = len(tool_call_results)
summary_text = (
f"Successfully recorded {num_results} tool execution "
f"result{'s' if num_results > 1 else ''} and generated "
f"usage guidelines."
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=summary_text,
),
],
)
except Exception as e:
logger.exception("Error recording tool memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording tool memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Retrieve usage guidelines and best practices for tools.
Retrieve usage guidelines and best practices for specific tools.
.. note:: You should call this function BEFORE using a tool,
especially if you're uncertain about its proper usage or want to
follow established best practices. This retrieves synthesized
guidelines based on past tool executions.
Use this when:
- About to use a tool and want to know the best practices
- Uncertain about tool parameters or usage patterns
- Want to learn from past successful/failed tool executions
- User asks "how should I use this tool?" or "what's the best
way to..."
- Need to understand tool performance characteristics or
limitations
Benefits of retrieving first:
- Learn from accumulated tool usage experience
- Avoid common mistakes and pitfalls
- Use optimal parameter combinations
- Understand tool performance and cost characteristics
- Follow established best practices
Args:
keywords (`list[str]`):
List of tool names to retrieve guidelines for. Use the
exact tool names. Examples: ["search"],
["database_query", "cache_get"], ["api_call"].
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments for the retrieval operation.
Returns:
`ToolResponse`:
Retrieved usage guidelines and best practices for the
specified tools. If no guidelines exist yet, you'll
receive a message indicating that.
"""
logger.info(
"[ReMeToolMemory] Entering retrieve_from_memory - "
"keywords: %s, kwargs: %s",
keywords,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Join all tool names with comma
tool_names = ",".join(keywords)
# Retrieve tool guidelines for all tools at once
result = await self.app.async_execute(
name="retrieve_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names,
top_k=limit,
**kwargs,
)
# Extract the answer from the result
answer = result.get("answer", "")
if answer:
combined_text = answer
else:
combined_text = f"No tool guidelines found for: {tool_names}"
return ToolResponse(
content=[
TextBlock(
type="text",
text=combined_text,
),
],
)
except Exception as e:
logger.exception("Error retrieving tool memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving tool memory: {str(e)}",
),
],
)
def _extract_content_from_messages(self, msg_list: list[Msg]) -> list[str]:
"""Extract content strings from messages.
Args:
msg_list (`list[Msg]`):
List of messages to extract content from.
Returns:
`list[str]`:
List of extracted content strings.
"""
content_list = []
for msg in msg_list:
if isinstance(msg.content, str):
content_list.append(msg.content)
elif isinstance(msg.content, list):
content_list.extend(
self._extract_text_from_blocks(msg.content),
)
return content_list
def _extract_text_from_blocks(self, blocks: list) -> list[str]:
"""Extract text from content blocks.
Args:
blocks (`list`):
List of content blocks.
Returns:
`list[str]`:
List of extracted text strings.
"""
texts = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "text":
texts.append(block.get("text", ""))
elif isinstance(block, str):
texts.append(block)
return texts
def _parse_tool_call_results(
self,
content_list: list[str],
) -> tuple[list[dict], set[str]]:
"""Parse JSON content strings into tool call results.
Args:
content_list (`list[str]`):
List of JSON strings to parse.
Returns:
`tuple[list[dict], set[str]]`:
Tuple of (tool_call_results, tool_names_set).
"""
import json
import warnings
tool_call_results = []
tool_names_set = set()
for item in content_list:
try:
tool_call_result = json.loads(item)
tool_call_results.append(tool_call_result)
if "tool_name" in tool_call_result:
tool_names_set.add(tool_call_result["tool_name"])
except json.JSONDecodeError as e:
warnings.warn(
f"Failed to parse tool call result JSON: {item}. "
f"Error: {str(e)}",
)
return tool_call_results, tool_names_set
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> None:
"""Record the content to the tool memory.
This method extracts content from messages and treats them as
JSON strings representing tool_call_results, similar to
record_to_memory.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory. Each message's content
should be a JSON string or list of JSON strings
representing tool_call_results.
**kwargs (`Any`):
Additional keyword arguments for the recording.
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not msg_list:
return
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Extract content from messages and parse as tool_call_results
content_list = self._extract_content_from_messages(msg_list)
if not content_list:
return
# Parse each content item as a tool_call_result
tool_call_results, tool_names_set = self._parse_tool_call_results(
content_list,
)
if not tool_call_results:
return
# First, add the tool call results
await self.app.async_execute(
name="add_tool_call_result",
workspace_id=self.workspace_id,
tool_call_results=tool_call_results,
**kwargs,
)
# Then, summarize the tool memory for the affected tools
if tool_names_set:
tool_names_list = list(tool_names_set)
await self.app.async_execute(
name="summary_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names_list,
**kwargs,
)
except Exception as e:
# Log the error but don't raise to maintain compatibility
logger.exception(
"Error recording tool messages to memory: %s",
str(e),
)
import warnings
warnings.warn(
f"Error recording tool messages to memory: {str(e)}",
)
def _extract_tool_names_from_message(self, msg: Msg) -> str:
"""Extract tool names from a message.
Args:
msg (`Msg`):
Message to extract tool names from.
Returns:
`str`:
Extracted tool names as a string.
"""
if isinstance(msg.content, str):
return msg.content
if isinstance(msg.content, list):
content_parts = []
for block in msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
return " ".join(content_parts)
return ""
def _format_retrieve_result(self, result: Any) -> str:
"""Format the retrieve result into a string.
Args:
result (`Any`):
Result from the retrieve operation.
Returns:
`str`:
Formatted result string.
"""
if isinstance(result, dict) and "answer" in result:
return result["answer"]
if isinstance(result, str):
return result
return str(result)
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve tool guidelines from memory.
Retrieve tool guidelines from memory based on message content.
Args:
msg (`Msg | list[Msg] | None`):
The message containing tool names or queries to
retrieve guidelines for.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. If the
message is a list of messages, the limit applies to each
message. If the message is a single message, the limit is the
total number of memories to retrieve for that message. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved tool guidelines as a string.
"""
if msg is None:
return ""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Extract tool names from the last message
last_msg = msg[-1]
tool_names = self._extract_tool_names_from_message(last_msg)
if not tool_names:
return ""
# Retrieve tool guidelines
result = await self.app.async_execute(
name="retrieve_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names,
top_k=limit,
**kwargs,
)
return self._format_retrieve_result(result)
except Exception as e:
logger.exception("Error retrieving tool guidelines: %s", str(e))
import warnings
warnings.warn(f"Error retrieving tool guidelines: {str(e)}")
return ""

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The working memory module in AgentScope, which provides various memory
storage implementations. In AgentScope, such module is responsible for
storing and managing the short-term memory with specific marks."""
from ._base import MemoryBase
from ._in_memory_memory import InMemoryMemory
from ._redis_memory import RedisMemory
from ._sqlalchemy_memory import AsyncSQLAlchemyMemory
__all__ = [
"MemoryBase",
"InMemoryMemory",
"RedisMemory",
"AsyncSQLAlchemyMemory",
]

View File

@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
"""The memory base class."""
from abc import abstractmethod
from typing import Any
from ...message import Msg
from ...module import StateModule
class MemoryBase(StateModule):
"""The base class for memory in agentscope."""
def __init__(self) -> None:
"""Initialize the memory base."""
super().__init__()
self._compressed_summary: str = ""
self.register_state("_compressed_summary")
async def update_compressed_summary(self, summary: str) -> None:
"""Update the compressed summary of the memory.
Args:
summary (`str`):
The new compressed summary.
"""
self._compressed_summary = summary
@abstractmethod
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
**kwargs: Any,
) -> None:
"""Add message(s) into the memory storage with the given mark
(if provided).
Args:
memories (`Msg | list[Msg] | None`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
"""
@abstractmethod
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
async def delete_by_mark(
self,
mark: str | list[str],
*args: Any,
**kwargs: Any,
) -> int:
"""Remove messages from the memory by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Raises:
`TypeError`:
If the provided mark is not a string or a list of strings.
Returns:
`int`:
The number of messages removed.
"""
raise NotImplementedError(
"The delete_by_mark method is not implemented in "
f"{self.__class__.__name__} class.",
)
@abstractmethod
async def size(self) -> int:
"""Get the number of messages in the storage.
Returns:
`int`:
The number of messages in the storage.
"""
@abstractmethod
async def clear(self) -> None:
"""Clear the memory content."""
@abstractmethod
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
raise NotImplementedError(
"The update_messages_mark method is not implemented in "
f"{self.__class__.__name__} class.",
)

View File

@@ -0,0 +1,305 @@
# -*- coding: utf-8 -*-
"""The in-memory storage module for memory storage."""
from copy import deepcopy
from typing import Any
from ...message import Msg
from ._base import MemoryBase
class InMemoryMemory(MemoryBase):
"""The in-memory implementation of memory storage."""
def __init__(self) -> None:
"""Initialize the in-memory storage."""
super().__init__()
# Use a list of tuples to store messages along with their marks
self.content: list[tuple[Msg, list[str]]] = []
# Register the state for serialization
self.register_state("content")
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Raises:
`TypeError`:
If the provided mark is not a string or None.
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if not (mark is None or isinstance(mark, str)):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if not (exclude_mark is None or isinstance(exclude_mark, str)):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
# Filter messages based on mark
filtered_content = [
(msg, marks)
for msg, marks in self.content
if mark is None or mark in marks
]
# Further filter messages based on exclude_mark
if exclude_mark is not None:
filtered_content = [
(msg, marks)
for msg, marks in filtered_content
if exclude_mark not in marks
]
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*[msg for msg, _ in filtered_content],
]
return [msg for msg, _ in filtered_content]
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
allow_duplicates: bool = False,
**kwargs: Any,
) -> None:
"""Add message(s) into the memory storage with the given mark
(if provided).
Args:
memories (`Msg | list[Msg] | None`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
allow_duplicates (`bool`, defaults to `False`):
Whether to allow duplicate messages in the storage.
"""
if memories is None:
return
if isinstance(memories, Msg):
memories = [memories]
if marks is None:
marks = []
elif isinstance(marks, str):
marks = [marks]
elif not isinstance(marks, list) or not all(
isinstance(m, str) for m in marks
):
raise TypeError(
f"The mark should be a string, a list of strings, or None, "
f"but got {type(marks)}.",
)
if not allow_duplicates:
existing_ids = {msg.id for msg, _ in self.content}
memories = [msg for msg in memories if msg.id not in existing_ids]
for msg in memories:
self.content.append((deepcopy(msg), deepcopy(marks)))
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
initial_size = len(self.content)
self.content = [
(msg, marks)
for msg, marks in self.content
if msg.id not in msg_ids
]
return initial_size - len(self.content)
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the memory by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Raises:
`TypeError`:
If the provided mark is not a string or a list of strings.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
if isinstance(mark, list) and not all(
isinstance(m, str) for m in mark
):
raise TypeError(
f"The mark should be a string or a list of strings, "
f"but got {type(mark)} with elements of types "
f"{[type(m) for m in mark]}.",
)
initial_size = len(self.content)
for m in mark:
self.content = [
(msg, marks) for msg, marks in self.content if m not in marks
]
return initial_size - len(self.content)
async def clear(self) -> None:
"""Clear all messages from the storage."""
self.content.clear()
async def size(self) -> int:
"""Get the number of messages in the storage.
Returns:
`int`:
The number of messages in the storage.
"""
return len(self.content)
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
updated_count = 0
for idx, (msg, marks) in enumerate(self.content):
# If msg_ids is provided, skip messages not in the list
if msg_ids is not None and msg.id not in msg_ids:
continue
# If old_mark is provided, skip messages that do not have the old
# mark
if old_mark is not None and old_mark not in marks:
continue
# If new_mark is None, remove the old_mark
if new_mark is None:
if old_mark in marks:
marks.remove(old_mark)
updated_count += 1
else:
# If new_mark is provided, add or replace the old_mark
if old_mark is not None and old_mark in marks:
marks.remove(old_mark)
if new_mark not in marks:
marks.append(new_mark)
updated_count += 1
self.content[idx] = (msg, marks)
return updated_count
def state_dict(self) -> dict:
"""Get the state dictionary for serialization."""
return {
**super().state_dict(),
"content": [[msg.to_dict(), marks] for msg, marks in self.content],
}
def load_state_dict(self, state_dict: dict, strict: bool = True) -> None:
"""Load the state dictionary for deserialization."""
if strict and "content" not in state_dict:
raise KeyError(
"The state_dict does not contain 'content' "
"keys required for InMemoryMemory.",
)
self._compressed_summary = state_dict.get("_compressed_summary", "")
self.content = []
for item in state_dict.get("content", []):
if isinstance(item, (tuple, list)) and len(item) == 2:
msg_dict, marks = item
msg = Msg.from_dict(msg_dict)
self.content.append((msg, marks))
elif isinstance(item, dict):
# For compatibility with older versions
msg = Msg.from_dict(item)
self.content.append((msg, []))
else:
raise ValueError(
"Invalid item format in state_dict for InMemoryMemory.",
)

View File

@@ -0,0 +1,827 @@
# -*- coding: utf-8 -*-
"""The redis based memory storage implementation."""
import json
from typing import Any, TYPE_CHECKING
from ._base import MemoryBase
from ...message import Msg
if TYPE_CHECKING:
from redis.asyncio import ConnectionPool, Redis
else:
ConnectionPool = Any
Redis = Any
class RedisMemory(MemoryBase):
"""Redis memory storage implementation, which supports session and user
context.
.. note:: All the operations in this class are within a specific session
and user context, identified by `session_id` and `user_id`. Cross-session
or cross-user operations are not supported. For example, the
`remove_messages` method will only remove messages that belong to the
specified `session_id` and `user_id`.
.. note:: All Redis keys used by this class will be prefixed by `prefix`
(if provided) to support multi-tenant / multi-app isolation.
**Mark Index Storage:**
This class maintains a `marks_index` (Redis Set) to efficiently track all
mark names within a session. When a mark is created via `add_mark()`, the
mark name is added to this set. This allows quick retrieval of all marks
without scanning all Redis keys. The marks_index key pattern is:
``user_id:{user_id}:session:{session_id}:marks_index``
Each individual mark stores its associated message IDs in a separate Redis
List with the key pattern:
``user_id:{user_id}:session:{session_id}:mark:{mark}``
"""
SESSION_KEY = "user_id:{user_id}:session:{session_id}:messages"
"""Redis key pattern (without prefix) for storing message IDs (ordered) for
a specific session.
"""
SESSION_PATTERN = "user_id:{user_id}:session:{session_id}:*"
"""Redis key pattern (without prefix) for scanning all keys belong to
a specific user and session."""
MARK_KEY = "user_id:{user_id}:session:{session_id}:mark:{mark}"
"""Redis key pattern (without prefix) for storing message IDs that belong
to a specific mark.
"""
MESSAGE_KEY = "user_id:{user_id}:session:{session_id}:msg:{msg_id}"
"""Redis key pattern (without prefix) for storing message payload data."""
MARKS_INDEX_KEY = "user_id:{user_id}:session:{session_id}:marks_index"
"""Redis key pattern (without prefix) for storing all mark names as a set.
This is used to avoid scanning all keys to find marks.
"""
def __init__(
self,
session_id: str = "default_session",
user_id: str = "default_user",
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
connection_pool: ConnectionPool | None = None,
key_prefix: str = "",
key_ttl: int | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Redis based storage by connecting to the Redis
server. You can provide either the connection parameters or an
existing connection pool.
Args:
session_id (`str`, default to `"default_session"`):
The session ID for the storage.
user_id (`str`, default to `"default_user"`):
The user ID for the storage.
host (`str`, default to `"localhost"`):
The Redis server host.
port (`int`, default to `6379`):
The Redis server port.
db (`int`, default to `0`):
The Redis database index.
password (`str | None`, optional):
The password for the Redis server, if required.
connection_pool (`ConnectionPool | None`, optional):
An optional Redis connection pool. If provided, it will be used
instead of creating a new connection.
key_prefix (`str`, default to `""`):
Optional Redis key prefix prepended to every key generated by
this storage. Useful for isolating keys across
apps/environments (e.g. `"prod"`, `"staging"`, `"myapp"`).
key_ttl (`int | None`, default to `None`):
The expired time in seconds for each key. If provided, the
expiration will be refreshed on every access (sliding TTL). If
`None`, the keys will not expire. **Note** the ttl will update
all session related keys, so it's not recommended to set for
large sessions.
**kwargs (`Any`):
Additional keyword arguments to pass to the Redis client.
"""
try:
import redis.asyncio as redis
except ImportError as e:
raise ImportError(
"The 'redis' package is required for RedisStorage. "
"Please install it via 'pip install redis[async]'.",
) from e
super().__init__()
self.session_id = session_id
self.user_id = user_id
self.key_prefix = key_prefix or ""
self.key_ttl = key_ttl
self._client = redis.Redis(
host=host,
port=port,
db=db,
password=password,
connection_pool=connection_pool,
decode_responses=True,
**kwargs,
)
def get_client(self) -> Redis:
"""Get the underlying Redis client.
Returns:
`Redis`:
The Redis client instance.
"""
return self._client
def _decode_if_bytes(self, data: Any) -> Any:
"""Helper method to decode bytes to str if needed.
Args:
data (`Any`):
The data to decode, which may be bytes, bytearray, or str.
Returns:
`Any`:
The decoded string if input was bytes/bytearray, otherwise
the original data.
"""
if isinstance(data, (bytes, bytearray)):
return data.decode("utf-8")
return data
def _decode_list(self, data_list: list) -> list:
"""Helper method to decode a list of potential bytes.
Args:
data_list (`list`):
A list that may contain bytes, bytearray, or str elements.
Returns:
`list`:
A list with all bytes/bytearray elements decoded to str.
"""
return [self._decode_if_bytes(item) for item in data_list]
def _get_session_key(self) -> str:
"""Get the Redis key for the current session.
Returns:
`str`:
The Redis key for storing messages in the current session.
"""
return self.key_prefix + self.SESSION_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _get_session_pattern(self) -> str:
"""Get the Redis key pattern for all keys in the current session.
Returns:
`str`:
The Redis key pattern for all keys in the current session.
"""
return self.key_prefix + self.SESSION_PATTERN.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _get_mark_key(self, mark: str) -> str:
"""Get the Redis key for a specific mark.
Args:
mark (`str`):
The mark name.
Returns:
`str`:
The Redis key for storing message IDs with the given mark.
"""
return self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark=mark,
)
def _get_mark_pattern(self) -> str:
"""Get the Redis key pattern for all marks in the current session.
Returns:
`str`:
The Redis key pattern for all mark keys.
"""
return self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark="*",
)
def _get_marks_index_key(self) -> str:
"""Get the Redis key for the marks index set.
Returns:
`str`:
The Redis key for storing all mark names as a set.
"""
return self.key_prefix + self.MARKS_INDEX_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _extract_mark_from_key(self, mark_key: str) -> str:
"""Extract the mark name from a full mark key.
Args:
mark_key (`str`):
The full Redis key for a mark.
Returns:
`str`:
The mark name extracted from the key.
"""
# Remove the prefix and the base pattern to get the mark name
# Example: "prefix:user_id:xxx:session:yyy:mark:my_mark" -> "my_mark"
prefix_pattern = self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark="",
)
return mark_key.replace(prefix_pattern, "")
def _get_message_key(self, msg_id: str) -> str:
"""Get the Redis key for a specific message.
Args:
msg_id (`str`):
The message ID.
Returns:
`str`:
The Redis key for storing the message data.
"""
return self.key_prefix + self.MESSAGE_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
msg_id=msg_id,
)
async def _refresh_session_ttl(
self,
pipe: Any | None = None,
) -> None:
"""Refresh the TTL for the session keys (if `key_ttl` is set).
Args:
pipe (`Any | None`, optional):
An optional Redis pipeline to use. If `None`, a new pipeline
will be created and executed immediately. If provided, the
expired commands will be added to the pipeline without
executing it.
"""
if self.key_ttl is None:
return
# Create a new pipeline if not provided
should_execute = pipe is None
if pipe is None:
pipe = self._client.pipeline()
cursor = 0
while True:
cursor, keys = await self._client.scan(
cursor,
match=self._get_session_pattern(),
count=100,
)
# Decode keys if they are bytes
keys = self._decode_list(keys)
for key in keys:
await pipe.expire(key, self.key_ttl)
if cursor == 0:
break
if should_execute:
await pipe.execute()
async def _scan_and_migrate_marks(self) -> list[str]:
"""Scan all mark keys and migrate them to the marks index.
This method is only called once for old data that doesn't have
a marks index yet. After migration, the marks index will be
maintained automatically.
Returns:
`list[str]`:
The list of all mark keys found.
"""
mark_keys = []
cursor = 0
while True:
cursor, keys = await self._client.scan(
cursor,
match=self._get_mark_pattern(),
count=50,
)
keys = self._decode_list(keys)
mark_keys.extend(keys)
if cursor == 0:
break
# Build the marks index
if mark_keys:
pipe = self._client.pipeline()
for mark_key in mark_keys:
mark = self._extract_mark_from_key(mark_key)
await pipe.sadd(self._get_marks_index_key(), mark)
await pipe.execute()
return mark_keys
async def _get_all_mark_keys(self) -> list[str]:
"""Get all mark keys, compatible with both old and new data.
For new data (with marks index), this method uses the index directly.
For old data (without marks index), this method scans once and
migrates to the new structure.
Returns:
`list[str]`:
The list of all mark keys.
"""
marks_index_key = self._get_marks_index_key()
# Try to read from the index first
marks = await self._client.smembers(marks_index_key)
if marks:
# Index exists, use it
marks = self._decode_list(list(marks))
return [self._get_mark_key(mark) for mark in marks]
# Index doesn't exist, check if this is a new session
session_exists = await self._client.exists(self._get_session_key())
if not session_exists:
# New session, no data at all, return empty
return []
# Old session without index, need to scan and migrate (only once)
mark_keys = await self._scan_and_migrate_marks()
return mark_keys
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if not (mark is None or isinstance(mark, str)):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if not (exclude_mark is None or isinstance(exclude_mark, str)):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
if mark is None:
# Obtain the message IDs from the session list
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
else:
# Obtain the message IDs from the mark list
msg_ids = await self._client.lrange(
self._get_mark_key(mark),
0,
-1,
)
msg_ids = self._decode_list(msg_ids)
# Exclude messages by exclude_mark
if exclude_mark:
exclude_msg_ids = await self._client.lrange(
self._get_mark_key(exclude_mark),
0,
-1,
)
exclude_msg_ids = self._decode_list(exclude_msg_ids)
msg_ids = [_ for _ in msg_ids if _ not in exclude_msg_ids]
# Use mget for batch retrieval to avoid N+1 queries
messages: list[Msg] = []
if msg_ids:
msg_keys = [self._get_message_key(msg_id) for msg_id in msg_ids]
msg_data_list = await self._client.mget(msg_keys)
for msg_data in msg_data_list:
if msg_data is not None:
# Decode if bytes
msg_data = self._decode_if_bytes(msg_data)
msg_dict = json.loads(msg_data)
messages.append(Msg.from_dict(msg_dict))
# Refresh TTLs
await self._refresh_session_ttl()
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*messages,
]
return messages
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
skip_duplicated: bool = True,
**kwargs: Any,
) -> None:
"""Add message into the storage with the given mark (if provided).
Args:
memories (`Msg | list[Msg]`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
skip_duplicated (`bool`, defaults to `True`):
If `True`, skip messages with duplicate IDs that already exist
in the storage. If `False`, allow duplicate message IDs to be
added to the session list (though the message data will be
overwritten).
"""
if memories is None:
return
if isinstance(memories, Msg):
memories = [memories]
# Normalize marks to a list
if marks is None:
mark_list = []
elif isinstance(marks, str):
mark_list = [marks]
else:
mark_list = marks
# Filter out existing messages if skip_duplicated is True
messages_to_add = memories
if skip_duplicated:
# Get all existing message IDs in the current session
existing_msg_ids = await self._client.lrange(
self._get_session_key(),
0,
-1,
)
existing_msg_ids = self._decode_list(existing_msg_ids)
existing_msg_ids_set = set(existing_msg_ids)
# Filter out messages that already exist
messages_to_add = [
m for m in memories if m.id not in existing_msg_ids_set
]
# If all messages are duplicates, return early
if not messages_to_add:
return
# Use pipeline for atomic operations
pipe = self._client.pipeline()
# Push message ids into the session list
if messages_to_add:
await pipe.rpush(
self._get_session_key(),
*[m.id for m in messages_to_add],
)
# Store message data and marks
for m in messages_to_add:
# Record the message data
await pipe.set(
self._get_message_key(m.id),
json.dumps(m.to_dict(), ensure_ascii=False),
)
# Record the marks if provided
for mark in mark_list:
await pipe.rpush(self._get_mark_key(mark), m.id)
# Maintain the marks index
await pipe.sadd(self._get_marks_index_key(), mark)
# Refresh TTLs
await self._refresh_session_ttl(pipe=pipe)
await pipe.execute()
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
if not msg_ids:
return 0
# Get all mark keys using the new method (compatible with old data)
mark_keys = await self._get_all_mark_keys()
pipe = self._client.pipeline()
for msg_id in msg_ids:
# Remove from the session (0 means remove all occurrences)
await pipe.lrem(self._get_session_key(), 0, msg_id)
# Remove the message data
await pipe.delete(self._get_message_key(msg_id))
# Remove from all marks
for mark_key in mark_keys:
await pipe.lrem(mark_key, 0, msg_id)
# Refresh TTLs
await self._refresh_session_ttl(pipe=pipe)
results = await pipe.execute()
# Count actual deletions from lrem results (every 3rd result
# starting from 0)
removed_count = sum(
1
for i in range(
0,
len(msg_ids) * (2 + len(mark_keys)),
2 + len(mark_keys),
)
if results[i] > 0
)
return removed_count
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the storage by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
total_removed = 0
for m in mark:
mark_key = self._get_mark_key(m)
msg_ids = await self._client.lrange(mark_key, 0, -1)
msg_ids = self._decode_list(msg_ids)
if not msg_ids:
continue
# Remove messages by IDs
removed_count = await self.delete(
msg_ids,
)
total_removed += removed_count
# Delete the mark list
await self._client.delete(mark_key)
# Remove from the marks index
await self._client.srem(self._get_marks_index_key(), m)
# Refresh TTLs
await self._refresh_session_ttl()
return total_removed
async def clear(self) -> None:
"""Clear all messages belong to this session from the storage."""
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
msg_ids = self._decode_list(msg_ids)
# Get all mark keys using the new method (compatible with old data)
mark_keys = await self._get_all_mark_keys()
pipe = self._client.pipeline()
for msg_id in msg_ids:
# Remove the message data
await pipe.delete(self._get_message_key(msg_id))
# Delete the session list
await pipe.delete(self._get_session_key())
# Delete all mark lists
for mark_key in mark_keys:
await pipe.delete(mark_key)
# Delete the marks index
await pipe.delete(self._get_marks_index_key())
await pipe.execute()
async def size(self) -> int:
"""Get the number of messages in the storage.
Returns:
`int`:
The number of messages in the storage.
"""
size = await self._client.llen(self._get_session_key())
await self._refresh_session_ttl()
return size
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
# Determine which message IDs to update
# Get source key based on old_mark
source_key = (
self._get_mark_key(old_mark)
if old_mark is not None
else self._get_session_key()
)
mark_msg_ids = await self._client.lrange(source_key, 0, -1)
mark_msg_ids = self._decode_list(mark_msg_ids)
# Check if we're removing all messages from old_mark
removing_all_from_old_mark = old_mark is not None and (
msg_ids is None or all(mid in set(msg_ids) for mid in mark_msg_ids)
)
# Filter by msg_ids if provided
if msg_ids is not None:
msg_ids_set = set(msg_ids)
mark_msg_ids = [mid for mid in mark_msg_ids if mid in msg_ids_set]
if not mark_msg_ids:
return 0
# Get existing IDs in new_mark list once (if needed)
existing_ids_set = set()
new_mark_key = None
if new_mark is not None:
new_mark_key = self._get_mark_key(new_mark)
existing_ids = await self._client.lrange(new_mark_key, 0, -1)
existing_ids = self._decode_list(existing_ids)
existing_ids_set = set(existing_ids)
# Use pipeline for batch operations
pipe = self._client.pipeline()
updated_count = 0
for msg_id in mark_msg_ids:
# Remove from old_mark list if applicable
if old_mark is not None:
await pipe.lrem(
self._get_mark_key(old_mark),
0,
msg_id,
)
# Add to new_mark list only if not already present
if new_mark is not None and msg_id not in existing_ids_set:
await pipe.rpush(new_mark_key, msg_id)
existing_ids_set.add(msg_id)
# Maintain the marks index
await pipe.sadd(self._get_marks_index_key(), new_mark)
# Count update only if we actually did something
if old_mark is not None or new_mark is not None:
updated_count += 1
# Clean up old_mark only if we removed ALL messages from it
if old_mark is not None and removing_all_from_old_mark:
old_mark_key = self._get_mark_key(old_mark)
# After lrem operations, the old mark list will be empty
# Delete the mark key and remove from index
await pipe.delete(old_mark_key)
await pipe.srem(self._get_marks_index_key(), old_mark)
await self._refresh_session_ttl(pipe=pipe)
await pipe.execute()
return updated_count
async def close(self, close_connection_pool: bool | None = None) -> None:
"""Close the Redis client connection.
Args:
close_connection_pool (`bool | None`, optional):
Decides whether to close the connection pool used by this
Redis client, overriding Redis.auto_close_connection_pool.
By default, let Redis.auto_close_connection_pool decide
whether to close the connection pool
"""
await self._client.aclose(close_connection_pool=close_connection_pool)
async def __aenter__(self) -> "RedisMemory":
"""Enter the async context manager.
Returns:
`RedisMemory`:
The memory instance itself.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
"""Exit the async context manager and close the session.
Args:
exc_type (`type[BaseException] | None`):
The exception type if an exception was raised.
exc_value (`BaseException | None`):
The exception instance if an exception was raised.
traceback (`Any`):
The traceback object if an exception was raised.
"""
await self.close()

View File

@@ -0,0 +1,873 @@
# -*- coding: utf-8 -*-
"""The SQLAlchemy database storage module, which supports storing messages in
a SQL database using SQLAlchemy ORM (e.g., SQLite, PostgreSQL, MySQL)."""
from typing import Any
from sqlalchemy import (
Column,
String,
JSON,
BigInteger,
ForeignKey,
select,
delete,
update,
func,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
)
from sqlalchemy.orm import declarative_base, relationship
from ._base import MemoryBase
from ...message import Msg
Base: Any = declarative_base()
class AsyncSQLAlchemyMemory(MemoryBase):
"""The SQLAlchemy memory storage class for storing messages in a SQL
database using SQLAlchemy ORM, such as SQLite, PostgreSQL, MySQL, etc.
.. note:: All the operations in this class are within a specific session
and user context, identified by `session_id` and `user_id`. Cross-session
or cross-user operations are not supported. For example, the
`remove_messages` method will only remove messages that belong to the
specified `session_id` and `user_id`.
"""
class MessageTable(Base):
"""The default message table definition."""
__tablename__ = "message"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The id column, we use the f"{user_id}-{session_id}-{message_id}"
as the primary key to ensure uniqueness across users and sessions."""
msg = Column(JSON, nullable=False)
"""The message JSON content column"""
session = relationship(
"SessionTable",
back_populates="messages",
)
"""The foreign key to the session id relationship"""
session_id = Column(
String(255),
ForeignKey("session.id"),
nullable=False,
)
"""The foreign key to the session id"""
index = Column(BigInteger, nullable=False, index=True)
"""The index column for ordering messages, so that we can retrieve
messages in the order they were added."""
class MessageMarkTable(Base):
"""The default message mark table definition."""
__tablename__ = "message_mark"
"""The table name"""
msg_id = Column(
String(255),
ForeignKey("message.id", ondelete="CASCADE"),
primary_key=True,
)
"""The message id column"""
mark = Column(String(255), primary_key=True)
"""The mark column"""
class SessionTable(Base):
"""The default session table definition."""
__tablename__ = "session"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The session id column"""
user = relationship("UserTable", back_populates="sessions")
"""The foreign key to the user id relationship"""
user_id = Column(String(255), ForeignKey("users.id"), nullable=False)
"""The foreign key to the user id"""
messages = relationship("MessageTable", back_populates="session")
"""The relationship to messages"""
class UserTable(Base):
"""The default user table definition."""
__tablename__ = "users"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The user id column"""
sessions = relationship("SessionTable", back_populates="user")
"""The relationship to sessions"""
def __init__(
self,
engine_or_session: AsyncEngine | AsyncSession,
session_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Initialize the SqlAlchemyDBStorage with a SQLAlchemy session.
Args:
engine_or_session (`AsyncEngine | AsyncSession`):
The SQLAlchemy asynchronous engine or session to use for
database operations. If you're using a connection pool, maybe
you want to pass in an `AsyncSession` instance.
session_id (`str | None`, optional):
The session ID for the messages. If `None`, a default session
ID will be used.
user_id (`str | None`, optional):
The user ID for the messages. If `None`, a default user ID
will be used.
Raises:
`ValueError`:
If the `engine` parameter is not an instance of
`sqlalchemy.ext.asyncio.AsyncEngine` or `sqlalchemy.
ext.asyncio.AsyncSession`.
"""
super().__init__()
self._db_session: AsyncSession | None = None
if isinstance(engine_or_session, AsyncEngine):
self._session_factory = async_sessionmaker(
bind=engine_or_session,
expire_on_commit=False,
)
elif isinstance(engine_or_session, AsyncSession):
self._session_factory = None
self._db_session = engine_or_session
else:
raise ValueError(
"The 'engine_or_session' parameter must be an instance of "
"sqlalchemy.ext.asyncio.AsyncEngine.",
)
self.session_id = session_id or "default_session"
self.user_id = user_id or "default_user"
# Flag to track if tables and records have been initialized
self._initialized = False
def _make_message_id(self, msg_id: str) -> str:
"""Generate a composite primary key for a message.
Args:
msg_id (`str`):
The original message ID.
Returns:
`str`:
The composite primary key in the format
"{user_id}-{session_id}-{message_id}".
"""
return f"{self.user_id}-{self.session_id}-{msg_id}"
@property
def session(self) -> AsyncSession:
"""Get the current database session, creating one if it doesn't exist.
Returns:
`AsyncSession`:
The current database session.
Note:
- If an external session was provided, it will be returned as-is
- If using internal session factory, a new session will be created
if the current one is None or inactive, and _initialized flag
will be reset to ensure proper re-initialization
"""
# External session: return as-is (managed by caller)
if self._session_factory is None:
return self._db_session
# Internal session: check validity and recreate if needed
if self._db_session is None or not self._db_session.is_active:
self._db_session = self._session_factory()
# Reset initialized flag when creating new session
self._initialized = False
return self._db_session
async def _create_table(self) -> None:
"""Create tables in database."""
# Skip if already initialized
if self._initialized:
return
# Obtain the engine first
engine: AsyncEngine = self.session.bind
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Track if we need to commit
needs_commit = False
# Create user record if not exists
result = await self.session.execute(
select(self.UserTable).filter(
self.UserTable.id == self.user_id,
),
)
user_record = result.scalar_one_or_none()
if user_record is None:
user_record = self.UserTable(
id=self.user_id,
)
self.session.add(user_record)
needs_commit = True
# Create session record if not exists
result = await self.session.execute(
select(self.SessionTable).filter(
self.SessionTable.id == self.session_id,
),
)
session_record = result.scalar_one_or_none()
if session_record is None:
session_record = self.SessionTable(
id=self.session_id,
user_id=self.user_id,
)
self.session.add(session_record)
needs_commit = True
# Commit once if any records were added
if needs_commit:
await self.session.commit()
# Mark as initialized
self._initialized = True
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Raises:
`TypeError`:
If the provided mark is not a string or None.
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if mark is not None and not isinstance(mark, str):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if exclude_mark is not None and not isinstance(exclude_mark, str):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
await self._create_table()
# Step 1: First filter by session_id to narrow down the dataset
# This ensures the database uses the session_id index first
base_query = select(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
)
# Step 2: Apply mark filtering if provided
if mark:
# Join with mark table only on the session-filtered messages
base_query = base_query.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
).filter(
self.MessageMarkTable.mark == mark,
)
# Step 3: Apply exclude_mark filtering if provided
if exclude_mark:
# Use a subquery to find message IDs with the exclude_mark
# within the current session only
exclude_subquery = (
select(self.MessageMarkTable.msg_id)
.filter(
self.MessageMarkTable.msg_id.in_(
select(self.MessageTable.id).filter(
self.MessageTable.session_id == self.session_id,
),
),
self.MessageMarkTable.mark == exclude_mark,
)
.scalar_subquery()
)
# Exclude messages whose IDs are in the subquery
base_query = base_query.filter(
self.MessageTable.id.notin_(exclude_subquery),
)
# Step 4: Order by index to maintain message order
query = base_query.order_by(self.MessageTable.index)
result = await self.session.execute(query)
results = result.scalars().all()
msgs = [Msg.from_dict(result.msg) for result in results]
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*msgs,
]
return msgs
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
skip_duplicated: bool = True,
**kwargs: Any,
) -> None:
"""Add message into the storage with the given mark (if provided).
Args:
memories (`Msg | list[Msg] | None`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
skip_duplicated (`bool`, defaults to `True`):
If `True`, skip messages with duplicate IDs that already exist
in the storage. If `False`, raise an `IntegrityError` when
attempting to add a message with an existing ID.
Raises:
`IntegrityError`:
If a message with the same ID already exists in the storage
and `skip_duplicated` is set to `False`.
"""
if memories is None:
return
# Type checking
if isinstance(memories, Msg):
memories = [memories]
elif not (
isinstance(memories, list)
and all(isinstance(_, Msg) for _ in memories)
):
raise TypeError(
"The 'memories' parameter must be a Msg instance or a list of "
f"Msg instances, but got {type(memories)}.",
)
if isinstance(marks, str):
marks = [marks]
elif marks is not None and not (
isinstance(marks, list) and all(isinstance(m, str) for m in marks)
):
raise TypeError(
"The 'marks' parameter must be a string or a list of strings, "
f"but got {type(marks)}.",
)
# Create table if not exists
await self._create_table()
# If skip_duplicated is True, filter out existing messages
messages_to_add = memories
if skip_duplicated:
existing_msg_ids = set()
result = await self.session.execute(
select(self.MessageTable.id).filter(
self.MessageTable.id.in_(
[self._make_message_id(m.id) for m in memories],
),
),
)
existing_msg_ids = {row[0] for row in result.fetchall()}
messages_to_add = [
m
for m in memories
if self._make_message_id(m.id) not in existing_msg_ids
]
# If all messages are duplicates, return early
if not messages_to_add:
return
# Get the starting index once to avoid race conditions
start_index = await self._get_next_index()
# Add messages to message table
for i, m in enumerate(messages_to_add):
message_record = self.MessageTable(
id=self._make_message_id(m.id),
msg=m.to_dict(),
session_id=self.session_id,
index=start_index + i,
)
self.session.add(message_record)
# Create mark records if marks are provided (use bulk insert)
if marks:
mark_records = [
{"msg_id": self._make_message_id(msg.id), "mark": mark}
for msg in messages_to_add
for mark in marks
]
if mark_records:
if skip_duplicated:
# Query existing mark combinations to avoid duplicates
result = await self.session.execute(
select(
self.MessageMarkTable.msg_id,
self.MessageMarkTable.mark,
),
)
existing_marks = {
(row[0], row[1]) for row in result.fetchall()
}
# Filter out existing mark combinations
mark_records = [
r
for r in mark_records
if (r["msg_id"], r["mark"]) not in existing_marks
]
if mark_records:
await self.session.run_sync(
lambda session: session.bulk_insert_mappings(
self.MessageMarkTable,
mark_records,
),
)
await self.session.commit()
async def _get_next_index(self) -> int:
"""Get the next index for a new message in the current session.
Returns:
`int`:
The next index value.
"""
result = await self.session.execute(
select(self.MessageTable.index)
.filter(self.MessageTable.session_id == self.session_id)
.order_by(self.MessageTable.index.desc())
.limit(1),
)
max_index = result.scalar_one_or_none()
return (max_index + 1) if max_index is not None else 0
async def size(self) -> int:
"""Get the size of the messages in the storage."""
result = await self.session.execute(
select(func.count(self.MessageTable.id)).filter(
self.MessageTable.session_id == self.session_id,
),
)
return result.scalar_one()
async def clear(self) -> None:
"""Clear all messages from the storage."""
# Delete all marks for messages in this session
await self.session.execute(
delete(self.MessageMarkTable).where(
self.MessageMarkTable.msg_id.in_(
select(self.MessageTable.id).filter(
self.MessageTable.session_id == self.session_id,
),
),
),
)
# Then delete all messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
),
)
await self.session.commit()
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the storage by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
# First, find message IDs that have the specified marks
query = (
select(self.MessageTable.id)
.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
)
.filter(
self.MessageTable.session_id == self.session_id,
self.MessageMarkTable.mark.in_(mark),
)
)
result = await self.session.execute(query)
msg_ids = [row[0] for row in result.all()]
if not msg_ids:
return 0
# Store the count before deletion
deleted_count = len(msg_ids)
# Delete marks first
await self.session.execute(
delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
),
)
# Then delete the messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
self.MessageTable.id.in_(msg_ids),
),
)
await self.session.commit()
return deleted_count
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
.. note:: Although MessageMarkTable has CASCADE delete on foreign key,
we explicitly delete marks first for reliability across all database
engines and configurations. SQLAlchemy's bulk delete bypasses
ORM-level cascades, and SQLite requires foreign keys to be
explicitly enabled.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
# Convert to composite keys
composite_ids = [self._make_message_id(msg_id) for msg_id in msg_ids]
if not composite_ids:
return 0
# Store the count before deletion
deleted_count = len(composite_ids)
# Delete related marks first (explicit cleanup for reliability)
await self.session.execute(
delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(composite_ids),
),
)
# Then delete the messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
self.MessageTable.id.in_(composite_ids),
),
)
await self.session.commit()
return deleted_count
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
# Type checking
if new_mark is not None and not isinstance(new_mark, str):
raise ValueError(
f"The 'new_mark' parameter must be a string or None, "
f"but got {type(new_mark)}.",
)
if old_mark is not None and not isinstance(old_mark, str):
raise ValueError(
f"The 'old_mark' parameter must be a string or None, "
f"but got {type(old_mark)}.",
)
if msg_ids is not None and not (
isinstance(msg_ids, list)
and all(isinstance(_, str) for _ in msg_ids)
):
raise ValueError(
f"The 'msg_ids' parameter must be a list of strings or None, "
f"but got {type(msg_ids)}.",
)
# First obtain the message ids that belong to this session
query = select(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
)
# Filter by msg_ids if provided
if msg_ids is not None:
# Convert to composite keys
composite_ids = [
self._make_message_id(msg_id) for msg_id in msg_ids
]
query = query.filter(self.MessageTable.id.in_(composite_ids))
# Filter by old_mark if provided
if old_mark is not None:
query = query.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
).filter(self.MessageMarkTable.mark == old_mark)
# Obtain the message records
result = await self.session.execute(query)
msg_ids = [str(_.id) for _ in result.scalars().all()]
# Return early if no messages found
if not msg_ids:
return 0
if new_mark:
if old_mark:
# Replace old_mark with new_mark
return await self._replace_message_mark(
msg_ids=msg_ids,
old_mark=old_mark,
new_mark=new_mark,
)
# Add new_mark to the messages
return await self._add_message_mark(
msg_ids=msg_ids,
mark=new_mark,
)
# Remove all marks from the messages
return await self._remove_message_mark(
msg_ids=msg_ids,
old_mark=old_mark,
)
async def _replace_message_mark(
self,
msg_ids: list[str],
old_mark: str,
new_mark: str,
) -> int:
"""Replace the old mark with the new mark for the given messages by
updating records in the message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be updated.
old_mark (`str`):
The old mark to be replaced.
new_mark (`str`):
The new mark to be set.
Returns:
`int`:
The number of messages updated.
"""
await self.session.execute(
update(self.MessageMarkTable)
.filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
self.MessageMarkTable.mark == old_mark,
)
.values(mark=new_mark),
)
await self.session.commit()
return len(msg_ids)
async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int:
"""Mark the messages with the given mark by adding records to the
message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be marked.
mark (`str`):
The mark to be added to the messages.
Returns:
`int`:
The number of messages marked.
"""
# Use bulk insert for better performance
mark_records = [{"msg_id": msg_id, "mark": mark} for msg_id in msg_ids]
if mark_records:
await self.session.run_sync(
lambda session: session.bulk_insert_mappings(
self.MessageMarkTable,
mark_records,
),
)
await self.session.commit()
return len(msg_ids)
async def _remove_message_mark(
self,
msg_ids: list[str],
old_mark: str | None,
) -> int:
"""Remove marks from the messages by deleting records from the
message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be unmarked.
old_mark (`str | None`):
The old mark to be removed. If `None`, all marks will be
removed from the messages.
Returns:
`int`:
The number of messages unmarked.
"""
delete_query = delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
)
if old_mark:
delete_query = delete_query.filter(
self.MessageMarkTable.mark == old_mark,
)
await self.session.execute(delete_query)
await self.session.commit()
return len(msg_ids)
async def close(self) -> None:
"""Close the database session."""
if self._db_session and self._db_session.is_active:
await self._db_session.close()
self._db_session = None
self._initialized = False
async def __aenter__(self) -> "AsyncSQLAlchemyMemory":
"""Enter the async context manager.
Returns:
`AsyncSQLAlchemyMemory`:
The memory instance itself.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
"""Exit the async context manager and close the session.
Args:
exc_type (`type[BaseException] | None`):
The exception type if an exception was raised.
exc_value (`BaseException | None`):
The exception instance if an exception was raised.
traceback (`Any`):
The traceback object if an exception was raised.
"""
await self.close()