chore: initialize sandbox and overwrite remote content
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
25
src/agentscope/tuner/__init__.py
Normal file
25
src/agentscope/tuner/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The learning module of AgentScope, including RL and SFT."""
|
||||
|
||||
from ._tune import tune
|
||||
from ._dataset import DatasetConfig
|
||||
from ._judge import JudgeType, JudgeOutput
|
||||
from ._workflow import WorkflowType, WorkflowOutput
|
||||
from ._algorithm import AlgorithmConfig
|
||||
from ._model import TunerModelConfig, TinkerConfig
|
||||
from ._config import check_judge_function, check_workflow_function
|
||||
|
||||
|
||||
__all__ = [
|
||||
"tune",
|
||||
"AlgorithmConfig",
|
||||
"WorkflowType",
|
||||
"WorkflowOutput",
|
||||
"JudgeType",
|
||||
"JudgeOutput",
|
||||
"DatasetConfig",
|
||||
"TunerModelConfig",
|
||||
"TinkerConfig",
|
||||
"check_workflow_function",
|
||||
"check_judge_function",
|
||||
]
|
||||
42
src/agentscope/tuner/_algorithm.py
Normal file
42
src/agentscope/tuner/_algorithm.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""AlgorithmConfig definition for tuner."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AlgorithmConfig(BaseModel):
|
||||
"""Algorithm configuration for tuning."""
|
||||
|
||||
algorithm_type: str = Field(
|
||||
description=(
|
||||
"The tuning algorithm type "
|
||||
"e.g., 'multi_step_grpo', 'sft'."
|
||||
"Please refer to https://github.com/agentscope-ai/Trinity-RFT"
|
||||
"for all supported algorithms. We recommend 'multi_step_grpo'"
|
||||
"for most agent tuning scenarios."
|
||||
),
|
||||
default="multi_step_grpo",
|
||||
)
|
||||
learning_rate: float = Field(
|
||||
description="The learning rate for the algorithm.",
|
||||
default=1e-6,
|
||||
)
|
||||
group_size: int = Field(
|
||||
description=(
|
||||
"The group size for algorithms "
|
||||
"requiring group rollout, e.g., GRPO."
|
||||
),
|
||||
default=8,
|
||||
)
|
||||
batch_size: int = Field(
|
||||
description="The batch size of each training step.",
|
||||
default=32,
|
||||
)
|
||||
save_interval_steps: int = Field(
|
||||
description="The interval steps to save the model.",
|
||||
default=100,
|
||||
)
|
||||
eval_interval_steps: int = Field(
|
||||
description="The interval steps to evaluate the model.",
|
||||
default=100,
|
||||
)
|
||||
267
src/agentscope/tuner/_config.py
Normal file
267
src/agentscope/tuner/_config.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Configuration conversion for tuner."""
|
||||
from typing import Any, Callable, List, Tuple
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
|
||||
from ._workflow import WorkflowType
|
||||
from ._judge import JudgeType
|
||||
from ._model import TunerModelConfig
|
||||
from ._dataset import DatasetConfig
|
||||
from ._algorithm import AlgorithmConfig
|
||||
|
||||
|
||||
def _set_if_not_none(obj: Any, field: str, value: Any) -> None:
|
||||
"""Set the field of obj to value if value is not None."""
|
||||
if value is not None:
|
||||
setattr(obj, field, value)
|
||||
|
||||
|
||||
def _to_trinity_config(
|
||||
*,
|
||||
config_path: str | None = None,
|
||||
workflow_func: WorkflowType | None = None,
|
||||
judge_func: JudgeType | None = None,
|
||||
model: TunerModelConfig | None = None,
|
||||
auxiliary_models: dict[str, TunerModelConfig] | None = None,
|
||||
train_dataset: DatasetConfig | None = None,
|
||||
eval_dataset: DatasetConfig | None = None,
|
||||
algorithm: AlgorithmConfig | None = None,
|
||||
project_name: str | None = None,
|
||||
experiment_name: str | None = None,
|
||||
monitor_type: str | None = None,
|
||||
) -> Any:
|
||||
"""Convert to Trinity-RFT compatible configuration."""
|
||||
from trinity.common.config import (
|
||||
Config,
|
||||
TasksetConfig,
|
||||
InferenceModelConfig,
|
||||
TinkerConfig,
|
||||
)
|
||||
|
||||
config, auto_config = _load_config_from_path_or_default(config_path)
|
||||
assert isinstance(config, Config), "Loaded config is not valid."
|
||||
|
||||
_set_if_not_none(config, "project", project_name)
|
||||
if experiment_name is None and auto_config:
|
||||
config.name = "Experiment-" + datetime.now().strftime(
|
||||
"%Y%m%d%H%M%S",
|
||||
)
|
||||
|
||||
_set_if_not_none(config, "monitor", monitor_type)
|
||||
|
||||
workflow_name = "agentscope_workflow_adapter_v1"
|
||||
if train_dataset is not None:
|
||||
if config.buffer.explorer_input.taskset is None:
|
||||
config.buffer.explorer_input.taskset = TasksetConfig(
|
||||
name="train_taskset",
|
||||
path=train_dataset.path,
|
||||
split=train_dataset.split,
|
||||
subset_name=train_dataset.name,
|
||||
)
|
||||
else:
|
||||
config.buffer.explorer_input.taskset.path = train_dataset.path
|
||||
config.buffer.explorer_input.taskset.split = train_dataset.split
|
||||
config.buffer.explorer_input.taskset.subset_name = (
|
||||
train_dataset.name
|
||||
)
|
||||
config.buffer.total_epochs = train_dataset.total_epochs
|
||||
config.buffer.total_steps = train_dataset.total_steps
|
||||
config.buffer.explorer_input.taskset.default_workflow_type = workflow_name
|
||||
config.buffer.explorer_input.default_workflow_type = workflow_name
|
||||
workflow_args = {
|
||||
"workflow_func": workflow_func,
|
||||
}
|
||||
if judge_func is not None:
|
||||
workflow_args["judge_func"] = judge_func
|
||||
|
||||
config.buffer.explorer_input.taskset.workflow_args.update(workflow_args)
|
||||
|
||||
if model is not None:
|
||||
model_config = model.get_config()
|
||||
config.model.model_path = model_config["model_path"]
|
||||
config.model.max_model_len = model_config["max_model_len"]
|
||||
config.model.max_response_tokens = model.max_tokens
|
||||
config.explorer.rollout_model = InferenceModelConfig(
|
||||
**model.get_config(),
|
||||
)
|
||||
config.explorer.rollout_model.enable_history = True
|
||||
if model.tinker_config is not None:
|
||||
config.model.tinker = TinkerConfig(
|
||||
**model.tinker_config.get_config(),
|
||||
)
|
||||
config.model.tinker.enable = True
|
||||
if auxiliary_models is not None:
|
||||
for name, aux_chat_model in auxiliary_models.items():
|
||||
model_config = InferenceModelConfig(
|
||||
**aux_chat_model.get_config(),
|
||||
)
|
||||
model_config.name = name
|
||||
config.explorer.auxiliary_models.append(
|
||||
model_config,
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
config.buffer.explorer_input.eval_tasksets.append(
|
||||
TasksetConfig(
|
||||
name="eval_taskset",
|
||||
path=eval_dataset.path,
|
||||
split=eval_dataset.split,
|
||||
subset_name=eval_dataset.name,
|
||||
),
|
||||
)
|
||||
for eval_taskset in config.buffer.explorer_input.eval_tasksets:
|
||||
eval_taskset.workflow_args.update(workflow_args)
|
||||
if algorithm is not None:
|
||||
config.algorithm.algorithm_type = algorithm.algorithm_type
|
||||
config.algorithm.repeat_times = algorithm.group_size
|
||||
config.algorithm.optimizer.lr = algorithm.learning_rate
|
||||
config.buffer.batch_size = algorithm.batch_size
|
||||
config.trainer.save_interval = algorithm.save_interval_steps
|
||||
config.explorer.eval_interval = algorithm.eval_interval_steps
|
||||
return config
|
||||
|
||||
|
||||
def _load_config_from_path_or_default(
|
||||
config_path: str | None,
|
||||
) -> Tuple[Any, bool]:
|
||||
"""Load configuration from the given path or default template.
|
||||
|
||||
Args:
|
||||
config_path (`str | None`): The path to the configuration file.
|
||||
Returns:
|
||||
`Tuple[Any, bool]`: The loaded configuration and a boolean
|
||||
indicating whether the default template was used.
|
||||
"""
|
||||
from trinity.common.config import (
|
||||
Config,
|
||||
load_config,
|
||||
)
|
||||
import tempfile
|
||||
import yaml
|
||||
|
||||
template_used = False
|
||||
if config_path is None:
|
||||
default_config = {
|
||||
"project": "AgentScope",
|
||||
"name": "Experiment",
|
||||
"checkpoint_root_dir": "./checkpoints",
|
||||
"algorithm": {
|
||||
"algorithm_type": "multi_step_grpo",
|
||||
},
|
||||
"buffer": {
|
||||
"total_epochs": 1,
|
||||
},
|
||||
"explorer": {
|
||||
"runner_per_model": 16,
|
||||
"max_timeout": 3600,
|
||||
"max_repeat_times_per_runner": 1,
|
||||
},
|
||||
"synchronizer": {
|
||||
"sync_style": "dynamic_by_explorer",
|
||||
"sync_method": "nccl",
|
||||
"sync_interval": 1,
|
||||
"sync_timeout": 7200,
|
||||
},
|
||||
"trainer": {
|
||||
"save_interval": 100,
|
||||
},
|
||||
"monitor": {
|
||||
"monitor_type": "tensorboard",
|
||||
},
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as tmp:
|
||||
yaml.dump(default_config, tmp)
|
||||
tmp.flush()
|
||||
config = load_config(tmp.name)
|
||||
template_used = True
|
||||
else:
|
||||
config = load_config(config_path)
|
||||
|
||||
assert isinstance(config, Config), "Loaded config is not valid."
|
||||
return config, template_used
|
||||
|
||||
|
||||
def check_workflow_function(
|
||||
func: Callable,
|
||||
) -> None:
|
||||
"""Check if the given function is a valid WorkflowType.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to check.
|
||||
"""
|
||||
essential_params = ["task", "model"]
|
||||
optional_params = ["auxiliary_models", "logger"]
|
||||
_check_function_signature(
|
||||
func,
|
||||
essential_params,
|
||||
optional_params,
|
||||
)
|
||||
|
||||
|
||||
def check_judge_function(
|
||||
func: Callable,
|
||||
) -> None:
|
||||
"""Check if the given function is a valid JudgeType.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to check.
|
||||
"""
|
||||
essential_params = ["task", "response"]
|
||||
optional_params = ["auxiliary_models", "logger"]
|
||||
_check_function_signature(
|
||||
func,
|
||||
essential_params,
|
||||
optional_params,
|
||||
)
|
||||
|
||||
|
||||
def _check_function_signature(
|
||||
func: Callable,
|
||||
essential_params: List[str],
|
||||
optional_params: List[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Check if the given function has the required signature.
|
||||
|
||||
Args:
|
||||
func (`Callable`): The function to check.
|
||||
essential_params (`List[str]`): List of essential parameter names
|
||||
that must be present in the function.
|
||||
optional_params (`List[str] | None`): List of optional parameter names
|
||||
that can be present in the function.
|
||||
"""
|
||||
if optional_params is None:
|
||||
optional_params = []
|
||||
|
||||
sig = inspect.signature(func)
|
||||
actual_params = []
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
# *args and **kwargs are not allowed
|
||||
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
raise ValueError(f"*args parameter is not allowed: *{param_name}")
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
raise ValueError(
|
||||
f"**kwargs parameter is not allowed: **{param_name}",
|
||||
)
|
||||
actual_params.append(param_name)
|
||||
|
||||
# Convert to sets for easier comparison
|
||||
actual_params_set = set(actual_params)
|
||||
essential_params_set = set(essential_params)
|
||||
optional_params_set = set(optional_params)
|
||||
allowed_params_set = essential_params_set | optional_params_set
|
||||
|
||||
# Check 1: All essential parameters are present
|
||||
missing_essential = essential_params_set - actual_params_set
|
||||
if missing_essential:
|
||||
raise ValueError(
|
||||
f"Missing essential parameters: {sorted(missing_essential)}",
|
||||
)
|
||||
|
||||
# Check 2: Whether there are disallowed parameters
|
||||
extra_params = actual_params_set - allowed_params_set
|
||||
if extra_params:
|
||||
raise ValueError(
|
||||
f"Contains disallowed parameters: {sorted(extra_params)}",
|
||||
)
|
||||
61
src/agentscope/tuner/_dataset.py
Normal file
61
src/agentscope/tuner/_dataset.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""DatasetConfig definition for tuner."""
|
||||
from itertools import islice
|
||||
from typing import List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatasetConfig(BaseModel):
|
||||
"""Dataset configuration for tuning.
|
||||
Compatible with huggingface dataset format.
|
||||
Agentscope will load the dataset from the given path using
|
||||
`datasets.load_dataset`.
|
||||
"""
|
||||
|
||||
path: str = Field(
|
||||
description="Path to your dataset.",
|
||||
)
|
||||
name: str | None = Field(
|
||||
description="The name of the dataset configuration.",
|
||||
default=None,
|
||||
)
|
||||
split: str | None = Field(
|
||||
description="The dataset split to use.",
|
||||
default="train",
|
||||
)
|
||||
total_epochs: int = Field(
|
||||
description="Total number of epochs to run.",
|
||||
default=1,
|
||||
)
|
||||
total_steps: int | None = Field(
|
||||
description=(
|
||||
"Total number of steps to run. "
|
||||
"If set, it will override total_epochs."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
def preview(self, n: int = 5) -> List:
|
||||
"""Preview the dataset information.
|
||||
|
||||
Args:
|
||||
n (`int`): Number of samples to preview. Defaults to 5.
|
||||
"""
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The `datasets` library is not installed. "
|
||||
"Please install it with `pip install datasets`.",
|
||||
) from e
|
||||
import json
|
||||
|
||||
ds = load_dataset(
|
||||
path=self.path,
|
||||
name=self.name,
|
||||
split=self.split,
|
||||
streaming=True,
|
||||
)
|
||||
samples = list(islice(ds, n))
|
||||
print(json.dumps(samples, indent=2, ensure_ascii=False))
|
||||
return samples
|
||||
44
src/agentscope/tuner/_judge.py
Normal file
44
src/agentscope/tuner/_judge.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The judge module for tuner."""
|
||||
from typing import Any, Callable, Dict, Awaitable
|
||||
from logging import Logger
|
||||
from pydantic import BaseModel, Field
|
||||
from ..model import ChatModelBase
|
||||
|
||||
|
||||
class JudgeOutput(BaseModel):
|
||||
"""The output of a judge function."""
|
||||
|
||||
reward: float = Field(
|
||||
description="The reward value assigned by the judge function.",
|
||||
)
|
||||
|
||||
metrics: Dict[str, float] | None = Field(
|
||||
description="Metrics from the judge function.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
JudgeType = Callable[
|
||||
[Dict, Any, Dict[str, ChatModelBase] | None, Logger | None],
|
||||
Awaitable[JudgeOutput],
|
||||
]
|
||||
# A judge function type for tuning.
|
||||
|
||||
# Args:
|
||||
# task (`Dict`):
|
||||
# The task information for the corresponding workflow.
|
||||
# response (`Any`):
|
||||
# The response field of the WorkflowOutput generated by the
|
||||
# corresponding workflow.
|
||||
# auxiliary_models (`Dict[str, ChatModelBase] | None`, optional):
|
||||
# A dictionary of additional chat models available for LLM-as-a-Judge
|
||||
# usage. The keys are model names, and the values are the corresponding
|
||||
# `ChatModelBase` instances.
|
||||
# logger (`Logger | None`, optional):
|
||||
# An optional logger for logging information during the judge
|
||||
# execution.
|
||||
# Returns:
|
||||
# `JudgeOutput`:
|
||||
# The reward value assigned by the judge function along with optional
|
||||
# metrics.
|
||||
148
src/agentscope/tuner/_model.py
Normal file
148
src/agentscope/tuner/_model.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TunerModelConfig definition."""
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TunerModelConfig(BaseModel):
|
||||
"""Model configuration for tuning."""
|
||||
|
||||
model_path: str = Field(
|
||||
description="The path to the model checkpoint.",
|
||||
)
|
||||
|
||||
max_model_len: int = Field(
|
||||
description=(
|
||||
"The maximum length of the model, including context"
|
||||
" and generated tokens."
|
||||
),
|
||||
)
|
||||
|
||||
temperature: float = Field(
|
||||
description="Sampling temperature.",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
top_p: float = Field(
|
||||
description="Top-p sampling parameter.",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
max_tokens: int = Field(
|
||||
description="Maximum tokens for generation.",
|
||||
default=8192,
|
||||
)
|
||||
|
||||
enable_thinking: bool | None = Field(
|
||||
description=(
|
||||
"Whether to enable thinking capability. "
|
||||
"Only applicable for Qwen3 series models."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
tensor_parallel_size: int = Field(
|
||||
description="The tensor parallel size for model inference.",
|
||||
default=1,
|
||||
)
|
||||
|
||||
inference_engine_num: int = Field(
|
||||
description="The number of engines for model inference.",
|
||||
default=1,
|
||||
)
|
||||
|
||||
tool_call_parser: str = Field(
|
||||
description=(
|
||||
"The tool call parser to use. The default setting "
|
||||
"is for Qwen3 series models."
|
||||
),
|
||||
default="hermes",
|
||||
)
|
||||
|
||||
reasoning_parser: str = Field(
|
||||
description=(
|
||||
"The reasoning parser to use. The default "
|
||||
"setting is for Qwen3 series models."
|
||||
),
|
||||
default="deepseek_r1",
|
||||
)
|
||||
|
||||
tinker_config: TinkerConfig | None = Field(
|
||||
description=(
|
||||
"The configuration for Tinker. " "If None, Tinker is not used."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get the model configuration.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: The model configuration dictionary.
|
||||
"""
|
||||
return {
|
||||
"model_path": self.model_path,
|
||||
"max_model_len": self.max_model_len,
|
||||
"tensor_parallel_size": self.tensor_parallel_size,
|
||||
"engine_num": self.inference_engine_num,
|
||||
"tool_call_parser": self.tool_call_parser,
|
||||
"reasoning_parser": self.reasoning_parser,
|
||||
"enable_openai_api": True,
|
||||
"enable_auto_tool_choice": True,
|
||||
}
|
||||
|
||||
|
||||
class TinkerConfig(BaseModel):
|
||||
"""Model configuration for Tinker."""
|
||||
|
||||
rank: int = Field(
|
||||
description="The LoRA rank of the Tinker model.",
|
||||
default=16,
|
||||
)
|
||||
|
||||
seed: int | None = Field(
|
||||
description=(
|
||||
"The seed for initializing LoRA weights in the model. "
|
||||
"If None, weights are initialized randomly."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
train_mlp: bool = Field(
|
||||
description="Whether to add LoRA to the MLP layers.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
train_attn: bool = Field(
|
||||
description="Whether to add LoRA to the attention layers.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
train_unembed: bool = Field(
|
||||
description="Whether to add LoRA to the unembedding layer.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
base_url: str | None = Field(
|
||||
description=(
|
||||
"The base URL for Tinker services. If None, the default "
|
||||
"service URL is used."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get the Tinker model configuration.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: The Tinker model configuration dictionary.
|
||||
"""
|
||||
return {
|
||||
"rank": self.rank,
|
||||
"seed": self.seed,
|
||||
"train_mlp": self.train_mlp,
|
||||
"train_attn": self.train_attn,
|
||||
"train_unembed": self.train_unembed,
|
||||
"base_url": self.base_url,
|
||||
}
|
||||
96
src/agentscope/tuner/_tune.py
Normal file
96
src/agentscope/tuner/_tune.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The main entry point for agent learning."""
|
||||
import os
|
||||
from ._workflow import WorkflowType
|
||||
from ._judge import JudgeType
|
||||
from ._model import TunerModelConfig
|
||||
from ._dataset import DatasetConfig
|
||||
from ._config import (
|
||||
_to_trinity_config,
|
||||
check_judge_function,
|
||||
check_workflow_function,
|
||||
)
|
||||
from ._algorithm import AlgorithmConfig
|
||||
|
||||
|
||||
def tune(
|
||||
*,
|
||||
workflow_func: WorkflowType,
|
||||
judge_func: JudgeType | None = None,
|
||||
train_dataset: DatasetConfig | None = None,
|
||||
eval_dataset: DatasetConfig | None = None,
|
||||
model: TunerModelConfig | None = None,
|
||||
auxiliary_models: dict[str, TunerModelConfig] | None = None,
|
||||
algorithm: AlgorithmConfig | None = None,
|
||||
project_name: str | None = None,
|
||||
experiment_name: str | None = None,
|
||||
monitor_type: str | None = None,
|
||||
config_path: str | None = None,
|
||||
) -> None:
|
||||
"""Train the agent workflow with the specific configuration.
|
||||
|
||||
Args:
|
||||
workflow_func (`WorkflowType`): The learning workflow function
|
||||
to execute.
|
||||
judge_func (`JudgeType`, optional): The judge function used to
|
||||
evaluate the workflow output. Defaults to None.
|
||||
train_dataset (`DatasetConfig`, optional): The training dataset for
|
||||
the learning process. Defaults to None.
|
||||
eval_dataset (`DatasetConfig`, optional): The evaluation dataset for
|
||||
the learning process. Defaults to None.
|
||||
model (`TunerModelConfig`, optional): The model to be tuned.
|
||||
Defaults to None.
|
||||
auxiliary_models (`dict[str, TunerModelConfig]`, optional): A
|
||||
dictionary of auxiliary models for LLM-as-a-Judge
|
||||
or acting other agents in multi-agent scenarios.
|
||||
Defaults to None.
|
||||
algorithm (`AlgorithmConfig`, optional): The tuning algorithm
|
||||
configuration. Defaults to None.
|
||||
project_name (`str`, optional): Name of the project.
|
||||
Defaults to None.
|
||||
experiment_name (`str`, optional): Name of the experiment.
|
||||
Leave None to use timestamp. Defaults to None.
|
||||
monitor_type (`str`, optional): Type of the monitor to use.
|
||||
Could be one of 'tensorboard', 'wandb', 'mlflow', 'swanlab'.
|
||||
Leave None to use tensorboard. Defaults to None.
|
||||
config_path (`str`, optional): Path to a trinity yaml configuration
|
||||
file. If provided, only `workflow_func` is necessary, other
|
||||
arguments will override the corresponding fields in the config.
|
||||
Defaults to None.
|
||||
"""
|
||||
try:
|
||||
from trinity.cli.launcher import run_stage
|
||||
from trinity.utils.dlc_utils import setup_ray_cluster, stop_ray_cluster
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Trinity-RFT is not installed. Please install it with "
|
||||
"`pip install trinity-rft`.",
|
||||
) from e
|
||||
|
||||
check_workflow_function(workflow_func)
|
||||
if judge_func is not None:
|
||||
check_judge_function(judge_func)
|
||||
|
||||
config = _to_trinity_config(
|
||||
config_path=config_path,
|
||||
workflow_func=workflow_func,
|
||||
judge_func=judge_func,
|
||||
model=model,
|
||||
auxiliary_models=auxiliary_models,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
algorithm=algorithm,
|
||||
project_name=project_name,
|
||||
experiment_name=experiment_name,
|
||||
monitor_type=monitor_type,
|
||||
)
|
||||
use_dlc = os.environ.get("USE_ALIYUN_PAI_DLC", "0") == "1"
|
||||
if use_dlc:
|
||||
config.cluster.ray_address = setup_ray_cluster(namespace="agentscope")
|
||||
try:
|
||||
return run_stage(
|
||||
config=config.check_and_update(),
|
||||
)
|
||||
finally:
|
||||
if use_dlc:
|
||||
stop_ray_cluster(namespace="agentscope")
|
||||
56
src/agentscope/tuner/_workflow.py
Normal file
56
src/agentscope/tuner/_workflow.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The workflow module for tuner."""
|
||||
from logging import Logger
|
||||
from typing import Any, Callable, Dict, Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
from ..model import ChatModelBase
|
||||
|
||||
|
||||
class WorkflowOutput(BaseModel):
|
||||
"""The output of a workflow function."""
|
||||
|
||||
reward: float | None = Field(
|
||||
description=(
|
||||
"The reward obtained from the workflow function. "
|
||||
"Used for direct reward output."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
response: Any | None = Field(
|
||||
description=(
|
||||
"The response generated by the workflow function. "
|
||||
"Used as judge input."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
metrics: Dict[str, float] | None = Field(
|
||||
description="Metrics from the workflow function.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
WorkflowType = Callable[
|
||||
[Dict, ChatModelBase, Dict[str, ChatModelBase] | None, Logger | None],
|
||||
Awaitable[WorkflowOutput],
|
||||
]
|
||||
# An agent workflow function type for tuning.
|
||||
|
||||
# Args:
|
||||
# task (`Dict`):
|
||||
# The task information for the workflow run.
|
||||
# model (`ChatModelBase`):
|
||||
# The primary chat model used in the workflow, this is the main model
|
||||
# being tuned.
|
||||
# auxiliary_models (`Dict[str, ChatModelBase] | None`, optional):
|
||||
# A dictionary of additional chat models available for LLM-as-a-Judge
|
||||
# usage. The keys are model names, and the values are the corresponding
|
||||
# `ChatModelBase` instances. Note that these auxiliary models are not
|
||||
# tuned during the workflow.
|
||||
# logger (`Logger | None`, optional):
|
||||
# An optional logger for logging information during the workflow
|
||||
# execution.
|
||||
# Returns:
|
||||
# `WorkflowOutput`:
|
||||
# The workflow execution results, including optional reward, raw
|
||||
# response and metrics.
|
||||
Reference in New Issue
Block a user