Files
tw2/src/agentscope/embedding/_openai_embedding.py
jimi a842f1861f
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
chore: initial import of standalone agentscope project
2026-03-02 18:21:40 +08:00

110 lines
3.4 KiB
Python

# -*- coding: utf-8 -*-
"""The OpenAI text embedding model class."""
from datetime import datetime
from typing import Any, List
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ._embedding_base import EmbeddingModelBase
from ..message import TextBlock
class OpenAITextEmbedding(EmbeddingModelBase):
"""OpenAI text embedding model class."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 1024,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the OpenAI text embedding model class.
Args:
api_key (`str`):
The OpenAI API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
# TODO: handle batch size limit and token limit
"""
import openai
super().__init__(model_name, dimensions)
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the OpenAI embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"input": gather_text,
"model": self.model_name,
"dimensions": self.dimensions,
"encoding_format": "float",
**kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = await self.client.embeddings.create(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[_.embedding for _ in response.data],
)
return EmbeddingResponse(
embeddings=[_.embedding for _ in response.data],
usage=EmbeddingUsage(
tokens=response.usage.total_tokens,
time=time,
),
)