chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
84
src/agentscope/token/_huggingface_token_counter.py
Normal file
84
src/agentscope/token/_huggingface_token_counter.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The huggingface token counter class."""
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agentscope.token._token_base import TokenCounterBase
|
||||
|
||||
|
||||
class HuggingFaceTokenCounter(TokenCounterBase):
|
||||
"""The token counter for Huggingface models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_model_name_or_path: str,
|
||||
use_mirror: bool = False,
|
||||
use_fast: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the huggingface token counter.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str`):
|
||||
The name or path of the pretrained model, which will be used
|
||||
to download the tokenizer from Huggingface Hub.
|
||||
use_mirror (`bool`, defaults to `False`):
|
||||
Whether to enable the HuggingFace mirror, which is useful for
|
||||
users in China.
|
||||
use_fast (`bool`, defaults to `False`):
|
||||
The argument that will be passed to the tokenizer.
|
||||
trust_remote_code (`bool`, defaults to `False`):
|
||||
The argument that will be passed to the tokenizer.
|
||||
**kwargs:
|
||||
Additional keyword arguments that will be passed to the
|
||||
tokenizer.
|
||||
"""
|
||||
if use_mirror:
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
use_fast=use_fast,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
f"The tokenizer for model {pretrained_model_name_or_path} in "
|
||||
f"transformers does not have chat template.",
|
||||
)
|
||||
|
||||
async def count(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Count the number of tokens with the tokenizer download from
|
||||
HuggingFace hub.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of message dictionaries
|
||||
tools (`list[dict] | None`, defaults to `None`):
|
||||
The JSON schema of the tools, which will also be involved in
|
||||
the token counting.
|
||||
**kwargs (`Any`):
|
||||
The additional keyword arguments that will be passed to the
|
||||
tokenizer, e.g. `chat_template`, `padding`, etc.
|
||||
"""
|
||||
|
||||
tokenized_msgs = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="np",
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
return len(tokenized_msgs)
|
||||
Reference in New Issue
Block a user