chore: initialize sandbox and overwrite remote content
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled

This commit is contained in:
codex-bot
2026-03-02 22:32:27 +08:00
commit a64378956a
584 changed files with 93604 additions and 0 deletions

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
import json
from typing import Any
from agentscope.tool import ToolResponse
from .rules import (
detect_intent,
detect_order_status,
extract_image_urls,
extract_size_pairs_m,
has_map_or_political_risk,
has_porn_risk,
is_meaningless_short,
requests_external_contact,
)
def _tool_ok(value: Any) -> ToolResponse:
"""Wrap python values into AgentScope ToolResponse."""
try:
text = json.dumps(value, ensure_ascii=False)
except Exception:
text = str(value)
return ToolResponse(
content=[{"type": "text", "text": text}],
metadata={"value": value},
)
def tool_detect_intent(msg: str) -> ToolResponse:
"""识别客户当前意图: image/pricing/greeting/external_contact/finish_or_quote_trigger/nonsense/unknown。"""
return _tool_ok(detect_intent(msg or ""))
def tool_extract_image_urls(msg: str) -> ToolResponse:
"""提取消息中的图片 URL 列表。"""
return _tool_ok(extract_image_urls(msg or ""))
def tool_detect_order_status(goods_order: str) -> ToolResponse:
"""识别订单状态: paid/pending_payment/refund/unknown。"""
return _tool_ok(detect_order_status(goods_order or ""))
def tool_extract_size_pairs(msg: str) -> ToolResponse:
"""提取尺寸对,单位米。返回 [(w, h), ...]。"""
return _tool_ok(extract_size_pairs_m(msg or ""))
def tool_detect_risk(msg: str, goods_name: str = "") -> ToolResponse:
"""检测风险:地图政治、黄暴。"""
text = msg or ""
gname = goods_name or ""
return _tool_ok(
{
"map_or_political": has_map_or_political_risk(text, gname),
"porn": has_porn_risk(text),
}
)
def tool_detect_external_contact(msg: str) -> ToolResponse:
"""检测是否索要外部联系方式(微信/QQ/手机号等)。"""
return _tool_ok(requests_external_contact(msg or ""))
def tool_is_meaningless_short(msg: str) -> ToolResponse:
"""检测是否无意义短句(嗯/哦/ok等"""
return _tool_ok(is_meaningless_short(msg or ""))

194
qingjian_cs/app/agents.py Normal file
View File

