chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
109
src/agentscope/embedding/_openai_embedding.py
Normal file
109
src/agentscope/embedding/_openai_embedding.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The OpenAI text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class OpenAITextEmbedding(EmbeddingModelBase):
|
||||
"""OpenAI text embedding model class."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 1024,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The OpenAI API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
|
||||
# TODO: handle batch size limit and token limit
|
||||
"""
|
||||
import openai
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the OpenAI embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"input": gather_text,
|
||||
"model": self.model_name,
|
||||
"dimensions": self.dimensions,
|
||||
"encoding_format": "float",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = await self.client.embeddings.create(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=response.usage.total_tokens,
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user