Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
197 lines
7.4 KiB
Python
197 lines
7.4 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from .agent_tools import (
|
||
tool_detect_external_contact,
|
||
tool_detect_intent,
|
||
tool_detect_order_status,
|
||
tool_detect_risk,
|
||
tool_extract_image_urls,
|
||
tool_extract_size_pairs,
|
||
tool_is_meaningless_short,
|
||
)
|
||
from .config import AGENT_MAX_ITERS, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME
|
||
from .models import Decision, DecisionModel, RouteModel
|
||
from .rules import rules_prompt
|
||
|
||
|
||
def _ensure_agentscope_importable() -> None:
|
||
repo_root = Path(__file__).resolve().parents[2]
|
||
src_dir = repo_root / "src"
|
||
if src_dir.exists() and str(src_dir) not in sys.path:
|
||
sys.path.insert(0, str(src_dir))
|
||
|
||
|
||
class _AgentRuntime:
|
||
def __init__(self, name: str, sys_prompt: str):
|
||
_ensure_agentscope_importable()
|
||
|
||
from agentscope.agent import ReActAgent
|
||
from agentscope.formatter import OpenAIChatFormatter
|
||
from agentscope.memory import InMemoryMemory
|
||
from agentscope.message import Msg
|
||
from agentscope.model import OpenAIChatModel
|
||
from agentscope.tool import Toolkit
|
||
|
||
if not OPENAI_API_KEY:
|
||
raise RuntimeError("OPENAI_API_KEY 未设置")
|
||
|
||
self.Msg = Msg
|
||
toolkit = Toolkit()
|
||
toolkit.register_tool_function(tool_detect_intent)
|
||
toolkit.register_tool_function(tool_extract_image_urls)
|
||
toolkit.register_tool_function(tool_detect_order_status)
|
||
toolkit.register_tool_function(tool_extract_size_pairs)
|
||
toolkit.register_tool_function(tool_detect_risk)
|
||
toolkit.register_tool_function(tool_detect_external_contact)
|
||
toolkit.register_tool_function(tool_is_meaningless_short)
|
||
|
||
model = OpenAIChatModel(
|
||
model_name=OPENAI_MODEL_NAME,
|
||
api_key=OPENAI_API_KEY,
|
||
stream=False,
|
||
client_kwargs={"base_url": OPENAI_BASE_URL},
|
||
generate_kwargs={"temperature": 0.1},
|
||
)
|
||
|
||
self.agent = ReActAgent(
|
||
name=name,
|
||
sys_prompt=sys_prompt,
|
||
model=model,
|
||
formatter=OpenAIChatFormatter(),
|
||
toolkit=toolkit,
|
||
memory=InMemoryMemory(),
|
||
max_iters=max(1, AGENT_MAX_ITERS),
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_json(text: str) -> dict[str, Any] | None:
|
||
m = re.search(r"\{[\s\S]*\}", text or "")
|
||
if not m:
|
||
return None
|
||
try:
|
||
return json.loads(m.group(0))
|
||
except Exception:
|
||
return None
|
||
|
||
@staticmethod
|
||
def _msg_to_text(msg: Any) -> str:
|
||
try:
|
||
if hasattr(msg, "get_text_content"):
|
||
v = msg.get_text_content()
|
||
if isinstance(v, str):
|
||
return v
|
||
except Exception:
|
||
pass
|
||
c = getattr(msg, "content", None)
|
||
if isinstance(c, str):
|
||
return c
|
||
if isinstance(c, list):
|
||
out: list[str] = []
|
||
for b in c:
|
||
t = getattr(b, "text", None)
|
||
if isinstance(t, str) and t.strip():
|
||
out.append(t)
|
||
return "\n".join(out)
|
||
return str(msg)
|
||
|
||
@staticmethod
|
||
def _extract_structured(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||
if not isinstance(metadata, dict):
|
||
return None
|
||
candidates = [metadata, metadata.get("structured_output"), metadata.get("result"), metadata.get("output"), metadata.get("json")]
|
||
for obj in candidates:
|
||
if isinstance(obj, dict):
|
||
return obj
|
||
return None
|
||
|
||
|
||
class RouterAgent(_AgentRuntime):
|
||
def __init__(self) -> None:
|
||
super().__init__(
|
||
"RouterAgent",
|
||
rules_prompt()
|
||
+ "\n你是路由Agent。只输出路由 pre_sales/quote/after_sales/risk,不直接回复客户。"
|
||
+ " 你必须基于上下文语义路由,禁止关键词硬匹配。",
|
||
)
|
||
|
||
async def route(self, context: dict[str, Any]) -> tuple[str, str]:
|
||
prompt = f"按上下文路由到 pre_sales/quote/after_sales/risk。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||
res = await self.agent(self.Msg("user", prompt, "user"), structured_model=RouteModel)
|
||
obj = self._extract_structured(getattr(res, "metadata", None)) or self._extract_json(self._msg_to_text(res)) or {}
|
||
route = str(obj.get("route", "pre_sales") or "pre_sales")
|
||
if route not in {"pre_sales", "quote", "after_sales", "risk"}:
|
||
route = "pre_sales"
|
||
return route, str(obj.get("reason", "") or "")
|
||
|
||
|
||
class QuoteAgent(_AgentRuntime):
|
||
def __init__(self) -> None:
|
||
super().__init__(
|
||
"QuoteAgent",
|
||
rules_prompt()
|
||
+ "\n你是报价Agent。负责收图、报价触发、报价回复和报价阶段状态更新。"
|
||
+ "\n若上下文里有 image_quote_analysis,优先参考其 诉求类型/可做性/复杂度/建议报价 来决定回复语气与报价动作。",
|
||
)
|
||
|
||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||
prompt = f"你负责报价相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||
return await _decide_with_model(self, prompt)
|
||
|
||
|
||
class AfterSalesAgent(_AgentRuntime):
|
||
def __init__(self) -> None:
|
||
super().__init__(
|
||
"AfterSalesAgent",
|
||
rules_prompt() + "\n你是售后Agent。负责退款/重发/不满意等售后处理与状态推进。",
|
||
)
|
||
|
||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||
prompt = f"你负责售后相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||
return await _decide_with_model(self, prompt)
|
||
|
||
|
||
class RiskAgent(_AgentRuntime):
|
||
def __init__(self) -> None:
|
||
super().__init__(
|
||
"RiskAgent",
|
||
rules_prompt() + "\n你是风控Agent。专注风险识别与风险动作决策。",
|
||
)
|
||
|
||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||
prompt = f"你负责风控相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||
return await _decide_with_model(self, prompt)
|
||
|
||
|
||
class PreSalesAgent(_AgentRuntime):
|
||
def __init__(self) -> None:
|
||
super().__init__(
|
||
"PreSalesAgent",
|
||
rules_prompt() + "\n你是售前Agent。处理咨询承接、收图、澄清需求与转报价前动作。",
|
||
)
|
||
|
||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||
prompt = f"你负责售前相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||
return await _decide_with_model(self, prompt)
|
||
|
||
|
||
async def _decide_with_model(rt: _AgentRuntime, prompt: str) -> Decision:
|
||
res = await rt.agent(rt.Msg("user", prompt, "user"), structured_model=DecisionModel)
|
||
obj = rt._extract_structured(getattr(res, "metadata", None)) or rt._extract_json(rt._msg_to_text(res)) or {}
|
||
action = str(obj.get("action", "reply") or "reply").strip().lower()
|
||
if action not in {"reply", "quote", "transfer", "noop", "update_state"}:
|
||
action = "reply"
|
||
return Decision(
|
||
action=action,
|
||
reply=str(obj.get("reply", "") or "").strip(),
|
||
transfer_msg=str(obj.get("transfer_msg", "") or "").strip(),
|
||
quote_mode=str(obj.get("quote_mode", "") or "").strip(),
|
||
state_patch=obj.get("state_patch") if isinstance(obj.get("state_patch"), dict) else {},
|
||
reason=str(obj.get("reason", "") or "").strip(),
|
||
)
|