@@ -0,0 +1,194 @@
from __future__ import annotations
import json
import re
import sys
from pathlib import Path
from typing import Any
from .agent_tools import (
tool_detect_external_contact,
tool_detect_intent,
tool_detect_order_status,
tool_detect_risk,
tool_extract_image_urls,
tool_extract_size_pairs,
tool_is_meaningless_short,
)
from .config import AGENT_MAX_ITERS, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME
from .models import Decision, DecisionModel, RouteModel
from .rules import rules_prompt
def _ensure_agentscope_importable() -> None:
repo_root = Path(__file__).resolve().parents[2]
src_dir = repo_root / "src"
if src_dir.exists() and str(src_dir) not in sys.path:
sys.path.insert(0, str(src_dir))
class _AgentRuntime:
def __init__(self, name: str, sys_prompt: str):
_ensure_agentscope_importable()
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
from agentscope.tool import Toolkit
if not OPENAI_API_KEY:
raise RuntimeError("OPENAI_API_KEY 未设置")
self.Msg = Msg
toolkit = Toolkit()
toolkit.register_tool_function(tool_detect_intent)
toolkit.register_tool_function(tool_extract_image_urls)
toolkit.register_tool_function(tool_detect_order_status)
toolkit.register_tool_function(tool_extract_size_pairs)
toolkit.register_tool_function(tool_detect_risk)
toolkit.register_tool_function(tool_detect_external_contact)
toolkit.register_tool_function(tool_is_meaningless_short)
model = OpenAIChatModel(
model_name=OPENAI_MODEL_NAME,
api_key=OPENAI_API_KEY,
stream=False,
client_kwargs={"base_url": OPENAI_BASE_URL},
generate_kwargs={"temperature": 0.1},
)
self.agent = ReActAgent(
name=name,
sys_prompt=sys_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=toolkit,
memory=InMemoryMemory(),
max_iters=max(1, AGENT_MAX_ITERS),
)
@staticmethod
def _extract_json(text: str) -> dict[str, Any] | None:
m = re.search(r"\{[\s\S]*\}", text or "")
if not m:
return None
try:
return json.loads(m.group(0))
except Exception:
return None
@staticmethod
def _msg_to_text(msg: Any) -> str:
try:
if hasattr(msg, "get_text_content"):
v = msg.get_text_content()
if isinstance(v, str):
return v
except Exception:
pass
c = getattr(msg, "content", None)
if isinstance(c, str):
return c
if isinstance(c, list):
out: list[str] = []
for b in c:
t = getattr(b, "text", None)
if isinstance(t, str) and t.strip():
out.append(t)
return "\n".join(out)
return str(msg)
@staticmethod
def _extract_structured(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
if not isinstance(metadata, dict):
return None
candidates = [metadata, metadata.get("structured_output"), metadata.get("result"), metadata.get("output"), metadata.get("json")]
for obj in candidates:
if isinstance(obj, dict):
return obj
return None
class RouterAgent(_AgentRuntime):
def __init__(self) -> None:
super().__init__(
"RouterAgent",
rules_prompt()
+ "\n你是路由Agent。只输出路由 pre_sales/quote/after_sales/risk不直接回复客户。"
+ " 你必须基于上下文语义路由,禁止关键词硬匹配。",
)
async def route(self, context: dict[str, Any]) -> tuple[str, str]:
prompt = f"按上下文路由到 pre_sales/quote/after_sales/risk。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
res = await self.agent(self.Msg("user", prompt, "user"), structured_model=RouteModel)
obj = self._extract_structured(getattr(res, "metadata", None)) or self._extract_json(self._msg_to_text(res)) or {}
route = str(obj.get("route", "pre_sales") or "pre_sales")
if route not in {"pre_sales", "quote", "after_sales", "risk"}:
route = "pre_sales"
return route, str(obj.get("reason", "") or "")
class QuoteAgent(_AgentRuntime):
def __init__(self) -> None:
super().__init__(
"QuoteAgent",
rules_prompt() + "\n你是报价Agent。负责收图、报价触发、报价回复和报价阶段状态更新。",
)
async def decide(self, context: dict[str, Any]) -> Decision:
prompt = f"你负责报价相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
return await _decide_with_model(self, prompt)
class AfterSalesAgent(_AgentRuntime):
def __init__(self) -> None:
super().__init__(
"AfterSalesAgent",
rules_prompt() + "\n你是售后Agent。负责退款/重发/不满意等售后处理与状态推进。",
)
async def decide(self, context: dict[str, Any]) -> Decision:
prompt = f"你负责售后相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
return await _decide_with_model(self, prompt)
class RiskAgent(_AgentRuntime):
def __init__(self) -> None:
super().__init__(
"RiskAgent",
rules_prompt() + "\n你是风控Agent。专注风险识别与风险动作决策。",
)
async def decide(self, context: dict[str, Any]) -> Decision:
prompt = f"你负责风控相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
return await _decide_with_model(self, prompt)
class PreSalesAgent(_AgentRuntime):
def __init__(self) -> None:
super().__init__(
"PreSalesAgent",
rules_prompt() + "\n你是售前Agent。处理咨询承接、收图、澄清需求与转报价前动作。",
)
async def decide(self, context: dict[str, Any]) -> Decision:
prompt = f"你负责售前相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
return await _decide_with_model(self, prompt)
async def _decide_with_model(rt: _AgentRuntime, prompt: str) -> Decision:
res = await rt.agent(rt.Msg("user", prompt, "user"), structured_model=DecisionModel)
obj = rt._extract_structured(getattr(res, "metadata", None)) or rt._extract_json(rt._msg_to_text(res)) or {}
action = str(obj.get("action", "reply") or "reply").strip().lower()
if action not in {"reply", "quote", "transfer", "noop", "update_state"}:
action = "reply"
return Decision(
action=action,
reply=str(obj.get("reply", "") or "").strip(),
transfer_msg=str(obj.get("transfer_msg", "") or "").strip(),
quote_mode=str(obj.get("quote_mode", "") or "").strip(),
state_patch=obj.get("state_patch") if isinstance(obj.get("state_patch"), dict) else {},
reason=str(obj.get("reason", "") or "").strip(),
)

View File

@@ -0,0 +1,131 @@
from __future__ import annotations
import asyncio
import os
import sys
import tempfile
import uuid
from pathlib import Path
from typing import Any
import requests
from dotenv import load_dotenv
from .config import AUTO_DRAW_ENDPOINT, AUTO_DRAW_TIMEOUT_SECONDS
def _add_legacy_tw_path() -> None:
root = os.getenv("LEGACY_TW_ROOT", r"D:\main\sandbox\tw_terminator").strip()
if not root:
return
p = Path(root)
# 先加载 legacy 项目的 .env确保 service_tuhui_upload 在 import 时拿到正确账号配置
legacy_env = p / ".env"
if legacy_env.exists():
load_dotenv(legacy_env, override=True)
if p.exists() and str(p) not in sys.path:
sys.path.insert(0, str(p))
async def _draw_via_legacy_tw(
image_url: str,
customer_id: str,
requirement: str = "",
) -> dict[str, Any]:
# 优先使用当前项目中拷贝过来的 service_gemini
from services.service_gemini import GeminiExtractV2Service # type: ignore
# 上传模块暂时仍走 legacy你后续可替换为新项目本地上传实现
_add_legacy_tw_path()
from services.service_tuhui_upload import upload_to_tuhui # type: ignore
prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本"
# 1) 下载原图到本地临时文件
input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg")
output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg")
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/122.0.0.0 Safari/537.36"
),
"Referer": "https://www.taobao.com/",
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
}
resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
if resp.status_code != 200:
return {"ok": False, "error": f"download_http_{resp.status_code}"}
with open(input_path, "wb") as f:
f.write(resp.content)
# 2) 直调你原来的 service_gemini 作图 API
service = GeminiExtractV2Service()
ok_extract, msg_extract, _ = await service.extract_pattern(
input_path=input_path,
output_path=output_path,
custom_prompt=prompt,
aspect_ratio="1:1",
)
if not ok_extract:
return {"ok": False, "error": f"extract_failed:{msg_extract}"}
if not os.path.exists(output_path):
return {"ok": False, "error": "extract_no_output_file"}
# 3) 上传图绘,返回可外发 URL
ok, link, _ = await upload_to_tuhui(
output_path,
title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图",
description="AI自动作图预览",
price=1,
)
if not ok:
return {"ok": False, "error": str(link)}
return {"ok": True, "url": str(link)}
def _draw_via_http_endpoint(image_url: str, customer_id: str, requirement: str = "") -> dict[str, Any]:
if not AUTO_DRAW_ENDPOINT:
return {"ok": False, "error": "AUTO_DRAW_ENDPOINT not configured"}
payload = {
"image_url": image_url,
"customer_id": customer_id,
"requirement": requirement,
}
resp = requests.post(AUTO_DRAW_ENDPOINT, json=payload, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
if resp.status_code != 200:
return {"ok": False, "error": f"http_{resp.status_code}:{resp.text[:200]}"}
data = resp.json() if resp.text else {}
url = str(data.get("url", "") or data.get("preview_url", "") or "")
if not url:
return {"ok": False, "error": "missing_preview_url"}
return {"ok": True, "url": url}
async def auto_draw_preview(
image_url: str,
customer_id: str,
requirement: str = "",
) -> dict[str, Any]:
"""
统一自动作图入口:
1) 优先走 tw_terminator 的 service_gemini 直调链路
2) 失败时回退 AUTO_DRAW_ENDPOINT
"""
try:
return await _draw_via_legacy_tw(image_url=image_url, customer_id=customer_id, requirement=requirement)
except Exception as e:
legacy_error = str(e)
try:
data = await asyncio.to_thread(
_draw_via_http_endpoint,
image_url,
customer_id,
requirement,
)
if data.get("ok"):
return data
return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{data.get('error','unknown')}"}
except Exception as e:
return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{e}"}

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import asyncio
from typing import Any
import requests
from .config import TIANWANG_CALLBACK_URL
async def post_tianwang_callback(event: str, data: dict[str, Any], extra: dict[str, Any] | None = None, timeout_s: int = 5) -> tuple[bool, int, str]:
payload = {
'event': event,
'data': data,
'extra': extra or {},
}
def _post() -> tuple[bool, int, str]:
try:
resp = requests.post(TIANWANG_CALLBACK_URL, json=payload, timeout=timeout_s)
ok = 200 <= resp.status_code < 300
return ok, resp.status_code, (resp.text or '')[:300]
except Exception as e:
return False, 0, str(e)
return await asyncio.to_thread(_post)

370
qingjian_cs/app/client.py Normal file
View File

@@ -0,0 +1,370 @@
import asyncio
import json
import re
import time
from collections import defaultdict
import websockets
from .callbacks import post_tianwang_callback
from .auto_draw import auto_draw_preview
from .config import (
AUTO_DRAW_ENABLED,
AUTO_QUOTE_WAIT_SECONDS,
MESSAGE_DEBOUNCE_SECONDS,
QINGJIAN_WS_URI,
SHORT_REPLY_MAX_CHARS,
)
from .logger import setup_logger
from .observability import activity_event, build_trace_id
from .orchestrator import Orchestrator
from .rules import extract_image_urls, prefilter_message
class QingjianClient:
def __init__(self) -> None:
self.logger = setup_logger()
self.uri = QINGJIAN_WS_URI
self.reply_id = "tb001"
self.websocket = None
self.running = True
self.orchestrator = Orchestrator()
self.pending_msgs: dict[str, list[dict]] = defaultdict(list)
self.debounce_tasks: dict[str, asyncio.Task] = {}
self.pending_images: dict[str, list[str]] = defaultdict(list)
self.auto_quote_tasks: dict[str, asyncio.Task] = {}
self.last_reply_key: dict[str, str] = {}
self.recent_outbound: list[tuple[str, str, str, float]] = []
@staticmethod
def _customer_key(data: dict) -> str:
return f"{data.get('acc_id','')}:{data.get('from_id','')}"
@staticmethod
def _msg_text(data: dict) -> str:
return str(data.get("msg", "") or "").strip()
def _debounce_seconds(self, msg: str) -> float:
if extract_image_urls(msg):
return 2.5
return float(MESSAGE_DEBOUNCE_SECONDS)
async def send_message(self, message: dict) -> None:
if not self.websocket:
return
await self.websocket.send(json.dumps(message, ensure_ascii=False))
self.logger.info("[发送] %s", message.get("msg", ""))
async def send_reply(self, data: dict, text: str, trace_id: str = "-") -> None:
text = str(text or "").strip()
if not text:
return
text = self._shorten_reply(text)
msg = {
"msg_id": "",
"acc_id": data.get("acc_id", ""),
"msg": text,
"from_id": data.get("from_id", ""),
"from_name": data.get("from_name", data.get("from_id", "")),
"cy_id": data.get("from_id", ""),
"acc_type": data.get("acc_type", "AliWorkbench"),
"msg_type": 0,
"cy_name": data.get("from_name", data.get("from_id", "")),
}
activity_event(self.logger, "send_reply_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=text)
await self.send_message(msg)
self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), text, time.monotonic()))
if len(self.recent_outbound) > 200:
self.recent_outbound = self.recent_outbound[-200:]
activity_event(self.logger, "send_reply_success", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=text)
async def send_image(self, data: dict, image_url: str, trace_id: str = "-") -> None:
image_url = str(image_url or "").strip()
if not image_url:
return
msg = {
"msg_id": "",
"acc_id": data.get("acc_id", ""),
"msg": image_url,
"from_id": data.get("from_id", ""),
"from_name": data.get("from_name", data.get("from_id", "")),
"cy_id": data.get("from_id", ""),
"acc_type": data.get("acc_type", "AliWorkbench"),
"msg_type": 1,
"cy_name": data.get("from_name", data.get("from_id", "")),
}
activity_event(self.logger, "send_image_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=image_url)
await self.send_message(msg)
self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), image_url, time.monotonic()))
if len(self.recent_outbound) > 200:
self.recent_outbound = self.recent_outbound[-200:]
activity_event(self.logger, "send_image_success", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=image_url)
@staticmethod
def _clean_text(text: str) -> str:
t = str(text or "").strip()
t = re.sub(r"\s+", "", t)
return t
def _shorten_reply(self, text: str) -> str:
max_len = max(8, int(SHORT_REPLY_MAX_CHARS))
t = str(text or "").strip()
t = self._humanize_reply(t)
if len(t) <= max_len:
return t
parts = re.split(r"[。!?!?]", t)
head = next((p.strip() for p in parts if p and p.strip()), t)
if len(head) > max_len:
head = head[:max_len].rstrip(",;: ")
return head or t[:max_len]
@staticmethod
def _humanize_reply(text: str) -> str:
t = str(text or "").strip()
return t
@staticmethod
def _is_invalid_ai_reply(text: str) -> bool:
t = str(text or "").strip().lower()
if not t:
return True
if "i noticed that you have interrupted me" in t:
return True
if t.startswith("action:") or t.startswith("{"):
return True
return False
def _fallback_reply(self, action: str) -> str:
if action == "transfer":
return "我先给你转人工处理。"
return "收到,我先处理一下。"
def _is_outbound_echo(self, data: dict, msg: str) -> bool:
"""
轻简可能会把我方刚发送文本回推为“收到消息”。
对“短时间完全相同文本”做回环拦截,兼容 acc/from 对调回推,避免无限对话。
"""
in_acc = str(data.get("acc_id", ""))
in_from = str(data.get("from_id", ""))
in_msg = self._clean_text(msg)
now = time.monotonic()
if not in_msg:
return False
for out_acc, out_to, out_msg, ts in reversed(self.recent_outbound):
if (now - ts) > 120:
break
if self._clean_text(out_msg) != in_msg:
continue
if (out_acc == in_acc and out_to == in_from) or (out_acc == in_from and out_to == in_acc):
return True
return False
async def _handle_decision(self, data: dict, merged_msg: str, *, auto_quote: bool = False) -> None:
key = self._customer_key(data)
trace_id = build_trace_id(data.get("acc_id", ""), data.get("from_id", ""), merged_msg)
t0 = time.perf_counter()
urls = extract_image_urls(merged_msg)
if urls:
for u in urls:
if u not in self.pending_images[key]:
self.pending_images[key].append(u)
context = {
"customer_key": key,
"acc_id": data.get("acc_id", ""),
"customer_id": data.get("from_id", ""),
"goods_name": data.get("goods_name", ""),
"goods_order": data.get("goods_order", ""),
"msg": merged_msg,
"intent": "unknown",
"pending_images": len(self.pending_images[key]),
"auto_quote_trigger": auto_quote,
"last_reply": self.last_reply_key.get(key, ""),
}
activity_event(self.logger, "agent_process_start", trace_id=trace_id, customer_id=context["customer_id"], acc_id=context["acc_id"], intent=context["intent"])
route, decision, state = await self.orchestrator.decide(context)
latency_ms = int((time.perf_counter() - t0) * 1000)
activity_event(
self.logger,
"agent_process_done",
trace_id=trace_id,
customer_id=context["customer_id"],
route=route,
action=decision.action,
reason=decision.reason,
latency_ms=latency_ms,
after_sales_stage=state.get("after_sales_stage", "new"),
)
if decision.action == "transfer":
text = (decision.transfer_msg or "").strip()
if self._is_invalid_ai_reply(text):
text = self._fallback_reply("transfer")
await self.send_reply(data, text, trace_id=trace_id)
self.last_reply_key[key] = text
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "transfer", "reply": text})
return
if decision.action == "quote":
if AUTO_DRAW_ENABLED and self.pending_images.get(key):
latest_image = self.pending_images[key][-1]
activity_event(
self.logger,
"auto_draw_start",
trace_id=trace_id,
customer_id=context["customer_id"],
image_url=latest_image,
)
draw_res = await auto_draw_preview(
image_url=latest_image,
customer_id=context["customer_id"],
requirement=merged_msg,
)
if draw_res.get("ok"):
preview_url = str(draw_res.get("url", "") or "")
await self.send_reply(data, "先给你做了预览图。", trace_id=trace_id)
await self.send_image(data, preview_url, trace_id=trace_id)
final_text = "看下预览,满意再拍下付款。"
await self.send_reply(data, final_text, trace_id=trace_id)
self.last_reply_key[key] = final_text
# 预览完成后清掉当前批次,避免同一图重复触发
self.pending_images[key].clear()
activity_event(
self.logger,
"auto_draw_success",
trace_id=trace_id,
customer_id=context["customer_id"],
preview_url=preview_url,
)
await post_tianwang_callback(
"message_processed",
data,
extra={"trace_id": trace_id, "route": route, "action": "quote", "reply": final_text, "auto_draw": True},
)
return
activity_event(
self.logger,
"auto_draw_fail",
trace_id=trace_id,
customer_id=context["customer_id"],
error=str(draw_res.get("error", "unknown")),
)
text = (decision.reply or "").strip()
if self._is_invalid_ai_reply(text):
text = self._fallback_reply("quote")
if self.last_reply_key.get(key) != text:
await self.send_reply(data, text, trace_id=trace_id)
self.last_reply_key[key] = text
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "quote", "reply": text})
return
if decision.action == "noop":
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "noop", "reply": ""})
return
text = (decision.reply or "").strip()
if self._is_invalid_ai_reply(text):
text = self._fallback_reply("reply")
if self.last_reply_key.get(key) != text:
await self.send_reply(data, text, trace_id=trace_id)
self.last_reply_key[key] = text
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "reply", "reply": text})
if self.pending_images[key] and key not in self.auto_quote_tasks:
self.auto_quote_tasks[key] = asyncio.create_task(self._auto_quote_later(data))
async def _auto_quote_later(self, data: dict) -> None:
key = self._customer_key(data)
try:
await asyncio.sleep(AUTO_QUOTE_WAIT_SECONDS)
if self.pending_images.get(key):
await self._handle_decision(data, "", auto_quote=True)
finally:
self.auto_quote_tasks.pop(key, None)
async def _flush_customer(self, key: str) -> None:
queue = self.pending_msgs.get(key, [])
if not queue:
return
merged = "".join([self._msg_text(x) for x in queue if self._msg_text(x)])
data = queue[-1]
self.pending_msgs[key].clear()
await self._handle_decision(data, merged)
async def _debounce_enqueue(self, data: dict) -> None:
key = self._customer_key(data)
msg = self._msg_text(data)
self.pending_msgs[key].append(data)
if key in self.debounce_tasks:
self.debounce_tasks[key].cancel()
wait_s = self._debounce_seconds(msg)
activity_event(self.logger, "debounce_enqueue", customer_id=data.get("from_id", "-"), key=key, queue_size=len(self.pending_msgs[key]), wait_s=wait_s)
async def later() -> None:
try:
await asyncio.sleep(wait_s)
await self._flush_customer(key)
except asyncio.CancelledError:
return
finally:
self.debounce_tasks.pop(key, None)
self.debounce_tasks[key] = asyncio.create_task(later())
async def _on_message(self, raw: str) -> None:
try:
data = json.loads(raw)
except Exception:
self.logger.info("[非JSON] %s", raw)
return
msg_type = int(data.get("msg_type", 0) or 0)
msg = self._msg_text(data)
rule = prefilter_message(msg, msg_type)
self.logger.info("[收消息] acc=%s from=%s type=%s msg=%s", data.get("acc_id", ""), data.get("from_id", ""), msg_type, msg)
await post_tianwang_callback("message_received", data, extra={"msg_type": msg_type})
if self._is_outbound_echo(data, msg):
activity_event(
self.logger,
"inbound_ignored",
customer_id=data.get("from_id", "-"),
reason="outbound_echo_loop_guard",
)
return
if rule.ignore:
activity_event(self.logger, "inbound_ignored", customer_id=data.get("from_id", "-"), reason=rule.reason)
return
patched = dict(data)
patched["msg"] = rule.normalized_msg or msg
if msg_type == 1:
await self._handle_decision(patched, patched["msg"])
return
await self._debounce_enqueue(patched)
async def _serve(self) -> None:
while self.running:
try:
self.logger.info("[连接] %s", self.uri)
async with websockets.connect(self.uri) as ws:
self.websocket = ws
self.logger.info("[连接成功]")
async for raw in ws:
await self._on_message(raw)
except Exception as e:
self.logger.info("[连接异常] %s", e)
await asyncio.sleep(3)
def run(self) -> None:
asyncio.run(self._serve())

