Files
tw2/examples/tuner/react_agent/main.py
codex-bot a64378956a
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
chore: initialize sandbox and overwrite remote content
2026-03-02 22:32:27 +08:00

149 lines
4.4 KiB
Python

# -*- coding: utf-8 -*-
"""Example of training a ReAct agent on GSM8K with Trinity-RFT."""
from typing import Dict
from agentscope.tuner import (
tune,
DatasetConfig,
WorkflowOutput,
JudgeOutput,
TunerModelConfig,
AlgorithmConfig,
)
from agentscope.agent import ReActAgent
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
async def run_react_agent(
task: Dict,
model: OpenAIChatModel,
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
) -> WorkflowOutput:
"""A simple workflow function using the ReAct agent to solve tasks.
Args:
task (`Dict`): The task to be solved.
model (`OpenAIChatModel`): The language model to use.
auxiliary_models (`Dict[str, OpenAIChatModel]`):
A dictionary of additional chat models available for
LLM-as-a-Judge. Not used in this workflow.
Returns:
`WorkflowOutput`: The workflow output containing the agent's response.
"""
assert (
auxiliary_models is None or len(auxiliary_models) == 0
), "No auxiliary models are used in this workflow."
sys_prompt = (
"You are an agent specialized in solving math problems with tools. "
"Please solve the math problem given to you. You can write and "
"execute Python code to perform calculation or verify your answer. "
"You should return your final answer within \\boxed{{}}."
)
agent = ReActAgent(
name="react_agent",
sys_prompt=sys_prompt,
model=model,
enable_meta_tool=True,
formatter=OpenAIChatFormatter(),
)
response = await agent.reply(
msg=Msg("user", task["question"], role="user"),
)
return WorkflowOutput(
response=response,
)
async def gsm8k_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
) -> JudgeOutput:
"""A simple judge function to calculate reward based on agent's response.
Args:
task (`Dict`): The task information for the corresponding workflow.
response (`Msg`): The response generated by the corresponding workflow.
auxiliary_models (`Dict[str, OpenAIChatModel]`):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding OpenAIChatModel instances.
Returns:
`JudgeOutput`: The reward value assigned by the judge function.
"""
from trinity.common.rewards.math_reward import MathBoxedRewardFn
assert (
auxiliary_models is None or len(auxiliary_models) == 0
), "No auxiliary models are used in this workflow."
reward_fn = MathBoxedRewardFn()
# parse truth from gsm8k raw text
truth = task["answer"]
if isinstance(truth, str) and "####" in truth:
truth = truth.split("####")[1].strip()
else:
truth = str(truth)
# parse answer from response message
result = response.get_text_content()
reward_dict = reward_fn(
response=result,
truth=truth,
)
return JudgeOutput(
reward=sum(reward_dict.values()),
metrics=reward_dict,
)
if __name__ == "__main__":
dataset = DatasetConfig(
path="openai/gsm8k",
name="main",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen3-0.6B",
max_model_len=24576,
max_tokens=16384,
temperature=1.0,
inference_engine_num=4,
tensor_parallel_size=1,
)
# If you have no GPU and want to use Tinker,
# uncomment the following code
# If you want to use local deployed TuFT
# set `base_url` to your TuFT server address.
#
# from agentscope.tuner import TinkerConfig
# tuner_model = TunerModelConfig(
# model_path="Qwen/Qwen3-4B-Instruct-2507",
# max_model_len=24576,
# max_tokens=16384,
# temperature=1.0,
# tinker_config=TinkerConfig(
# rank=16,
# base_url=None,
# ),
# )
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=8,
learning_rate=1e-6,
batch_size=32,
)
tune(
workflow_func=run_react_agent,
judge_func=gsm8k_judge,
train_dataset=dataset,
model=tuner_model,
algorithm=algorithm,
)