Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
149 lines
4.4 KiB
Python
149 lines
4.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Example of training a ReAct agent on GSM8K with Trinity-RFT."""
|
|
from typing import Dict
|
|
|
|
from agentscope.tuner import (
|
|
tune,
|
|
DatasetConfig,
|
|
WorkflowOutput,
|
|
JudgeOutput,
|
|
TunerModelConfig,
|
|
AlgorithmConfig,
|
|
)
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.model import OpenAIChatModel
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
from agentscope.message import Msg
|
|
|
|
|
|
async def run_react_agent(
|
|
task: Dict,
|
|
model: OpenAIChatModel,
|
|
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
|
) -> WorkflowOutput:
|
|
"""A simple workflow function using the ReAct agent to solve tasks.
|
|
|
|
Args:
|
|
task (`Dict`): The task to be solved.
|
|
model (`OpenAIChatModel`): The language model to use.
|
|
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
|
A dictionary of additional chat models available for
|
|
LLM-as-a-Judge. Not used in this workflow.
|
|
|
|
Returns:
|
|
`WorkflowOutput`: The workflow output containing the agent's response.
|
|
"""
|
|
assert (
|
|
auxiliary_models is None or len(auxiliary_models) == 0
|
|
), "No auxiliary models are used in this workflow."
|
|
|
|
sys_prompt = (
|
|
"You are an agent specialized in solving math problems with tools. "
|
|
"Please solve the math problem given to you. You can write and "
|
|
"execute Python code to perform calculation or verify your answer. "
|
|
"You should return your final answer within \\boxed{{}}."
|
|
)
|
|
agent = ReActAgent(
|
|
name="react_agent",
|
|
sys_prompt=sys_prompt,
|
|
model=model,
|
|
enable_meta_tool=True,
|
|
formatter=OpenAIChatFormatter(),
|
|
)
|
|
response = await agent.reply(
|
|
msg=Msg("user", task["question"], role="user"),
|
|
)
|
|
return WorkflowOutput(
|
|
response=response,
|
|
)
|
|
|
|
|
|
async def gsm8k_judge(
|
|
task: Dict,
|
|
response: Msg,
|
|
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
|
) -> JudgeOutput:
|
|
"""A simple judge function to calculate reward based on agent's response.
|
|
|
|
Args:
|
|
task (`Dict`): The task information for the corresponding workflow.
|
|
response (`Msg`): The response generated by the corresponding workflow.
|
|
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
|
A dictionary of additional chat models available for LLM-as-a-Judge
|
|
usage. The keys are model names, and the values are the
|
|
corresponding OpenAIChatModel instances.
|
|
|
|
Returns:
|
|
`JudgeOutput`: The reward value assigned by the judge function.
|
|
"""
|
|
from trinity.common.rewards.math_reward import MathBoxedRewardFn
|
|
|
|
assert (
|
|
auxiliary_models is None or len(auxiliary_models) == 0
|
|
), "No auxiliary models are used in this workflow."
|
|
|
|
reward_fn = MathBoxedRewardFn()
|
|
# parse truth from gsm8k raw text
|
|
truth = task["answer"]
|
|
if isinstance(truth, str) and "####" in truth:
|
|
truth = truth.split("####")[1].strip()
|
|
else:
|
|
truth = str(truth)
|
|
# parse answer from response message
|
|
result = response.get_text_content()
|
|
reward_dict = reward_fn(
|
|
response=result,
|
|
truth=truth,
|
|
)
|
|
return JudgeOutput(
|
|
reward=sum(reward_dict.values()),
|
|
metrics=reward_dict,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset = DatasetConfig(
|
|
path="openai/gsm8k",
|
|
name="main",
|
|
split="train",
|
|
)
|
|
tuner_model = TunerModelConfig(
|
|
model_path="Qwen/Qwen3-0.6B",
|
|
max_model_len=24576,
|
|
max_tokens=16384,
|
|
temperature=1.0,
|
|
inference_engine_num=4,
|
|
tensor_parallel_size=1,
|
|
)
|
|
|
|
# If you have no GPU and want to use Tinker,
|
|
# uncomment the following code
|
|
# If you want to use local deployed TuFT
|
|
# set `base_url` to your TuFT server address.
|
|
#
|
|
# from agentscope.tuner import TinkerConfig
|
|
# tuner_model = TunerModelConfig(
|
|
# model_path="Qwen/Qwen3-4B-Instruct-2507",
|
|
# max_model_len=24576,
|
|
# max_tokens=16384,
|
|
# temperature=1.0,
|
|
# tinker_config=TinkerConfig(
|
|
# rank=16,
|
|
# base_url=None,
|
|
# ),
|
|
# )
|
|
|
|
algorithm = AlgorithmConfig(
|
|
algorithm_type="multi_step_grpo",
|
|
group_size=8,
|
|
learning_rate=1e-6,
|
|
batch_size=32,
|
|
)
|
|
tune(
|
|
workflow_func=run_react_agent,
|
|
judge_func=gsm8k_judge,
|
|
train_dataset=dataset,
|
|
model=tuner_model,
|
|
algorithm=algorithm,
|
|
)
|