37
qingjian_cs/app/config.py Normal file
View File

@@ -0,0 +1,37 @@
import os
from pathlib import Path
try:
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parents[1] / '.env')
except Exception:
pass
QINGJIAN_WS_URI = os.getenv("QINGJIAN_WS_URI", "ws://127.0.0.1:9528")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip()
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3").strip()
OPENAI_MODEL_NAME = os.getenv("OPENAI_MODEL_NAME", "doubao-seed-2-0-pro-260215").strip()
MESSAGE_DEBOUNCE_SECONDS = int(os.getenv("MESSAGE_DEBOUNCE_SECONDS", "6"))
AUTO_QUOTE_WAIT_SECONDS = int(os.getenv("AUTO_QUOTE_WAIT_SECONDS", "18"))
AGENT_MAX_ITERS = int(os.getenv("AGENT_MAX_ITERS", "3"))
FAST_ROUTE_ENABLED = os.getenv("FAST_ROUTE_ENABLED", "1").strip() in {"1", "true", "True", "yes", "on"}
SHORT_REPLY_MAX_CHARS = int(os.getenv("SHORT_REPLY_MAX_CHARS", "20"))
STORE_BACKEND = os.getenv("STORE_BACKEND", "sqlite").strip().lower()
STORE_SQLITE_PATH = os.getenv("STORE_SQLITE_PATH", "").strip()
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1").strip()
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
MYSQL_USER = os.getenv("MYSQL_USER", "root").strip()
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "").strip()
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs").strip()
MYSQL_TABLE_PREFIX = os.getenv("MYSQL_TABLE_PREFIX", "qjcs_").strip()
HTTP_HOST = os.getenv("HTTP_HOST", "127.0.0.1").strip()
HTTP_PORT = int(os.getenv("HTTP_PORT", "6060"))
TIANWANG_CALLBACK_URL = os.getenv("TIANWANG_CALLBACK_URL", "http://139.199.3.75:18789/api/callback").strip()
AUTO_DRAW_ENABLED = os.getenv("AUTO_DRAW_ENABLED", "1").strip() in {"1", "true", "True", "yes", "on"}
AUTO_DRAW_ENDPOINT = os.getenv("AUTO_DRAW_ENDPOINT", "").strip()
AUTO_DRAW_TIMEOUT_SECONDS = int(os.getenv("AUTO_DRAW_TIMEOUT_SECONDS", "25"))

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from flask import Flask, jsonify, request
from .logger import setup_logger
from .task_manager import TaskManager
def create_http_app(task_manager: TaskManager | None = None) -> Flask:
app = Flask(__name__)
logger = setup_logger()
tm = task_manager or TaskManager()
@app.get('/api/health')
def health():
return jsonify({'ok': True})
@app.post('/api/task/receive')
def receive_task():
payload = request.get_json(silent=True) or {}
task_id = tm.create_task(payload)
logger.info('[任务] receive task_id=%s', task_id)
return jsonify({'ok': True, 'task_id': task_id})
@app.post('/api/task/cancel')
def cancel_task():
body = request.get_json(silent=True) or {}
task_id = str(body.get('task_id', '')).strip()
if not task_id:
return jsonify({'ok': False, 'error': 'task_id required'}), 400
ok = tm.cancel_task(task_id)
return jsonify({'ok': ok, 'task_id': task_id})
@app.get('/api/task/status/<task_id>')
def task_status(task_id: str):
task = tm.get_task(task_id)
if not task:
return jsonify({'ok': False, 'error': 'not found'}), 404
return jsonify({'ok': True, 'task': task})
@app.get('/api/task/list')
def task_list():
limit = int(request.args.get('limit', 100))
return jsonify({'ok': True, 'tasks': tm.list_tasks(limit=limit)})
return app
def run_http_server(host: str, port: int, task_manager: TaskManager | None = None) -> None:
app = create_http_app(task_manager=task_manager)
app.run(host=host, port=port, debug=False, use_reloader=False)

