chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
27
src/agentscope/embedding/__init__.py
Normal file
27
src/agentscope/embedding/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding module in agentscope."""
|
||||
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._dashscope_embedding import DashScopeTextEmbedding
|
||||
from ._dashscope_multimodal_embedding import DashScopeMultiModalEmbedding
|
||||
from ._openai_embedding import OpenAITextEmbedding
|
||||
from ._gemini_embedding import GeminiTextEmbedding
|
||||
from ._ollama_embedding import OllamaTextEmbedding
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._file_cache import FileEmbeddingCache
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingModelBase",
|
||||
"EmbeddingUsage",
|
||||
"EmbeddingResponse",
|
||||
"DashScopeTextEmbedding",
|
||||
"DashScopeMultiModalEmbedding",
|
||||
"OpenAITextEmbedding",
|
||||
"GeminiTextEmbedding",
|
||||
"OllamaTextEmbedding",
|
||||
"EmbeddingCacheBase",
|
||||
"FileEmbeddingCache",
|
||||
]
|
||||
63
src/agentscope/embedding/_cache_base.py
Normal file
63
src/agentscope/embedding/_cache_base.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding cache base class."""
|
||||
from abc import abstractmethod
|
||||
from typing import List, Any
|
||||
|
||||
from ..types import (
|
||||
JSONSerializableObject,
|
||||
Embedding,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingCacheBase:
|
||||
"""Base class for embedding caches, which is responsible for storing and
|
||||
retrieving embeddings."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(
|
||||
self,
|
||||
embeddings: List[Embedding],
|
||||
identifier: JSONSerializableObject,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
embeddings (`List[Embedding]`):
|
||||
The embeddings to store.
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to distinguish the embeddings.
|
||||
overwrite (`bool`, defaults to `False`):
|
||||
Whether to overwrite existing embeddings with the same
|
||||
identifier. If `True`, existing embeddings will be replaced.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> List[Embedding] | None:
|
||||
"""Retrieve the embeddings with the given identifier. If not
|
||||
found, return `None`.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to retrieve the embeddings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def remove(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> None:
|
||||
"""Remove the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to remove the embeddings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cached embeddings."""
|
||||
169
src/agentscope/embedding/_dashscope_embedding.py
Normal file
169
src/agentscope/embedding/_dashscope_embedding.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope embedding module in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from .._logging import logger
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class DashScopeTextEmbedding(EmbeddingModelBase):
|
||||
"""DashScope text embedding API class.
|
||||
|
||||
.. note:: From the `official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_:
|
||||
|
||||
- The max batch size that DashScope text embedding API
|
||||
supports is 10 for `text-embedding-v4` and `text-embedding-v3` models, and
|
||||
25 for `text-embedding-v2` and `text-embedding-v1` models.
|
||||
- The max token limit for a single input is 8192 tokens for `v4` and `v3`
|
||||
models, and 2048 tokens for `v2` and `v1` models.
|
||||
|
||||
"""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 1024,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.api_key = api_key
|
||||
self.embedding_cache = embedding_cache
|
||||
self.batch_size_limit = 10
|
||||
|
||||
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
|
||||
"""Call the DashScope embedding API by the given keyword arguments."""
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
start_time = datetime.now()
|
||||
response = dashscope.embeddings.TextEmbedding.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get embedding from DashScope API: {response}",
|
||||
)
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[
|
||||
_["embedding"] for _ in response.output["embeddings"]
|
||||
],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_["embedding"] for _ in response.output["embeddings"]],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=response.usage["total_tokens"],
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the DashScope embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
if len(gather_text) > self.batch_size_limit:
|
||||
logger.info(
|
||||
"The input texts (%d) will be embedded with %d API calls due "
|
||||
f"to the batch size limit of {self.batch_size_limit} for "
|
||||
f"DashScope embedding API.",
|
||||
len(gather_text),
|
||||
(len(gather_text) + self.batch_size_limit - 1)
|
||||
// self.batch_size_limit,
|
||||
)
|
||||
|
||||
# Handle the batch size limit for DashScope embedding API
|
||||
collected_embeddings = []
|
||||
collected_time = 0.0
|
||||
collected_tokens = 0
|
||||
collected_source: Literal["cache", "api"] = "cache"
|
||||
for _ in range(0, len(gather_text), self.batch_size_limit):
|
||||
batch_texts = gather_text[_ : _ + self.batch_size_limit]
|
||||
batch_kwargs = {
|
||||
"input": batch_texts,
|
||||
"model": self.model_name,
|
||||
"dimension": self.dimensions,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
res = await self._call_api(batch_kwargs)
|
||||
|
||||
collected_embeddings.extend(res.embeddings)
|
||||
collected_time += res.usage.time
|
||||
if res.usage.tokens:
|
||||
collected_tokens += res.usage.tokens
|
||||
if res.source == "api":
|
||||
collected_source = "api"
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=collected_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=collected_tokens,
|
||||
time=collected_time,
|
||||
),
|
||||
source=collected_source,
|
||||
)
|
||||
244
src/agentscope/embedding/_dashscope_multimodal_embedding.py
Normal file
244
src/agentscope/embedding/_dashscope_multimodal_embedding.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope multimodal embedding model in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import (
|
||||
VideoBlock,
|
||||
ImageBlock,
|
||||
TextBlock,
|
||||
)
|
||||
|
||||
|
||||
class DashScopeMultiModalEmbedding(EmbeddingModelBase):
|
||||
"""The DashScope multimodal embedding API, supporting text, image and
|
||||
video embedding."""
|
||||
|
||||
supported_modalities: list[str] = ["text", "image", "video"]
|
||||
"""This class supports text, image and video input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int | None = None,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multimodal embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model, e.g. "multimodal-embedding-
|
||||
v1", "tongyi-embedding-vision-plus".
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712517>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
path_doc = (
|
||||
"https://bailian.console.aliyun.com/?tab=api#/api/?type=model&"
|
||||
"url=2712517"
|
||||
)
|
||||
self.batch_size_limit = 1
|
||||
|
||||
if model_name.startswith("tongyi-embedding-vision-plus"):
|
||||
self.batch_size_limit = 8
|
||||
if dimensions is None:
|
||||
dimensions = 1152
|
||||
elif dimensions != 1152:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 1152, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
if model_name.startswith("tongyi-embedding-vision-flash"):
|
||||
self.batch_size_limit = 8
|
||||
if dimensions is None:
|
||||
dimensions = 768
|
||||
elif dimensions != 768:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 768, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
if model_name.startswith("multimodal-embedding-v"):
|
||||
if dimensions is None:
|
||||
dimensions = 1024
|
||||
elif dimensions != 1024:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 1024, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
refined_dimensions: int = 1024
|
||||
if dimensions is not None:
|
||||
refined_dimensions = dimensions
|
||||
super().__init__(model_name, refined_dimensions)
|
||||
|
||||
self.api_key = api_key
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
inputs: list[TextBlock | ImageBlock | VideoBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the DashScope multimodal embedding API, which accepts text,
|
||||
image, and video data.
|
||||
|
||||
Args:
|
||||
inputs (`list[TextBlock | ImageBlock | VideoBlock]`):
|
||||
The input data to be embedded. It can be a list of text,
|
||||
image, and video blocks.
|
||||
|
||||
Returns:
|
||||
`EmbeddingResponse`:
|
||||
The embedding response object, which contains the embeddings
|
||||
and usage information.
|
||||
"""
|
||||
# check data type
|
||||
formatted_data = []
|
||||
for _ in inputs:
|
||||
if (
|
||||
not isinstance(_, dict)
|
||||
or "type" not in _
|
||||
or _["type"]
|
||||
not in [
|
||||
"text",
|
||||
"image",
|
||||
"video",
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid data : {_}. It should be a list of "
|
||||
"TextBlock, ImageBlock, or VideoBlock.",
|
||||
)
|
||||
if (
|
||||
_["type"] == "video"
|
||||
and _.get("source", {}).get("type") != "url"
|
||||
):
|
||||
raise ValueError(
|
||||
f"The multimodal embedding API only supports URL input "
|
||||
f"for video data, but got {_}.",
|
||||
)
|
||||
|
||||
if _["type"] == "text":
|
||||
assert "text" in _, (
|
||||
f"Invalid text block: {_}. It should contain a "
|
||||
f"'text' field.",
|
||||
)
|
||||
formatted_data.append({"text": _["text"]})
|
||||
|
||||
elif _["type"] == "video":
|
||||
formatted_data.append({"video": _["source"]["url"]})
|
||||
|
||||
elif (
|
||||
_["type"] == "image"
|
||||
and "source" in _
|
||||
and _["source"].get("type") in ["base64", "url"]
|
||||
):
|
||||
typ = _["source"]["type"]
|
||||
if typ == "base64":
|
||||
formatted_data.append(
|
||||
{
|
||||
"image": f'data:{_["source"]["media_type"]};'
|
||||
f'base64,{_["source"]["data"]}',
|
||||
},
|
||||
)
|
||||
elif typ == "url":
|
||||
formatted_data.append(
|
||||
{"image": _["source"]["url"]},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid block {_}. It should be a valid TextBlock, "
|
||||
f"ImageBlock, or VideoBlock.",
|
||||
)
|
||||
|
||||
# Handle the batch size limit of the DashScope multimodal embedding API
|
||||
collected_embeddings = []
|
||||
collected_time = 0.0
|
||||
collected_tokens = 0
|
||||
collected_source: Literal["cache", "api"] = "cache"
|
||||
for _ in range(0, len(formatted_data), self.batch_size_limit):
|
||||
batch_data = formatted_data[_ : _ + self.batch_size_limit]
|
||||
batch_kwargs = {
|
||||
"input": batch_data,
|
||||
"model": self.model_name,
|
||||
**kwargs,
|
||||
}
|
||||
res = await self._call_api(batch_kwargs)
|
||||
|
||||
collected_embeddings.extend(res.embeddings)
|
||||
collected_time += res.usage.time
|
||||
if res.usage.tokens:
|
||||
collected_tokens += res.usage.tokens
|
||||
if res.source == "api":
|
||||
collected_source = "api"
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=collected_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=collected_tokens,
|
||||
time=collected_time,
|
||||
),
|
||||
source=collected_source,
|
||||
)
|
||||
|
||||
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
|
||||
"""
|
||||
Call the DashScope multimodal embedding API by the given arguments.
|
||||
"""
|
||||
# Search in cache first
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
start_time = datetime.now()
|
||||
res = dashscope.MultiModalEmbedding.call(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get embedding from DashScope API: {res}",
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_["embedding"] for _ in res.output["embeddings"]],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=res.usage.get(
|
||||
"image_tokens",
|
||||
0,
|
||||
)
|
||||
+ res.usage.get(
|
||||
"input_tokens",
|
||||
0,
|
||||
),
|
||||
time=time,
|
||||
),
|
||||
source="api",
|
||||
)
|
||||
45
src/agentscope/embedding/_embedding_base.py
Normal file
45
src/agentscope/embedding/_embedding_base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding model base class."""
|
||||
from typing import Any
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
|
||||
|
||||
class EmbeddingModelBase:
|
||||
"""Base class for embedding models."""
|
||||
|
||||
model_name: str
|
||||
"""The embedding model name"""
|
||||
|
||||
supported_modalities: list[str]
|
||||
"""The supported data modalities, e.g. "text", "image", "video"."""
|
||||
|
||||
dimensions: int
|
||||
"""The dimensions of the embedding vector."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dimensions: int,
|
||||
) -> None:
|
||||
"""Initialize the embedding model base class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`):
|
||||
The dimension of the embedding vector.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.dimensions = dimensions
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the embedding API with the given arguments."""
|
||||
raise NotImplementedError(
|
||||
f"The {self.__class__.__name__} class does not implement "
|
||||
f"the __call__ method.",
|
||||
)
|
||||
32
src/agentscope/embedding/_embedding_response.py
Normal file
32
src/agentscope/embedding/_embedding_response.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding response class."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, List
|
||||
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from .._utils._common import _get_timestamp
|
||||
from .._utils._mixin import DictMixin
|
||||
from ..types import Embedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResponse(DictMixin):
|
||||
"""The embedding response class."""
|
||||
|
||||
embeddings: List[Embedding]
|
||||
"""The embedding data"""
|
||||
|
||||
id: str = field(default_factory=lambda: _get_timestamp(True))
|
||||
"""The identity of the embedding response"""
|
||||
|
||||
created_at: str = field(default_factory=_get_timestamp)
|
||||
"""The timestamp of the embedding response creation"""
|
||||
|
||||
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
|
||||
"""The type of the response, must be `embedding`."""
|
||||
|
||||
usage: EmbeddingUsage | None = field(default_factory=lambda: None)
|
||||
"""The usage of the embedding model API invocation, if available."""
|
||||
|
||||
source: Literal["cache", "api"] = field(default_factory=lambda: "api")
|
||||
"""If the response comes from the cache or the API."""
|
||||
20
src/agentscope/embedding/_embedding_usage.py
Normal file
20
src/agentscope/embedding/_embedding_usage.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding usage class in agentscope."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from .._utils._mixin import DictMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingUsage(DictMixin):
|
||||
"""The usage of an embedding model API invocation."""
|
||||
|
||||
time: float
|
||||
"""The time used in seconds."""
|
||||
|
||||
tokens: int | None = field(default_factory=lambda: None)
|
||||
"""The number of tokens used, if available."""
|
||||
|
||||
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
|
||||
"""The type of the usage, must be `embedding`."""
|
||||
187
src/agentscope/embedding/_file_cache.py
Normal file
187
src/agentscope/embedding/_file_cache.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""A file embedding cache implementation for storing and retrieving
|
||||
embeddings in binary files."""
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from .._logging import logger
|
||||
from ..types import (
|
||||
Embedding,
|
||||
JSONSerializableObject,
|
||||
)
|
||||
|
||||
|
||||
class FileEmbeddingCache(EmbeddingCacheBase):
|
||||
"""The embedding cache class that stores each embeddings vector in
|
||||
binary files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: str = "./.cache/embeddings",
|
||||
max_file_number: int | None = None,
|
||||
max_cache_size: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the file embedding cache class.
|
||||
|
||||
Args:
|
||||
cache_dir (`str`, defaults to `"./.cache/embeddings"`):
|
||||
The directory to store the embedding files.
|
||||
max_file_number (`int | None`, defaults to `None`):
|
||||
The maximum number of files to keep in the cache directory. If
|
||||
exceeded, the oldest files will be removed.
|
||||
max_cache_size (`int | None`, defaults to `None`):
|
||||
The maximum size of the cache directory in MB. If exceeded,
|
||||
the oldest files will be removed until the size is within the
|
||||
limit.
|
||||
"""
|
||||
self._cache_dir = os.path.abspath(cache_dir)
|
||||
self.max_file_number = max_file_number
|
||||
self.max_cache_size = max_cache_size
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> str:
|
||||
"""The cache directory where the embedding files are stored."""
|
||||
if not os.path.exists(self._cache_dir):
|
||||
os.makedirs(self._cache_dir, exist_ok=True)
|
||||
return self._cache_dir
|
||||
|
||||
async def store(
|
||||
self,
|
||||
embeddings: List[Embedding],
|
||||
identifier: JSONSerializableObject,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
embeddings (`List[Embedding]`):
|
||||
The embeddings to store.
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to distinguish the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
overwrite (`bool`, defaults to `False`):
|
||||
Whether to overwrite existing embeddings with the same
|
||||
identifier. If `True`, existing embeddings will be replaced.
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
if not os.path.isfile(path_file):
|
||||
raise RuntimeError(
|
||||
f"Path {path_file} exists but is not a file.",
|
||||
)
|
||||
|
||||
if overwrite:
|
||||
np.save(path_file, embeddings)
|
||||
await self._maintain_cache_dir()
|
||||
else:
|
||||
np.save(path_file, embeddings)
|
||||
await self._maintain_cache_dir()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> List[Embedding] | None:
|
||||
"""Retrieve the embeddings with the given identifier. If not found,
|
||||
return `None`.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to retrieve the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
return np.load(os.path.join(self.cache_dir, filename)).tolist()
|
||||
return None
|
||||
|
||||
async def remove(self, identifier: JSONSerializableObject) -> None:
|
||||
"""Remove the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifiers to remove the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
os.remove(path_file)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {path_file} does not exist.")
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache directory by removing all files."""
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith(".npy"):
|
||||
os.remove(os.path.join(self.cache_dir, filename))
|
||||
|
||||
def _get_cache_size(self) -> float:
|
||||
"""Get the current size of the cache directory in MB."""
|
||||
total_size = 0
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith(".npy"):
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
if os.path.isfile(path_file):
|
||||
total_size += os.path.getsize(path_file)
|
||||
return total_size / (1024.0 * 1024.0)
|
||||
|
||||
@staticmethod
|
||||
def _get_filename(identifier: JSONSerializableObject) -> str:
|
||||
"""Generate a filename based on the identifier."""
|
||||
json_str = json.dumps(identifier, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + ".npy"
|
||||
|
||||
async def _maintain_cache_dir(self) -> None:
|
||||
"""Maintain the cache directory by removing old files if the number of
|
||||
files exceeds the maximum limit or if the cache size exceeds the
|
||||
maximum size."""
|
||||
files = [
|
||||
(_.name, _.stat().st_mtime)
|
||||
for _ in os.scandir(self.cache_dir)
|
||||
if _.is_file() and _.name.endswith(".npy")
|
||||
]
|
||||
files.sort(key=lambda x: x[1])
|
||||
|
||||
if self.max_file_number and len(files) > self.max_file_number:
|
||||
for file_name, _ in files[: 0 - self.max_file_number]:
|
||||
os.remove(os.path.join(self.cache_dir, file_name))
|
||||
logger.info(
|
||||
"Remove cached embedding file %s for limited number "
|
||||
"of files (%d).",
|
||||
file_name,
|
||||
self.max_file_number,
|
||||
)
|
||||
files = files[0 - self.max_file_number :]
|
||||
|
||||
if (
|
||||
self.max_cache_size is not None
|
||||
and self._get_cache_size() > self.max_cache_size
|
||||
):
|
||||
removed_files = []
|
||||
for filename, _ in files:
|
||||
os.remove(os.path.join(self.cache_dir, filename))
|
||||
removed_files.append(filename)
|
||||
if self._get_cache_size() <= self.max_cache_size:
|
||||
break
|
||||
|
||||
if removed_files:
|
||||
logger.info(
|
||||
"Remove %d cached embedding file(s) for limited "
|
||||
"cache size (%d MB).",
|
||||
len(removed_files),
|
||||
self.max_cache_size,
|
||||
)
|
||||
109
src/agentscope/embedding/_gemini_embedding.py
Normal file
109
src/agentscope/embedding/_gemini_embedding.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The gemini text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class GeminiTextEmbedding(EmbeddingModelBase):
|
||||
"""The Gemini text embedding model."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 3072,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Gemini text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The Gemini API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 3072):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://ai.google.dev/gemini-api/docs/embeddings?hl=zh-cn#control-embedding-size>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
from google import genai
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = genai.Client(api_key=api_key, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""The Gemini embedding API call.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
|
||||
# TODO: handle the batch size limit
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": self.model_name,
|
||||
"contents": gather_text,
|
||||
"config": kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = self.client.models.embed_content(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[_.values for _ in response.embeddings],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_.values for _ in response.embeddings],
|
||||
usage=EmbeddingUsage(
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
106
src/agentscope/embedding/_ollama_embedding.py
Normal file
106
src/agentscope/embedding/_ollama_embedding.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The ollama text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Any
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ..embedding import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class OllamaTextEmbedding(EmbeddingModelBase):
|
||||
"""The Ollama embedding model."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dimensions: int,
|
||||
host: str | None = None,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Ollama text embedding model class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`):
|
||||
The dimension of the embedding vector, the parameter should be
|
||||
provided according to the model used.
|
||||
host (`str | None`, defaults to `None`):
|
||||
The host URL for the Ollama API.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
import ollama
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = ollama.AsyncClient(host=host, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the Ollama embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"input": gather_text,
|
||||
"model": self.model_name,
|
||||
"dimensions": self.dimensions,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = await self.client.embed(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=response.embeddings,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=response.embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
109
src/agentscope/embedding/_openai_embedding.py
Normal file
109
src/agentscope/embedding/_openai_embedding.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The OpenAI text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class OpenAITextEmbedding(EmbeddingModelBase):
|
||||
"""OpenAI text embedding model class."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 1024,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The OpenAI API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
|
||||
# TODO: handle batch size limit and token limit
|
||||
"""
|
||||
import openai
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the OpenAI embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"input": gather_text,
|
||||
"model": self.model_name,
|
||||
"dimensions": self.dimensions,
|
||||
"encoding_format": "float",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = await self.client.embeddings.create(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=response.usage.total_tokens,
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user