Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""The OpenAI text embedding model class."""
|
|
from datetime import datetime
|
|
from typing import Any, List
|
|
|
|
from ._embedding_response import EmbeddingResponse
|
|
from ._embedding_usage import EmbeddingUsage
|
|
from ._cache_base import EmbeddingCacheBase
|
|
from ._embedding_base import EmbeddingModelBase
|
|
from ..message import TextBlock
|
|
|
|
|
|
class OpenAITextEmbedding(EmbeddingModelBase):
|
|
"""OpenAI text embedding model class."""
|
|
|
|
supported_modalities: list[str] = ["text"]
|
|
"""This class only supports text input."""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
model_name: str,
|
|
dimensions: int = 1024,
|
|
embedding_cache: EmbeddingCacheBase | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the OpenAI text embedding model class.
|
|
|
|
Args:
|
|
api_key (`str`):
|
|
The OpenAI API key.
|
|
model_name (`str`):
|
|
The name of the embedding model.
|
|
dimensions (`int`, defaults to 1024):
|
|
The dimension of the embedding vector.
|
|
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
|
The embedding cache class instance, used to cache the
|
|
embedding results to avoid repeated API calls.
|
|
|
|
# TODO: handle batch size limit and token limit
|
|
"""
|
|
import openai
|
|
|
|
super().__init__(model_name, dimensions)
|
|
|
|
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
|
|
self.embedding_cache = embedding_cache
|
|
|
|
async def __call__(
|
|
self,
|
|
text: List[str | TextBlock],
|
|
**kwargs: Any,
|
|
) -> EmbeddingResponse:
|
|
"""Call the OpenAI embedding API.
|
|
|
|
Args:
|
|
text (`List[str | TextBlock]`):
|
|
The input text to be embedded. It can be a list of strings.
|
|
"""
|
|
gather_text = []
|
|
for _ in text:
|
|
if isinstance(_, dict) and "text" in _:
|
|
gather_text.append(_["text"])
|
|
elif isinstance(_, str):
|
|
gather_text.append(_)
|
|
else:
|
|
raise ValueError(
|
|
"Input text must be a list of strings or TextBlock dicts.",
|
|
)
|
|
|
|
kwargs = {
|
|
"input": gather_text,
|
|
"model": self.model_name,
|
|
"dimensions": self.dimensions,
|
|
"encoding_format": "float",
|
|
**kwargs,
|
|
}
|
|
|
|
if self.embedding_cache:
|
|
cached_embeddings = await self.embedding_cache.retrieve(
|
|
identifier=kwargs,
|
|
)
|
|
if cached_embeddings:
|
|
return EmbeddingResponse(
|
|
embeddings=cached_embeddings,
|
|
usage=EmbeddingUsage(
|
|
tokens=0,
|
|
time=0,
|
|
),
|
|
source="cache",
|
|
)
|
|
|
|
start_time = datetime.now()
|
|
response = await self.client.embeddings.create(**kwargs)
|
|
time = (datetime.now() - start_time).total_seconds()
|
|
|
|
if self.embedding_cache:
|
|
await self.embedding_cache.store(
|
|
identifier=kwargs,
|
|
embeddings=[_.embedding for _ in response.data],
|
|
)
|
|
|
|
return EmbeddingResponse(
|
|
embeddings=[_.embedding for _ in response.data],
|
|
usage=EmbeddingUsage(
|
|
tokens=response.usage.total_tokens,
|
|
time=time,
|
|
),
|
|
)
|