22
qingjian_cs/app/logger.py Normal file
View File

@@ -0,0 +1,22 @@
import logging
import sys
def setup_logger() -> logging.Logger:
logger = logging.getLogger("qingjian_cs")
if logger.handlers:
return logger
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", "%H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)
# 降低 AgentScope 内部推理/格式器日志噪音,保留本项目活动日志。
logging.getLogger("agentscope").setLevel(logging.ERROR)
logging.getLogger("agentscope.formatter").setLevel(logging.ERROR)
logging.getLogger("agentscope.agent").setLevel(logging.ERROR)
logging.getLogger("_openai_formatter").setLevel(logging.ERROR)
logging.getLogger("_react_agent").setLevel(logging.ERROR)
return logger

30
qingjian_cs/app/models.py Normal file
View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from pydantic import BaseModel, Field
class DecisionModel(BaseModel):
action: Literal["reply", "quote", "transfer", "noop", "update_state"] = Field(description="唯一动作")
reply: str = Field(default="", description="给客户的回复")
transfer_msg: str = Field(default="", description="转人工提示")
quote_mode: Literal["flush_pending", "analyze_current_or_recent", "collect_only", ""] = Field(default="")
state_patch: dict = Field(default_factory=dict, description="状态增量")
reason: str = Field(default="", description="内部原因")
class RouteModel(BaseModel):
route: Literal["pre_sales", "quote", "after_sales", "risk"] = Field(description="路由目标")
reason: str = Field(default="")
@dataclass
class Decision:
action: str
reply: str = ""
transfer_msg: str = ""
quote_mode: str = ""
state_patch: dict | None = None
reason: str = ""

