diff --git a/core/pydantic_ai_agent.py b/core/pydantic_ai_agent.py index c8b9e41..46d8348 100755 --- a/core/pydantic_ai_agent.py +++ b/core/pydantic_ai_agent.py @@ -22,6 +22,10 @@ from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.openai import OpenAIProvider from dotenv import load_dotenv from utils.metrics_tracker import emit as metrics_emit +from utils.observability import emit_activity, build_trace_id +from core.quote_state_machine import QuoteStateMachine +from core.rules import Rule, RuleContext, RuleEngine, RuleResult +from services.risk_service import RiskService load_dotenv() @@ -209,16 +213,14 @@ class CustomerServiceAgent: @staticmethod def _activity_log(event: str, **kwargs): - safe = {} - for k, v in kwargs.items(): - if isinstance(v, str): - safe[k] = v[:240] - else: - safe[k] = v - try: - logger.info(f"[ACTIVITY] event={event} data={json.dumps(safe, ensure_ascii=False)}") - except Exception: - logger.info(f"[ACTIVITY] event={event} data={safe}") + emit_activity( + logger, + event=event, + trace_id=str(kwargs.pop("trace_id", "")), + customer_id=str(kwargs.pop("customer_id", "")), + result=str(kwargs.pop("result", "ok")), + **kwargs, + ) def __init__(self, skills_dir: str = "skills"): self.api_key = os.getenv("OPENAI_API_KEY") @@ -231,6 +233,9 @@ class CustomerServiceAgent: self.batch_quote_delay_turns = max(0, int(os.getenv("BATCH_QUOTE_DELAY_TURNS", "1"))) except Exception: self.batch_quote_delay_turns = 1 + self.quote_state_machine = QuoteStateMachine(delay_turns=self.batch_quote_delay_turns) + self.risk_service = RiskService() + self._pre_rule_engine = self._build_pre_rule_engine() if not self.api_key: raise ValueError("请设置 OPENAI_API_KEY 环境变量") @@ -430,7 +435,7 @@ class CustomerServiceAgent: 收图阶段回复默认走 AI 改写,失败时回退到固定模板。 """ # 首张收图先承接“我看一下”,避免机械地立刻催“发完统一报价”。 - if scene == "collect_ack" and len(state.pending_image_urls) <= 1: + if scene == "collect_ack" and len(state.pending_image_urls) == 1: first_ack = [ "收到了,我先看一下哈,稍等哈", "这张我收到了,我先看下,等我一下哈", @@ -1272,21 +1277,7 @@ class CustomerServiceAgent: @staticmethod def _refresh_quote_phase(state: ConversationState, phase_hint: str = ""): """统一维护收图报价状态机。""" - if phase_hint in {"idle", "collecting", "ready_to_quote", "waiting_result"}: - state.quote_phase = phase_hint - if phase_hint == "idle": - state.quote_ready_turns = 0 - return - if not state.pending_image_urls: - state.quote_phase = "idle" - state.quote_ready_turns = 0 - return - if state.quote_phase in {"ready_to_quote", "waiting_result"}: - return - if state.pending_image_urls and state.pending_requirements: - state.quote_phase = "collecting" - return - state.quote_phase = "collecting" + QuoteStateMachine().refresh(state, phase_hint=phase_hint) def _should_defer_batch_quote(self, state: ConversationState, mark_ready: bool = False) -> bool: """ @@ -1294,19 +1285,13 @@ class CustomerServiceAgent: - 首次进入 ready_to_quote 时按配置等待 N 轮 - 等待轮次归零后,本轮即可报价 """ - if mark_ready and state.quote_phase != "ready_to_quote": - state.quote_phase = "ready_to_quote" - state.quote_ready_turns = max(0, int(self.batch_quote_delay_turns)) - if state.quote_phase == "ready_to_quote" and state.quote_ready_turns > 0: - state.quote_ready_turns -= 1 - return True - return False + self.quote_state_machine.delay_turns = max(0, int(self.batch_quote_delay_turns)) + return self.quote_state_machine.should_defer_batch_quote(state, mark_ready=mark_ready) def _mark_quote_ready(self, state: ConversationState): """仅标记 ready 状态,不消费等待轮次。""" - if state.quote_phase != "ready_to_quote": - state.quote_phase = "ready_to_quote" - state.quote_ready_turns = max(0, int(self.batch_quote_delay_turns)) + self.quote_state_machine.delay_turns = max(0, int(self.batch_quote_delay_turns)) + self.quote_state_machine.mark_ready(state) def _build_reject_message(self, reason: str = "") -> str: templates = [ @@ -1779,10 +1764,163 @@ class CustomerServiceAgent: clean = msg.strip().rstrip("!!??。.~~") return clean in self._COOLDOWN_PATTERNS + def _build_pre_rule_engine(self) -> RuleEngine: + return RuleEngine( + rules=[ + Rule( + name="meaningless_short_text", + priority=10, + predicate=self._rule_pred_meaningless_short_text, + action=self._rule_act_meaningless_short_text, + ), + Rule( + name="cooldown_silent", + priority=20, + predicate=self._rule_pred_cooldown_silent, + action=self._rule_act_cooldown_silent, + ), + Rule( + name="manual_risk_block", + priority=30, + predicate=self._rule_pred_manual_risk_block, + action=self._rule_act_manual_risk_block, + ), + Rule( + name="text_risk_block", + priority=40, + predicate=self._rule_pred_text_risk_block, + action=self._rule_act_text_risk_block, + ), + ] + ) + + async def _rule_pred_meaningless_short_text(self, ctx: RuleContext) -> bool: + message: CustomerMessage = ctx.get("message") + return _is_meaningless_short_text(message.msg) + + async def _rule_act_meaningless_short_text(self, ctx: RuleContext) -> RuleResult: + message: CustomerMessage = ctx.get("message") + state: ConversationState = ctx.get("state") + trace_id = ctx.get("trace_id", "") + ping = random.choice(("嗯咯", "嗯啦", "嗯", "哦")) + state.last_reply_at = datetime.now() + self._activity_log( + "agent_ping_reply", + trace_id=trace_id, + customer_id=message.from_id, + msg=message.msg, + reply=ping, + ) + return RuleResult( + matched=True, + stop=True, + action="agent_ping_reply", + payload={"response": AgentResponse(reply=ping, should_reply=True, need_transfer=False)}, + ) + + async def _rule_pred_cooldown_silent(self, ctx: RuleContext) -> bool: + message: CustomerMessage = ctx.get("message") + state: ConversationState = ctx.get("state") + return self._in_cooldown(state, message.msg) + + async def _rule_act_cooldown_silent(self, ctx: RuleContext) -> RuleResult: + message: CustomerMessage = ctx.get("message") + state: ConversationState = ctx.get("state") + trace_id = ctx.get("trace_id", "") + elapsed = int((datetime.now() - state.last_reply_at).total_seconds()) if state.last_reply_at else 0 + print(f"[Agent] 冷却期静默(距上次回复 {elapsed}s):{message.msg!r}") + self._activity_log( + "agent_cooldown_silent", + trace_id=trace_id, + customer_id=message.from_id, + elapsed_s=elapsed, + ) + return RuleResult( + matched=True, + stop=True, + action="agent_cooldown_silent", + payload={"response": AgentResponse(reply="", should_reply=False, need_transfer=False)}, + ) + + async def _rule_pred_manual_risk_block(self, ctx: RuleContext) -> bool: + message: CustomerMessage = ctx.get("message") + decision = self.risk_service.check_manual_block(message.from_id) + ctx.set("manual_risk_decision", decision) + return decision.blocked + + async def _rule_act_manual_risk_block(self, ctx: RuleContext) -> RuleResult: + message: CustomerMessage = ctx.get("message") + trace_id = ctx.get("trace_id", "") + decision = ctx.get("manual_risk_decision") + self._activity_log( + "agent_manual_risk_reject", + trace_id=trace_id, + customer_id=message.from_id, + risk=(decision.profile if decision else {}), + ) + return RuleResult( + matched=True, + stop=True, + action="agent_manual_risk_reject", + payload={ + "response": AgentResponse( + reply="这边无法继续为你处理该类需求,给你转人工专员对接。", + should_reply=True, + need_transfer=True, + transfer_msg=TRANSFER_MESSAGE, + ) + }, + ) + + async def _rule_pred_text_risk_block(self, ctx: RuleContext) -> bool: + message: CustomerMessage = ctx.get("message") + decision = await self.risk_service.check_text_block( + message.msg, + political_detector=self._is_political_inquiry, + map_detector=self._is_map_inquiry, + ) + ctx.set("text_risk_decision", decision) + return decision.blocked + + async def _rule_act_text_risk_block(self, ctx: RuleContext) -> RuleResult: + message: CustomerMessage = ctx.get("message") + state: ConversationState = ctx.get("state") + trace_id = ctx.get("trace_id", "") + decision = ctx.get("text_risk_decision") + state.pending_image_urls.clear() + state.pending_requirements.clear() + self._sync_pending_quote_state(message.from_id, state) + + reject_text = self.risk_service.build_reject_text(decision.category if decision else "other") + reply = await self._rewrite_reply_with_ai( + message=message, + state=state, + reply=reject_text, + scene="risk_reject", + ) + state.last_reply_at = datetime.now() + print(f"{self.C_REPLY}[REPLY->CUSTOMER]{self.C_RESET} {reply}") + self._activity_log( + "agent_risk_reject", + trace_id=trace_id, + customer_id=message.from_id, + risk_category=(decision.category if decision else "other"), + risk_source=(decision.source if decision else "unknown"), + reply=reply, + ) + return RuleResult( + matched=True, + stop=True, + action="agent_risk_reject", + payload={"response": AgentResponse(reply=reply, should_reply=True, need_transfer=False)}, + ) + async def process_message(self, message: CustomerMessage) -> AgentResponse: """处理客户消息并生成回复""" + trace_id = build_trace_id(message.acc_id, message.from_id, message.msg_id, message.msg[:64]) self._activity_log( "agent_inbound", + trace_id=trace_id, acc_id=message.acc_id, customer_id=message.from_id, msg=message.msg, @@ -1791,97 +1929,12 @@ class CustomerServiceAgent: metrics_emit("inbound_msg", customer_id=message.from_id, acc_id=message.acc_id) # 获取或创建对话状态 state = self._get_conversation_state(message.from_id) - - # 无意义短句承接:单独回一句口语,不进入复杂决策 - if _is_meaningless_short_text(message.msg): - ping = random.choice(("嗯咯", "嗯啦", "嗯", "哦")) - state.last_reply_at = datetime.now() - self._activity_log("agent_ping_reply", customer_id=message.from_id, msg=message.msg, reply=ping) - return AgentResponse(reply=ping, should_reply=True, need_transfer=False) - - # 冷却期检测:近期已回复 + 纯打招呼 → 静默 - if self._in_cooldown(state, message.msg): - elapsed = int((datetime.now() - state.last_reply_at).total_seconds()) - print(f"[Agent] 冷却期静默(距上次回复 {elapsed}s):{message.msg!r}") - self._activity_log("agent_cooldown_silent", customer_id=message.from_id, elapsed_s=elapsed) - return AgentResponse(reply="", should_reply=False, need_transfer=False) - - # 前置风控:客户文本一旦命中政治/敏感询问,直接拒绝,避免“发图我看看”类答非所问 - try: - # 人工风控:标记为不接单的客户直接转人工 - manual_risk = risk_db.evaluate_customer(message.from_id) - if bool(manual_risk.get("do_not_serve")): - self._activity_log( - "agent_manual_risk_reject", - customer_id=message.from_id, - risk=manual_risk, - ) - return AgentResponse( - reply="这边无法继续为你处理该类需求,给你转人工专员对接。", - should_reply=True, - need_transfer=True, - transfer_msg=TRANSFER_MESSAGE, - ) - - from utils.content_filter import should_block_customer_smart - risk_hit, risk_category, _risk_reason = await should_block_customer_smart(message.msg) - map_hit = self._is_map_inquiry(message.msg) or (risk_category == "map") - political_hit = self._is_political_inquiry(message.msg) or (risk_category == "political") - if risk_hit or political_hit or map_hit: - # 命中敏感询问时清空待报价队列,避免旧图残留污染后续会话 - state.pending_image_urls.clear() - state.pending_requirements.clear() - self._sync_pending_quote_state(message.from_id, state) - reject_text = "地图这类不做哈,这边不接地图相关需求。" - if risk_category == "sexual": - reject_text = "这类不做哈,涉黄擦边内容都不接。" - elif risk_category == "violent": - reject_text = "这类不做哈,暴力血腥相关都不接。" - elif political_hit and not map_hit: - reject_text = "这类不做哈,政治相关图片和人物都不接。" - reply = await self._rewrite_reply_with_ai( - message=message, - state=state, - reply=reject_text, - scene="risk_reject", - ) - state.last_reply_at = datetime.now() - print(f"{self.C_REPLY}[REPLY->CUSTOMER]{self.C_RESET} {reply}") - self._activity_log( - "agent_risk_reject", - customer_id=message.from_id, - map_hit=map_hit, - political_hit=political_hit, - risk_category=risk_category, - reply=reply, - ) - return AgentResponse(reply=reply, should_reply=True, need_transfer=False) - except Exception: - map_hit = self._is_map_inquiry(message.msg) - political_hit = self._is_political_inquiry(message.msg) - if political_hit or map_hit: - state.pending_image_urls.clear() - state.pending_requirements.clear() - self._sync_pending_quote_state(message.from_id, state) - reject_text = "地图这类不做哈,这边不接地图相关需求。" - if political_hit and not map_hit: - reject_text = "这类不做哈,政治相关图片和人物都不接。" - reply = await self._rewrite_reply_with_ai( - message=message, - state=state, - reply=reject_text, - scene="risk_reject", - ) - state.last_reply_at = datetime.now() - print(f"{self.C_REPLY}[REPLY->CUSTOMER]{self.C_RESET} {reply}") - self._activity_log( - "agent_risk_reject", - customer_id=message.from_id, - map_hit=map_hit, - political_hit=political_hit, - reply=reply, - ) - return AgentResponse(reply=reply, should_reply=True, need_transfer=False) + pre_ctx = RuleContext(data={"message": message, "state": state, "trace_id": trace_id}) + pre_result = await self._pre_rule_engine.run(pre_ctx) + if pre_result.stop: + response = pre_result.payload.get("response") + if isinstance(response, AgentResponse): + return response # 检测售前/售后 new_stage = self._detect_stage(message.msg) diff --git a/core/quote_state_machine.py b/core/quote_state_machine.py new file mode 100644 index 0000000..ac6338e --- /dev/null +++ b/core/quote_state_machine.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + + +class QuoteStateLike(Protocol): + pending_image_urls: list + pending_requirements: list + quote_phase: str + quote_ready_turns: int + + +@dataclass +class QuoteStateMachine: + delay_turns: int = 1 + + def refresh(self, state: QuoteStateLike, phase_hint: str = "") -> None: + if phase_hint in {"idle", "collecting", "ready_to_quote", "waiting_result"}: + state.quote_phase = phase_hint + if phase_hint == "idle": + state.quote_ready_turns = 0 + return + + if not state.pending_image_urls: + state.quote_phase = "idle" + state.quote_ready_turns = 0 + return + + if state.quote_phase in {"ready_to_quote", "waiting_result"}: + return + + state.quote_phase = "collecting" + + def mark_ready(self, state: QuoteStateLike) -> None: + if state.quote_phase != "ready_to_quote": + state.quote_phase = "ready_to_quote" + state.quote_ready_turns = max(0, int(self.delay_turns)) + + def should_defer_batch_quote(self, state: QuoteStateLike, mark_ready: bool = False) -> bool: + if mark_ready and state.quote_phase != "ready_to_quote": + self.mark_ready(state) + + if state.quote_phase == "ready_to_quote" and state.quote_ready_turns > 0: + state.quote_ready_turns -= 1 + return True + return False diff --git a/core/rules/__init__.py b/core/rules/__init__.py new file mode 100644 index 0000000..6fcad2a --- /dev/null +++ b/core/rules/__init__.py @@ -0,0 +1,3 @@ +from .engine import Rule, RuleContext, RuleEngine, RuleResult + +__all__ = ["Rule", "RuleContext", "RuleEngine", "RuleResult"] diff --git a/core/rules/engine.py b/core/rules/engine.py new file mode 100644 index 0000000..890328d --- /dev/null +++ b/core/rules/engine.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Dict, List, Optional + + +@dataclass +class RuleContext: + data: Dict[str, Any] = field(default_factory=dict) + + def get(self, key: str, default: Any = None) -> Any: + return self.data.get(key, default) + + def set(self, key: str, value: Any) -> None: + self.data[key] = value + + +@dataclass +class RuleResult: + matched: bool = False + stop: bool = False + action: str = "" + payload: Dict[str, Any] = field(default_factory=dict) + + +Predicate = Callable[[RuleContext], Awaitable[bool]] +Action = Callable[[RuleContext], Awaitable[RuleResult]] + + +@dataclass +class Rule: + name: str + priority: int + predicate: Predicate + action: Action + + +class RuleEngine: + """Priority-ordered async rule chain.""" + + def __init__(self, rules: Optional[List[Rule]] = None): + self._rules: List[Rule] = sorted(rules or [], key=lambda x: x.priority) + + def add_rule(self, rule: Rule) -> None: + self._rules.append(rule) + self._rules.sort(key=lambda x: x.priority) + + async def run(self, ctx: RuleContext) -> RuleResult: + for rule in self._rules: + if not await rule.predicate(ctx): + continue + result = await rule.action(ctx) + if not result.matched: + result.matched = True + if not result.action: + result.action = rule.name + if result.stop: + return result + return RuleResult(matched=False, stop=False, action="no_match") diff --git a/core/websocket_client.py b/core/websocket_client.py index 051f69c..723e373 100755 --- a/core/websocket_client.py +++ b/core/websocket_client.py @@ -11,6 +11,7 @@ from collections import deque from datetime import datetime from pathlib import Path from typing import Optional, Dict, Any, List +from utils.observability import emit_activity, build_trace_id # ========== 转接分组映射 ========== def _get_transfer_group(acc_id: str) -> str: @@ -134,16 +135,14 @@ class QingjianAPIClient: def _activity_log(self, event: str, **kwargs): """统一活动日志,便于按 event 检索完整链路。""" - safe = {} - for k, v in kwargs.items(): - if isinstance(v, str): - safe[k] = v[:200] - else: - safe[k] = v - try: - logger.info(f"[ACTIVITY] event={event} data={json.dumps(safe, ensure_ascii=False)}") - except Exception: - logger.info(f"[ACTIVITY] event={event} data={safe}") + emit_activity( + logger, + event=event, + trace_id=str(kwargs.pop("trace_id", "")), + customer_id=str(kwargs.pop("customer_id", "")), + result=str(kwargs.pop("result", "ok")), + **kwargs, + ) async def connect(self): @@ -721,6 +720,8 @@ class QingjianAPIClient: try: msg_text = self.to_chinese(data.get('msg', '')) _cid = data.get('from_id', '') + trace_id = build_trace_id(data.get("acc_id", ""), _cid, data.get("msg_id", ""), msg_text[:64]) + data["_trace_id"] = trace_id _name = self.to_chinese(data.get('from_name', '') or data.get('cy_name', '')) _plat = data.get('acc_type', '') _shop_type = _get_shop_type(data.get('acc_id', ''), self.to_chinese(data.get('goods_name', '') or '')) @@ -864,19 +865,32 @@ class QingjianAPIClient: logger.info("Agent 正在处理消息...") self._activity_log( "agent_process_start", + trace_id=trace_id, acc_id=data.get("acc_id", ""), customer_id=data.get("from_id", ""), msg=msg_text, ) # 调用 Agent + _t0 = time.monotonic() response = await self.agent.process_message(customer_msg) + self._activity_log( + "agent_process_done", + trace_id=trace_id, + acc_id=data.get("acc_id", ""), + customer_id=data.get("from_id", ""), + result="ok", + latency_ms=int((time.monotonic() - _t0) * 1000), + should_reply=bool(response.should_reply), + need_transfer=bool(response.need_transfer), + ) # 检查是否需要转接人工 if response.need_transfer: logger.info("Agent 决定转接人工") self._activity_log( "agent_transfer", + trace_id=trace_id, acc_id=data.get("acc_id", ""), customer_id=data.get("from_id", ""), transfer_msg=response.transfer_msg, @@ -932,6 +946,7 @@ class QingjianAPIClient: logger.info(f"Agent 回复: {response.reply}") self._activity_log( "agent_reply", + trace_id=trace_id, acc_id=data.get("acc_id", ""), customer_id=data.get("from_id", ""), reply=response.reply, @@ -955,6 +970,7 @@ class QingjianAPIClient: logger.info("Agent 决定不回复此消息") self._activity_log( "agent_no_reply", + trace_id=trace_id, acc_id=data.get("acc_id", ""), customer_id=data.get("from_id", ""), ) @@ -963,6 +979,7 @@ class QingjianAPIClient: logger.error(f"Agent 处理失败: {e}") self._activity_log( "agent_process_error", + trace_id=data.get("_trace_id", ""), acc_id=data.get("acc_id", ""), customer_id=data.get("from_id", ""), error=str(e), @@ -1659,10 +1676,12 @@ class QingjianAPIClient: original_msg: 收到的原始消息字典 reply_content: 回复内容(文本或本地文件路径/http地址) """ + trace_id = original_msg.get("_trace_id", "") if not self.websocket: print(f"[{self.get_time()}] 错误: 未连接到服务器") self._activity_log( "send_reply_skipped", + trace_id=trace_id, reason="websocket_not_connected", acc_id=original_msg.get("acc_id", ""), customer_id=original_msg.get("from_id", ""), @@ -1687,6 +1706,7 @@ class QingjianAPIClient: ) self._activity_log( "send_reply_throttled", + trace_id=trace_id, key=ckey, cooldown_s=cooldown, msg=str(reply_content), @@ -1716,10 +1736,12 @@ class QingjianAPIClient: self._log_outbound_once(original_msg, str(reply_content)) self._activity_log( "send_reply_attempt", + trace_id=trace_id, acc_id=shop_id, customer_id=customer_id, msg=str(reply_content), ) + reply["_trace_id"] = trace_id await self.send_message(reply) def _colloquialize_outbound_reply(self, text: Any) -> Any: @@ -1820,6 +1842,7 @@ class QingjianAPIClient: print(f"[{self.get_time()}] 发送成功:\n{pretty}") self._activity_log( "send_message_success", + trace_id=message.get("_trace_id", ""), acc_id=message.get("acc_id", ""), customer_id=message.get("from_id", ""), msg_type=message.get("msg_type", 0), @@ -1829,6 +1852,7 @@ class QingjianAPIClient: print(f"[{self.get_time()}] 发送失败: {e}") self._activity_log( "send_message_error", + trace_id=message.get("_trace_id", ""), acc_id=message.get("acc_id", ""), customer_id=message.get("from_id", ""), error=str(e), @@ -1837,6 +1861,7 @@ class QingjianAPIClient: print(f"[{self.get_time()}] 错误: 连接未打开") self._activity_log( "send_message_skipped", + trace_id=message.get("_trace_id", ""), reason="socket_not_open", acc_id=message.get("acc_id", ""), customer_id=message.get("from_id", ""), diff --git a/services/risk_service.py b/services/risk_service.py new file mode 100644 index 0000000..cbf3084 --- /dev/null +++ b/services/risk_service.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Awaitable + +from db.customer_risk_db import risk_db +from utils.content_filter import should_block_customer_smart + + +@dataclass +class RiskDecision: + blocked: bool + category: str = "none" + reason: str = "" + source: str = "none" # manual/ai_filter/fallback + profile: Dict[str, Any] | None = None + + +class RiskService: + def evaluate_customer(self, customer_id: str) -> Dict[str, Any]: + return risk_db.evaluate_customer(customer_id) + + def check_manual_block(self, customer_id: str) -> RiskDecision: + profile = self.evaluate_customer(customer_id) + if bool(profile.get("do_not_serve")): + return RiskDecision( + blocked=True, + category="manual_block", + reason="do_not_serve", + source="manual", + profile=profile, + ) + return RiskDecision(blocked=False, source="manual", profile=profile) + + async def check_text_block( + self, + text: str, + *, + political_detector: Callable[[str], bool], + map_detector: Callable[[str], bool], + ) -> RiskDecision: + try: + hit, category, reason = await should_block_customer_smart(text) + map_hit = map_detector(text) or category == "map" + political_hit = political_detector(text) or category == "political" + if hit or map_hit or political_hit: + c = category + if map_hit: + c = "map" + elif political_hit: + c = "political" + return RiskDecision( + blocked=True, + category=c or "other", + reason=reason or "sensitive_text", + source="ai_filter", + ) + return RiskDecision(blocked=False, category="none", source="ai_filter") + except Exception: + map_hit = map_detector(text) + political_hit = political_detector(text) + if map_hit: + return RiskDecision(blocked=True, category="map", reason="fallback_match", source="fallback") + if political_hit: + return RiskDecision(blocked=True, category="political", reason="fallback_match", source="fallback") + return RiskDecision(blocked=False, category="none", source="fallback") + + @staticmethod + def build_reject_text(category: str) -> str: + if category == "map": + return "地图这类不做哈,这边不接地图相关需求。" + if category == "sexual": + return "这类不做哈,涉黄擦边内容都不接。" + if category == "violent": + return "这类不做哈,暴力血腥相关都不接。" + if category == "political": + return "这类不做哈,政治相关图片和人物都不接。" + return "这类不做哈,这边不接这类需求。" diff --git a/tests/replay/test_golden_replay.py b/tests/replay/test_golden_replay.py new file mode 100644 index 0000000..ee55779 --- /dev/null +++ b/tests/replay/test_golden_replay.py @@ -0,0 +1,42 @@ +import unittest + +from core.quote_state_machine import QuoteStateMachine + + +class _State: + def __init__(self): + self.pending_image_urls = [] + self.pending_requirements = [] + self.quote_phase = "idle" + self.quote_ready_turns = 0 + + +class GoldenReplayTests(unittest.TestCase): + def test_replay_collect_then_ready_then_quote(self): + sm = QuoteStateMachine(delay_turns=1) + st = _State() + + replay = [ + {"event": "image", "url": "a.jpg", "want_phase": "collecting"}, + {"event": "image", "url": "b.jpg", "want_phase": "collecting"}, + {"event": "finish", "want_phase": "ready_to_quote", "want_defer": True}, + {"event": "progress", "want_phase": "ready_to_quote", "want_defer": False}, + ] + + for step in replay: + if step["event"] == "image": + st.pending_image_urls.append(step["url"]) + sm.refresh(st) + self.assertEqual(st.quote_phase, step["want_phase"]) + elif step["event"] == "finish": + deferred = sm.should_defer_batch_quote(st, mark_ready=True) + self.assertEqual(st.quote_phase, step["want_phase"]) + self.assertEqual(deferred, step["want_defer"]) + elif step["event"] == "progress": + deferred = sm.should_defer_batch_quote(st, mark_ready=False) + self.assertEqual(st.quote_phase, step["want_phase"]) + self.assertEqual(deferred, step["want_defer"]) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/test_rule_engine.py b/tests/test_rule_engine.py new file mode 100644 index 0000000..3bc8e5b --- /dev/null +++ b/tests/test_rule_engine.py @@ -0,0 +1,57 @@ +import unittest + +from core.quote_state_machine import QuoteStateMachine +from core.rules import Rule, RuleContext, RuleEngine, RuleResult +from services.risk_service import RiskService + + +class _StubState: + def __init__(self): + self.pending_image_urls = [] + self.pending_requirements = [] + self.quote_phase = "idle" + self.quote_ready_turns = 0 + + +class RuleEngineTests(unittest.IsolatedAsyncioTestCase): + async def test_rule_engine_priority_and_stop(self): + async def pred_true(_ctx): + return True + + async def act_first(_ctx): + return RuleResult(matched=True, stop=True, action="first", payload={"x": 1}) + + async def act_second(_ctx): + return RuleResult(matched=True, stop=True, action="second", payload={"x": 2}) + + engine = RuleEngine( + [ + Rule(name="r2", priority=20, predicate=pred_true, action=act_second), + Rule(name="r1", priority=10, predicate=pred_true, action=act_first), + ] + ) + out = await engine.run(RuleContext()) + self.assertTrue(out.matched) + self.assertEqual(out.action, "first") + self.assertEqual(out.payload["x"], 1) + + def test_quote_state_machine_transitions(self): + sm = QuoteStateMachine(delay_turns=1) + st = _StubState() + st.pending_image_urls = ["u1"] + sm.refresh(st) + self.assertEqual(st.quote_phase, "collecting") + self.assertTrue(sm.should_defer_batch_quote(st, mark_ready=True)) + self.assertEqual(st.quote_phase, "ready_to_quote") + self.assertEqual(st.quote_ready_turns, 0) + self.assertFalse(sm.should_defer_batch_quote(st, mark_ready=False)) + + def test_risk_service_reject_text(self): + svc = RiskService() + self.assertIn("地图", svc.build_reject_text("map")) + self.assertIn("政治", svc.build_reject_text("political")) + self.assertIn("涉黄", svc.build_reject_text("sexual")) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/utils/observability.py b/utils/observability.py new file mode 100644 index 0000000..a06a7d3 --- /dev/null +++ b/utils/observability.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import hashlib +import json +import logging +import time +from typing import Any, Dict, Optional + + +def build_trace_id(*parts: str) -> str: + raw = "|".join(str(p or "") for p in parts) + digest = hashlib.md5(raw.encode("utf-8")).hexdigest() + return digest[:16] + + +def emit_activity( + logger: logging.Logger, + *, + event: str, + trace_id: str = "", + customer_id: str = "", + result: str = "ok", + latency_ms: Optional[int] = None, + **fields: Any, +) -> None: + payload: Dict[str, Any] = { + "trace_id": trace_id or "-", + "customer_id": customer_id or "-", + "event": event, + "result": result, + } + if latency_ms is not None: + payload["latency_ms"] = int(max(0, latency_ms)) + for k, v in (fields or {}).items(): + if isinstance(v, str): + payload[k] = v[:400] + else: + payload[k] = v + try: + logger.info(f"[ACTIVITY] {json.dumps(payload, ensure_ascii=False)}") + except Exception: + logger.info(f"[ACTIVITY] {payload}") + + +class ActivityTimer: + def __init__( + self, + *, + logger: logging.Logger, + event: str, + trace_id: str = "", + customer_id: str = "", + **fields: Any, + ): + self.logger = logger + self.event = event + self.trace_id = trace_id + self.customer_id = customer_id + self.fields = fields + self.start = time.monotonic() + + def ok(self, **fields: Any) -> None: + elapsed_ms = int((time.monotonic() - self.start) * 1000) + merged = dict(self.fields) + merged.update(fields) + emit_activity( + self.logger, + event=self.event, + trace_id=self.trace_id, + customer_id=self.customer_id, + result="ok", + latency_ms=elapsed_ms, + **merged, + ) + + def fail(self, error: str, **fields: Any) -> None: + elapsed_ms = int((time.monotonic() - self.start) * 1000) + merged = dict(self.fields) + merged.update(fields) + emit_activity( + self.logger, + event=self.event, + trace_id=self.trace_id, + customer_id=self.customer_id, + result="error", + latency_ms=elapsed_ms, + error=error, + **merged, + )