View File

@@ -0,0 +1,23 @@
import json
from datetime import datetime
from typing import Any
def now_ts() -> str:
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
def build_trace_id(acc_id: str, customer_id: str, msg: str) -> str:
base = f"{acc_id}|{customer_id}|{msg}|{now_ts()}"
return hex(abs(hash(base)))[2:18]
def activity_event(logger, event: str, *, trace_id: str = '-', customer_id: str = '-', result: str = 'ok', **kwargs: Any) -> None:
payload = {
'trace_id': trace_id or '-',
'customer_id': customer_id or '-',
'event': event,
'result': result,
**kwargs,
}
logger.info('[活动日志] %s', json.dumps(payload, ensure_ascii=False))

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Any
from .agents import AfterSalesAgent, PreSalesAgent, QuoteAgent, RiskAgent, RouterAgent
from .models import Decision
from .state_machine import evolve_after_sales_state, migrate_state_schema
from .store import ConversationStore
class Orchestrator:
def __init__(self) -> None:
self.router = RouterAgent()
self.pre_sales = PreSalesAgent()
self.quote = QuoteAgent()
self.after_sales = AfterSalesAgent()
self.risk = RiskAgent()
self.store = ConversationStore()
async def decide(self, context: dict[str, Any]) -> tuple[str, Decision, dict[str, Any]]:
customer_key = context["customer_key"]
session = self.store.get_session(customer_key)
prev_state = migrate_state_schema(session.get("state", {}))
prev_route = session.get("route", "pre_sales")
# 统一改为语义决策:不走关键词意图/订单硬判定。
intent = "unknown"
order_status = "unknown"
merged_ctx = {
**context,
"session_state": prev_state,
"previous_route": prev_route,
"intent": intent,
"order_status": order_status,
}
route, route_reason = await self.router.route(merged_ctx)
if route == "quote":
decision = await self.quote.decide(merged_ctx)
elif route == "after_sales":
decision = await self.after_sales.decide(merged_ctx)
elif route == "risk":
decision = await self.risk.decide(merged_ctx)
else:
decision = await self.pre_sales.decide(merged_ctx)
merged_state = {**prev_state, **(decision.state_patch or {})}
new_state = evolve_after_sales_state(
merged_state,
route=route,
action=decision.action,
intent=intent,
order_status=order_status,
msg=str(context.get("msg", "") or ""),
)
self.store.upsert_session(customer_key, context.get("acc_id", ""), context.get("customer_id", ""), route, new_state)
self.store.append_event(
customer_key,
"decision",
{
"route": route,
"route_reason": route_reason,
"action": decision.action,
"reason": decision.reason,
"after_sales_stage": new_state.get("after_sales_stage", "new"),
},
)
return route, decision, new_state

137
qingjian_cs/app/rules.py Normal file
View File

@@ -0,0 +1,137 @@
import re
from dataclasses import dataclass
IMAGE_URL_RE = re.compile(r"https?://[^\s]+(?:\.jpg|\.jpeg|\.png|\.webp|\.bmp|\.gif)(?:\?[^\s]*)?", re.I)
SIZE_RE = re.compile(r"(\d+(?:\.\d+)?)\s*(米|m|M)\s*[xX*乘]\s*(\d+(?:\.\d+)?)\s*(米|m|M)")
@dataclass
class RuleResult:
ignore: bool = False
normalized_msg: str = ""
reason: str = ""
def extract_customer_text_from_shop_card(msg: str) -> str:
if "[进店卡片]" not in (msg or ""):
return ""
prefix = msg.split("#*#[进店卡片]", 1)[0].strip()
if prefix and prefix not in {"你好", "您好", "在吗"}:
return prefix
return prefix
def detect_order_status(order_text: str) -> str:
# 订单状态交给主决策 AI 从上下文语义判断。
return "unknown"
def extract_size_pairs_m(msg: str) -> list[tuple[float, float]]:
out: list[tuple[float, float]] = []
for m in SIZE_RE.finditer(msg or ""):
w = float(m.group(1))
h = float(m.group(3))
out.append((w, h))
return out
def has_map_or_political_risk(msg: str, goods_name: str = "") -> bool:
# 风险由 RiskAgent 语义判断。
return False
def has_porn_risk(msg: str) -> bool:
# 风险由 RiskAgent 语义判断。
return False
def requests_external_contact(msg: str) -> bool:
# 外联风险由 RiskAgent 语义判断。
return False
def is_meaningless_short(msg: str) -> bool:
# 无意义短句由主决策 AI 语义判断。
return False
def prefilter_message(msg: str, msg_type: int) -> RuleResult:
m = (msg or "").strip()
if not m:
return RuleResult(ignore=True, reason="empty")
if msg_type not in (0, 1):
return RuleResult(ignore=True, reason="unsupported_msg_type")
if "" in m and " 转交给 " in m:
return RuleResult(ignore=True, reason="transfer_notice")
if "Gemini 店铺消息,跳过" in m:
return RuleResult(ignore=True, reason="system_echo")
if "[进店卡片]" in m:
t = extract_customer_text_from_shop_card(m)
if t:
return RuleResult(ignore=False, normalized_msg=t, reason="shop_card_with_text")
return RuleResult(ignore=True, reason="pure_shop_card")
return RuleResult(ignore=False, normalized_msg=m, reason="normal")
def detect_intent(msg: str) -> str:
m = (msg or "").lower()
if IMAGE_URL_RE.search(m):
return "image"
# 其余意图交给 AI 语义判断。
return "unknown"
def extract_image_urls(msg: str) -> list[str]:
return IMAGE_URL_RE.findall(msg or "")
def rules_prompt() -> str:
return (
"你是淘宝图像服务客服系统的统一决策AI。必须按以下 MASTER_RULES 执行。\n"
"只输出 JSON 决策,不要解释过程。\n"
"动作 action 只能是: reply / quote / transfer / noop / update_state。\n\n"
"HARD_RULES(最高优先级,必须先判断):\n"
"1) 命中以下任一类,禁止报价(action=quote)\n"
" - 政治/涉政/政治人物/政治事件/政治宣传;\n"
" - 地图类(地图/地形图/行政区划图/卫星地图等)\n"
" - 黄暴/擦边/色情/露点/明显违规内容;\n"
" - 客户索要站外联系方式(微信/QQ/手机号)\n"
" - 文本出现超大尺寸需求(如超长边或超大面积,明显超出常规制作能力)。\n"
"2) 命中硬规则时,只能 action=reply 或 action=transfer\n"
" - 可直接拒绝的action=replyreply简短明确边界\n"
" - 风险高或不确定的action=transfer给 transfer_msg。\n"
"3) 命中硬规则后禁止改口:后续客户追问“能不能做”,仍保持同一结论。\n"
"4) 多图场景若部分可做、部分不可做,必须明确“哪张可做、哪张不做”,禁止含糊表述。\n\n"
"MASTER_RULES:\n"
"A. 统一动作语义\n"
"1) reply: 直接回复客户。\n"
"2) quote: 触发报价或触发看图后报价流程。\n"
"3) transfer: 转人工,必须给 transfer_msg。\n"
"4) noop: 当前不需要回复。\n"
"5) update_state: 仅更新状态,不对外发消息。\n\n"
"B. 售前与报价\n"
"1) 客户发图: 优先承接,可继续收图;不强行一次报完。\n"
"2) 客户询价且已有图(当前图/待处理图/最近图): 优先 action=quote。\n"
"3) 客户无图询价: action=reply引导其先发图。\n"
"4) 客户说“发完了/就这些/报价吧”: 若有图则 action=quote。\n"
"5) 不能承诺“一模一样原图必找到”,可说先看图评估。\n"
"6) 尺寸很大或要求高还原时,不夸张承诺,先说明可评估后给结论。\n\n"
"C. 订单阶段\n"
"1) 已付款: 可回复“已安排处理/正在处理/完成后发你确认”。\n"
"2) 待付款: 可提示付款,但不与客户争执;必要时先给预览再引导付款。\n"
"3) 退款/售后诉求: 进入售后语境,保持克制,必要时转人工。\n\n"
"D. 风控与合规\n"
"1) 涉政治/地图边界/黄暴/违规内容: 按 HARD_RULES 执行,禁止报价。\n"
"2) 客户索要微信/QQ/手机号等站外联系方式: 不外呼,站内引导。\n"
"3) 高风险不确定时,不硬答,给保守回复或转人工。\n\n"
"E. 对话质量\n"
"1) 单次只做一个动作,不混合。\n"
"2) 避免重复同一句话;若语义相同,换表达。\n"
"3) reply 必须短: 优先 1 句口语化避免AI腔。\n"
"4) 不要输出思考过程,不要输出 tool_use 文本给客户。\n"
"5) 若上下文不足,先澄清 1 个关键问题,不要连续追问。\n\n"
"F. 店铺人格\n"
"1) 按店铺/账号口吻说话,像真人客服,不要机械模板。\n"
"2) 语气友好直接不啰嗦不说“作为AI”。\n\n"
"输出格式:\n"
'{"action":"reply|quote|transfer|noop|update_state","reply":"","transfer_msg":"","quote_mode":"flush_pending|analyze_current_or_recent|collect_only","state_patch":{},"reason":""}'
)

View File

@@ -0,0 +1,60 @@
from __future__ import annotations
from typing import Any
AFTER_SALES_STAGES = {
"new",
"waiting_material",
"quoted",
"processing",
"waiting_feedback",
"done",
"refunding",
"transferred",
}
def migrate_state_schema(state: dict[str, Any] | None) -> dict[str, Any]:
src = dict(state or {})
# 兼容旧字段
if "after_sales_stage" not in src:
old = src.get("aftersales_stage") or src.get("status") or "new"
src["after_sales_stage"] = str(old)
if src.get("after_sales_stage") not in AFTER_SALES_STAGES:
src["after_sales_stage"] = "new"
if "quote_count" not in src:
src["quote_count"] = int(src.get("quotes", 0) or 0)
if "image_count" not in src:
src["image_count"] = int(src.get("images", 0) or 0)
if "last_intent" not in src:
src["last_intent"] = str(src.get("intent", "unknown") or "unknown")
if "version" not in src:
src["version"] = 2
return src
def evolve_after_sales_state(prev_state: dict[str, Any], *, route: str, action: str, intent: str, order_status: str, msg: str) -> dict[str, Any]:
s = migrate_state_schema(prev_state)
stage = s.get("after_sales_stage", "new")
if action == "transfer" or route == "risk":
stage = "transferred"
elif route == "quote" and action == "quote":
stage = "quoted"
s["quote_count"] = int(s.get("quote_count", 0)) + 1
elif route == "after_sales":
if order_status == "paid":
stage = "processing"
elif stage == "new":
stage = "waiting_material"
elif route == "pre_sales":
if intent == "image":
stage = "waiting_material"
s["after_sales_stage"] = stage
s["last_intent"] = intent
s["last_order_status"] = order_status or "unknown"
s["version"] = 2
return s

251
qingjian_cs/app/store.py Normal file
View File

@@ -0,0 +1,251 @@
from __future__ import annotations
import json
import re
import sqlite3
from pathlib import Path
from typing import Any
from .config import (
MYSQL_DATABASE,
MYSQL_HOST,
MYSQL_PASSWORD,
MYSQL_PORT,
MYSQL_TABLE_PREFIX,
MYSQL_USER,
STORE_BACKEND,
STORE_SQLITE_PATH,
)
from .state_machine import migrate_state_schema
DB_PATH = Path(__file__).resolve().parents[1] / "qingjian_cs.db"
def _safe_prefix(v: str) -> str:
p = re.sub(r"[^a-zA-Z0-9_]", "", (v or "").strip())
return p or "qjcs_"
class ConversationStore:
def __init__(self, backend: str | None = None, db_path: str | None = None) -> None:
self.backend = (backend or STORE_BACKEND or "sqlite").lower()
self.db_path = db_path or STORE_SQLITE_PATH or str(DB_PATH)
self.prefix = _safe_prefix(MYSQL_TABLE_PREFIX)
self.sessions_table = f"{self.prefix}sessions"
self.events_table = f"{self.prefix}events"
self._init_db()
def _sqlite_conn(self):
return sqlite3.connect(self.db_path)
def _mysql_conn(self):
import pymysql
return pymysql.connect(
host=MYSQL_HOST,
port=MYSQL_PORT,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=MYSQL_DATABASE,
charset="utf8mb4",
autocommit=False,
)
def _conn(self):
if self.backend == "mysql":
return self._mysql_conn()
return self._sqlite_conn()
def _init_db(self) -> None:
if self.backend == "mysql":
self._init_mysql()
else:
self._init_sqlite()
def _ensure_sqlite_column(self, conn: sqlite3.Connection, table: str, col: str, ddl: str) -> None:
cols = {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()}
if col not in cols:
conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}")
def _init_sqlite(self) -> None:
t_s = self.sessions_table
t_e = self.events_table
with self._sqlite_conn() as c:
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_s} (
customer_key TEXT PRIMARY KEY,
acc_id TEXT,
customer_id TEXT,
route TEXT,
state_json TEXT,
after_sales_stage TEXT,
state_version INTEGER DEFAULT 2,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_e} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_key TEXT,
event TEXT,
payload_json TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
self._ensure_sqlite_column(c, t_s, "after_sales_stage", "after_sales_stage TEXT")
self._ensure_sqlite_column(c, t_s, "state_version", "state_version INTEGER DEFAULT 2")
def _init_mysql(self) -> None:
t_s = self.sessions_table
t_e = self.events_table
conn = self._mysql_conn()
try:
with conn.cursor() as c:
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_s} (
customer_key VARCHAR(255) PRIMARY KEY,
acc_id VARCHAR(255),
customer_id VARCHAR(255),
route VARCHAR(64),
state_json JSON,
after_sales_stage VARCHAR(64),
state_version INT DEFAULT 2,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_after_sales_stage (after_sales_stage)
) CHARACTER SET utf8mb4
"""
)
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_e} (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
customer_key VARCHAR(255),
event VARCHAR(128),
payload_json JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_customer_key (customer_key),
INDEX idx_event (event)
) CHARACTER SET utf8mb4
"""
)
conn.commit()
finally:
conn.close()
def get_session(self, customer_key: str) -> dict[str, Any]:
t_s = self.sessions_table
conn = self._conn()
try:
with conn.cursor() as c:
if self.backend == "mysql":
c.execute(
f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=%s",
(customer_key,),
)
else:
c.execute(
f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=?",
(customer_key,),
)
row = c.fetchone()
if not row:
return {"route": "pre_sales", "state": migrate_state_schema({})}
if isinstance(row, dict):
vals = [row.get("acc_id"), row.get("customer_id"), row.get("route"), row.get("state_json"), row.get("after_sales_stage"), row.get("state_version")]
else:
vals = list(row)
raw_state = vals[3]
try:
if isinstance(raw_state, dict):
state = raw_state
else:
state = json.loads(raw_state or "{}")
except Exception:
state = {}
state = migrate_state_schema(state)
if vals[4] and not state.get("after_sales_stage"):
state["after_sales_stage"] = vals[4]
if vals[5] and not state.get("version"):
state["version"] = int(vals[5])
return {
"acc_id": vals[0],
"customer_id": vals[1],
"route": vals[2] or "pre_sales",
"state": state,
}
finally:
conn.close()
def upsert_session(self, customer_key: str, acc_id: str, customer_id: str, route: str, state: dict[str, Any]) -> None:
t_s = self.sessions_table
state = migrate_state_schema(state)
state_json = json.dumps(state or {}, ensure_ascii=False)
after_sales_stage = str(state.get("after_sales_stage", "new") or "new")
state_version = int(state.get("version", 2) or 2)
conn = self._conn()
try:
with conn.cursor() as c:
if self.backend == "mysql":
c.execute(
f"""
INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version)
VALUES(%s,%s,%s,%s,%s,%s,%s)
ON DUPLICATE KEY UPDATE
acc_id=VALUES(acc_id),
customer_id=VALUES(customer_id),
route=VALUES(route),
state_json=VALUES(state_json),
after_sales_stage=VALUES(after_sales_stage),
state_version=VALUES(state_version)
""",
(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version),
)
else:
c.execute(
f"""
INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version)
VALUES(?,?,?,?,?,?,?)
ON CONFLICT(customer_key) DO UPDATE SET
acc_id=excluded.acc_id,
customer_id=excluded.customer_id,
route=excluded.route,
state_json=excluded.state_json,
after_sales_stage=excluded.after_sales_stage,
state_version=excluded.state_version,
updated_at=CURRENT_TIMESTAMP
""",
(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version),
)
conn.commit()
finally:
conn.close()
def append_event(self, customer_key: str, event: str, payload: dict[str, Any]) -> None:
t_e = self.events_table
payload_json = json.dumps(payload or {}, ensure_ascii=False)
conn = self._conn()
try:
with conn.cursor() as c:
if self.backend == "mysql":
c.execute(
f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(%s,%s,%s)",
(customer_key, event, payload_json),
)
else:
c.execute(
f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(?,?,?)",
(customer_key, event, payload_json),
)
conn.commit()
finally:
conn.close()

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import json
import sqlite3
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
DB_PATH = Path(__file__).resolve().parents[1] / 'task_db.sqlite'
class TaskManager:
def __init__(self, db_path: str | None = None) -> None:
self.db_path = db_path or str(DB_PATH)
self._init_db()
def _conn(self):
return sqlite3.connect(self.db_path)
def _init_db(self) -> None:
with self._conn() as c:
c.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
status TEXT NOT NULL,
payload_json TEXT,
result_json TEXT,
created_at TEXT,
updated_at TEXT
)
''')
def create_task(self, payload: dict[str, Any]) -> str:
task_id = uuid.uuid4().hex
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with self._conn() as c:
c.execute(
'INSERT INTO tasks(task_id,status,payload_json,result_json,created_at,updated_at) VALUES(?,?,?,?,?,?)',
(task_id, 'queued', json.dumps(payload, ensure_ascii=False), '{}', now, now),
)
return task_id
def cancel_task(self, task_id: str) -> bool:
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with self._conn() as c:
cur = c.execute(
"UPDATE tasks SET status='cancelled',updated_at=? WHERE task_id=? AND status IN ('queued','running')",
(now, task_id),
)
return cur.rowcount > 0
def get_task(self, task_id: str) -> dict[str, Any] | None:
with self._conn() as c:
row = c.execute(
'SELECT task_id,status,payload_json,result_json,created_at,updated_at FROM tasks WHERE task_id=?',
(task_id,),
).fetchone()
if not row:
return None
return {
'task_id': row[0],
'status': row[1],
'payload': json.loads(row[2] or '{}'),
'result': json.loads(row[3] or '{}'),
'created_at': row[4],
'updated_at': row[5],
}
def list_tasks(self, limit: int = 100) -> list[dict[str, Any]]:
with self._conn() as c:
rows = c.execute(
'SELECT task_id,status,payload_json,result_json,created_at,updated_at FROM tasks ORDER BY created_at DESC LIMIT ?',
(limit,),
).fetchall()
return [
{
'task_id': r[0],
'status': r[1],
'payload': json.loads(r[2] or '{}'),
'result': json.loads(r[3] or '{}'),
'created_at': r[4],
'updated_at': r[5],
}
for r in rows
]