chore: initialize sandbox and overwrite remote content
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
1
qingjian_cs/app/__init__.py
Normal file
1
qingjian_cs/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
71
qingjian_cs/app/agent_tools.py
Normal file
71
qingjian_cs/app/agent_tools.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
from .rules import (
|
||||
detect_intent,
|
||||
detect_order_status,
|
||||
extract_image_urls,
|
||||
extract_size_pairs_m,
|
||||
has_map_or_political_risk,
|
||||
has_porn_risk,
|
||||
is_meaningless_short,
|
||||
requests_external_contact,
|
||||
)
|
||||
|
||||
|
||||
def _tool_ok(value: Any) -> ToolResponse:
|
||||
"""Wrap python values into AgentScope ToolResponse."""
|
||||
try:
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
except Exception:
|
||||
text = str(value)
|
||||
return ToolResponse(
|
||||
content=[{"type": "text", "text": text}],
|
||||
metadata={"value": value},
|
||||
)
|
||||
|
||||
|
||||
def tool_detect_intent(msg: str) -> ToolResponse:
|
||||
"""识别客户当前意图: image/pricing/greeting/external_contact/finish_or_quote_trigger/nonsense/unknown。"""
|
||||
return _tool_ok(detect_intent(msg or ""))
|
||||
|
||||
|
||||
def tool_extract_image_urls(msg: str) -> ToolResponse:
|
||||
"""提取消息中的图片 URL 列表。"""
|
||||
return _tool_ok(extract_image_urls(msg or ""))
|
||||
|
||||
|
||||
def tool_detect_order_status(goods_order: str) -> ToolResponse:
|
||||
"""识别订单状态: paid/pending_payment/refund/unknown。"""
|
||||
return _tool_ok(detect_order_status(goods_order or ""))
|
||||
|
||||
|
||||
def tool_extract_size_pairs(msg: str) -> ToolResponse:
|
||||
"""提取尺寸对,单位米。返回 [(w, h), ...]。"""
|
||||
return _tool_ok(extract_size_pairs_m(msg or ""))
|
||||
|
||||
|
||||
def tool_detect_risk(msg: str, goods_name: str = "") -> ToolResponse:
|
||||
"""检测风险:地图政治、黄暴。"""
|
||||
text = msg or ""
|
||||
gname = goods_name or ""
|
||||
return _tool_ok(
|
||||
{
|
||||
"map_or_political": has_map_or_political_risk(text, gname),
|
||||
"porn": has_porn_risk(text),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def tool_detect_external_contact(msg: str) -> ToolResponse:
|
||||
"""检测是否索要外部联系方式(微信/QQ/手机号等)。"""
|
||||
return _tool_ok(requests_external_contact(msg or ""))
|
||||
|
||||
|
||||
def tool_is_meaningless_short(msg: str) -> ToolResponse:
|
||||
"""检测是否无意义短句(嗯/哦/ok等)。"""
|
||||
return _tool_ok(is_meaningless_short(msg or ""))
|
||||
194
qingjian_cs/app/agents.py
Normal file
194
qingjian_cs/app/agents.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .agent_tools import (
|
||||
tool_detect_external_contact,
|
||||
tool_detect_intent,
|
||||
tool_detect_order_status,
|
||||
tool_detect_risk,
|
||||
tool_extract_image_urls,
|
||||
tool_extract_size_pairs,
|
||||
tool_is_meaningless_short,
|
||||
)
|
||||
from .config import AGENT_MAX_ITERS, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME
|
||||
from .models import Decision, DecisionModel, RouteModel
|
||||
from .rules import rules_prompt
|
||||
|
||||
|
||||
def _ensure_agentscope_importable() -> None:
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
src_dir = repo_root / "src"
|
||||
if src_dir.exists() and str(src_dir) not in sys.path:
|
||||
sys.path.insert(0, str(src_dir))
|
||||
|
||||
|
||||
class _AgentRuntime:
|
||||
def __init__(self, name: str, sys_prompt: str):
|
||||
_ensure_agentscope_importable()
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
if not OPENAI_API_KEY:
|
||||
raise RuntimeError("OPENAI_API_KEY 未设置")
|
||||
|
||||
self.Msg = Msg
|
||||
toolkit = Toolkit()
|
||||
toolkit.register_tool_function(tool_detect_intent)
|
||||
toolkit.register_tool_function(tool_extract_image_urls)
|
||||
toolkit.register_tool_function(tool_detect_order_status)
|
||||
toolkit.register_tool_function(tool_extract_size_pairs)
|
||||
toolkit.register_tool_function(tool_detect_risk)
|
||||
toolkit.register_tool_function(tool_detect_external_contact)
|
||||
toolkit.register_tool_function(tool_is_meaningless_short)
|
||||
|
||||
model = OpenAIChatModel(
|
||||
model_name=OPENAI_MODEL_NAME,
|
||||
api_key=OPENAI_API_KEY,
|
||||
stream=False,
|
||||
client_kwargs={"base_url": OPENAI_BASE_URL},
|
||||
generate_kwargs={"temperature": 0.1},
|
||||
)
|
||||
|
||||
self.agent = ReActAgent(
|
||||
name=name,
|
||||
sys_prompt=sys_prompt,
|
||||
model=model,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
toolkit=toolkit,
|
||||
memory=InMemoryMemory(),
|
||||
max_iters=max(1, AGENT_MAX_ITERS),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_json(text: str) -> dict[str, Any] | None:
|
||||
m = re.search(r"\{[\s\S]*\}", text or "")
|
||||
if not m:
|
||||
return None
|
||||
try:
|
||||
return json.loads(m.group(0))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _msg_to_text(msg: Any) -> str:
|
||||
try:
|
||||
if hasattr(msg, "get_text_content"):
|
||||
v = msg.get_text_content()
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
except Exception:
|
||||
pass
|
||||
c = getattr(msg, "content", None)
|
||||
if isinstance(c, str):
|
||||
return c
|
||||
if isinstance(c, list):
|
||||
out: list[str] = []
|
||||
for b in c:
|
||||
t = getattr(b, "text", None)
|
||||
if isinstance(t, str) and t.strip():
|
||||
out.append(t)
|
||||
return "\n".join(out)
|
||||
return str(msg)
|
||||
|
||||
@staticmethod
|
||||
def _extract_structured(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
candidates = [metadata, metadata.get("structured_output"), metadata.get("result"), metadata.get("output"), metadata.get("json")]
|
||||
for obj in candidates:
|
||||
if isinstance(obj, dict):
|
||||
return obj
|
||||
return None
|
||||
|
||||
|
||||
class RouterAgent(_AgentRuntime):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"RouterAgent",
|
||||
rules_prompt()
|
||||
+ "\n你是路由Agent。只输出路由 pre_sales/quote/after_sales/risk,不直接回复客户。"
|
||||
+ " 你必须基于上下文语义路由,禁止关键词硬匹配。",
|
||||
)
|
||||
|
||||
async def route(self, context: dict[str, Any]) -> tuple[str, str]:
|
||||
prompt = f"按上下文路由到 pre_sales/quote/after_sales/risk。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||||
res = await self.agent(self.Msg("user", prompt, "user"), structured_model=RouteModel)
|
||||
obj = self._extract_structured(getattr(res, "metadata", None)) or self._extract_json(self._msg_to_text(res)) or {}
|
||||
route = str(obj.get("route", "pre_sales") or "pre_sales")
|
||||
if route not in {"pre_sales", "quote", "after_sales", "risk"}:
|
||||
route = "pre_sales"
|
||||
return route, str(obj.get("reason", "") or "")
|
||||
|
||||
|
||||
class QuoteAgent(_AgentRuntime):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"QuoteAgent",
|
||||
rules_prompt() + "\n你是报价Agent。负责收图、报价触发、报价回复和报价阶段状态更新。",
|
||||
)
|
||||
|
||||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||||
prompt = f"你负责报价相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||||
return await _decide_with_model(self, prompt)
|
||||
|
||||
|
||||
class AfterSalesAgent(_AgentRuntime):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"AfterSalesAgent",
|
||||
rules_prompt() + "\n你是售后Agent。负责退款/重发/不满意等售后处理与状态推进。",
|
||||
)
|
||||
|
||||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||||
prompt = f"你负责售后相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||||
return await _decide_with_model(self, prompt)
|
||||
|
||||
|
||||
class RiskAgent(_AgentRuntime):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"RiskAgent",
|
||||
rules_prompt() + "\n你是风控Agent。专注风险识别与风险动作决策。",
|
||||
)
|
||||
|
||||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||||
prompt = f"你负责风控相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||||
return await _decide_with_model(self, prompt)
|
||||
|
||||
|
||||
class PreSalesAgent(_AgentRuntime):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"PreSalesAgent",
|
||||
rules_prompt() + "\n你是售前Agent。处理咨询承接、收图、澄清需求与转报价前动作。",
|
||||
)
|
||||
|
||||
async def decide(self, context: dict[str, Any]) -> Decision:
|
||||
prompt = f"你负责售前相关决策。\n上下文:\n{json.dumps(context, ensure_ascii=False)}"
|
||||
return await _decide_with_model(self, prompt)
|
||||
|
||||
|
||||
async def _decide_with_model(rt: _AgentRuntime, prompt: str) -> Decision:
|
||||
res = await rt.agent(rt.Msg("user", prompt, "user"), structured_model=DecisionModel)
|
||||
obj = rt._extract_structured(getattr(res, "metadata", None)) or rt._extract_json(rt._msg_to_text(res)) or {}
|
||||
action = str(obj.get("action", "reply") or "reply").strip().lower()
|
||||
if action not in {"reply", "quote", "transfer", "noop", "update_state"}:
|
||||
action = "reply"
|
||||
return Decision(
|
||||
action=action,
|
||||
reply=str(obj.get("reply", "") or "").strip(),
|
||||
transfer_msg=str(obj.get("transfer_msg", "") or "").strip(),
|
||||
quote_mode=str(obj.get("quote_mode", "") or "").strip(),
|
||||
state_patch=obj.get("state_patch") if isinstance(obj.get("state_patch"), dict) else {},
|
||||
reason=str(obj.get("reason", "") or "").strip(),
|
||||
)
|
||||
131
qingjian_cs/app/auto_draw.py
Normal file
131
qingjian_cs/app/auto_draw.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .config import AUTO_DRAW_ENDPOINT, AUTO_DRAW_TIMEOUT_SECONDS
|
||||
|
||||
|
||||
def _add_legacy_tw_path() -> None:
|
||||
root = os.getenv("LEGACY_TW_ROOT", r"D:\main\sandbox\tw_terminator").strip()
|
||||
if not root:
|
||||
return
|
||||
p = Path(root)
|
||||
# 先加载 legacy 项目的 .env,确保 service_tuhui_upload 在 import 时拿到正确账号配置
|
||||
legacy_env = p / ".env"
|
||||
if legacy_env.exists():
|
||||
load_dotenv(legacy_env, override=True)
|
||||
if p.exists() and str(p) not in sys.path:
|
||||
sys.path.insert(0, str(p))
|
||||
|
||||
|
||||
async def _draw_via_legacy_tw(
|
||||
image_url: str,
|
||||
customer_id: str,
|
||||
requirement: str = "",
|
||||
) -> dict[str, Any]:
|
||||
# 优先使用当前项目中拷贝过来的 service_gemini
|
||||
from services.service_gemini import GeminiExtractV2Service # type: ignore
|
||||
|
||||
# 上传模块暂时仍走 legacy(你后续可替换为新项目本地上传实现)
|
||||
_add_legacy_tw_path()
|
||||
from services.service_tuhui_upload import upload_to_tuhui # type: ignore
|
||||
|
||||
prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本"
|
||||
|
||||
# 1) 下载原图到本地临时文件
|
||||
input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg")
|
||||
output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg")
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/122.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Referer": "https://www.taobao.com/",
|
||||
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
|
||||
}
|
||||
resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
|
||||
if resp.status_code != 200:
|
||||
return {"ok": False, "error": f"download_http_{resp.status_code}"}
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
# 2) 直调你原来的 service_gemini 作图 API
|
||||
service = GeminiExtractV2Service()
|
||||
ok_extract, msg_extract, _ = await service.extract_pattern(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
custom_prompt=prompt,
|
||||
aspect_ratio="1:1",
|
||||
)
|
||||
if not ok_extract:
|
||||
return {"ok": False, "error": f"extract_failed:{msg_extract}"}
|
||||
if not os.path.exists(output_path):
|
||||
return {"ok": False, "error": "extract_no_output_file"}
|
||||
|
||||
# 3) 上传图绘,返回可外发 URL
|
||||
ok, link, _ = await upload_to_tuhui(
|
||||
output_path,
|
||||
title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图",
|
||||
description="AI自动作图预览",
|
||||
price=1,
|
||||
)
|
||||
if not ok:
|
||||
return {"ok": False, "error": str(link)}
|
||||
return {"ok": True, "url": str(link)}
|
||||
|
||||
|
||||
def _draw_via_http_endpoint(image_url: str, customer_id: str, requirement: str = "") -> dict[str, Any]:
|
||||
if not AUTO_DRAW_ENDPOINT:
|
||||
return {"ok": False, "error": "AUTO_DRAW_ENDPOINT not configured"}
|
||||
payload = {
|
||||
"image_url": image_url,
|
||||
"customer_id": customer_id,
|
||||
"requirement": requirement,
|
||||
}
|
||||
resp = requests.post(AUTO_DRAW_ENDPOINT, json=payload, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
|
||||
if resp.status_code != 200:
|
||||
return {"ok": False, "error": f"http_{resp.status_code}:{resp.text[:200]}"}
|
||||
data = resp.json() if resp.text else {}
|
||||
url = str(data.get("url", "") or data.get("preview_url", "") or "")
|
||||
if not url:
|
||||
return {"ok": False, "error": "missing_preview_url"}
|
||||
return {"ok": True, "url": url}
|
||||
|
||||
|
||||
async def auto_draw_preview(
|
||||
image_url: str,
|
||||
customer_id: str,
|
||||
requirement: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
统一自动作图入口:
|
||||
1) 优先走 tw_terminator 的 service_gemini 直调链路
|
||||
2) 失败时回退 AUTO_DRAW_ENDPOINT
|
||||
"""
|
||||
try:
|
||||
return await _draw_via_legacy_tw(image_url=image_url, customer_id=customer_id, requirement=requirement)
|
||||
except Exception as e:
|
||||
legacy_error = str(e)
|
||||
|
||||
try:
|
||||
data = await asyncio.to_thread(
|
||||
_draw_via_http_endpoint,
|
||||
image_url,
|
||||
customer_id,
|
||||
requirement,
|
||||
)
|
||||
if data.get("ok"):
|
||||
return data
|
||||
return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{data.get('error','unknown')}"}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{e}"}
|
||||
26
qingjian_cs/app/callbacks.py
Normal file
26
qingjian_cs/app/callbacks.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from .config import TIANWANG_CALLBACK_URL
|
||||
|
||||
|
||||
async def post_tianwang_callback(event: str, data: dict[str, Any], extra: dict[str, Any] | None = None, timeout_s: int = 5) -> tuple[bool, int, str]:
|
||||
payload = {
|
||||
'event': event,
|
||||
'data': data,
|
||||
'extra': extra or {},
|
||||
}
|
||||
|
||||
def _post() -> tuple[bool, int, str]:
|
||||
try:
|
||||
resp = requests.post(TIANWANG_CALLBACK_URL, json=payload, timeout=timeout_s)
|
||||
ok = 200 <= resp.status_code < 300
|
||||
return ok, resp.status_code, (resp.text or '')[:300]
|
||||
except Exception as e:
|
||||
return False, 0, str(e)
|
||||
|
||||
return await asyncio.to_thread(_post)
|
||||
370
qingjian_cs/app/client.py
Normal file
370
qingjian_cs/app/client.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import websockets
|
||||
|
||||
from .callbacks import post_tianwang_callback
|
||||
from .auto_draw import auto_draw_preview
|
||||
from .config import (
|
||||
AUTO_DRAW_ENABLED,
|
||||
AUTO_QUOTE_WAIT_SECONDS,
|
||||
MESSAGE_DEBOUNCE_SECONDS,
|
||||
QINGJIAN_WS_URI,
|
||||
SHORT_REPLY_MAX_CHARS,
|
||||
)
|
||||
from .logger import setup_logger
|
||||
from .observability import activity_event, build_trace_id
|
||||
from .orchestrator import Orchestrator
|
||||
from .rules import extract_image_urls, prefilter_message
|
||||
|
||||
|
||||
class QingjianClient:
|
||||
def __init__(self) -> None:
|
||||
self.logger = setup_logger()
|
||||
self.uri = QINGJIAN_WS_URI
|
||||
self.reply_id = "tb001"
|
||||
self.websocket = None
|
||||
self.running = True
|
||||
self.orchestrator = Orchestrator()
|
||||
|
||||
self.pending_msgs: dict[str, list[dict]] = defaultdict(list)
|
||||
self.debounce_tasks: dict[str, asyncio.Task] = {}
|
||||
self.pending_images: dict[str, list[str]] = defaultdict(list)
|
||||
self.auto_quote_tasks: dict[str, asyncio.Task] = {}
|
||||
self.last_reply_key: dict[str, str] = {}
|
||||
self.recent_outbound: list[tuple[str, str, str, float]] = []
|
||||
|
||||
@staticmethod
|
||||
def _customer_key(data: dict) -> str:
|
||||
return f"{data.get('acc_id','')}:{data.get('from_id','')}"
|
||||
|
||||
@staticmethod
|
||||
def _msg_text(data: dict) -> str:
|
||||
return str(data.get("msg", "") or "").strip()
|
||||
|
||||
def _debounce_seconds(self, msg: str) -> float:
|
||||
if extract_image_urls(msg):
|
||||
return 2.5
|
||||
return float(MESSAGE_DEBOUNCE_SECONDS)
|
||||
|
||||
async def send_message(self, message: dict) -> None:
|
||||
if not self.websocket:
|
||||
return
|
||||
await self.websocket.send(json.dumps(message, ensure_ascii=False))
|
||||
self.logger.info("[发送] %s", message.get("msg", ""))
|
||||
|
||||
async def send_reply(self, data: dict, text: str, trace_id: str = "-") -> None:
|
||||
text = str(text or "").strip()
|
||||
if not text:
|
||||
return
|
||||
text = self._shorten_reply(text)
|
||||
msg = {
|
||||
"msg_id": "",
|
||||
"acc_id": data.get("acc_id", ""),
|
||||
"msg": text,
|
||||
"from_id": data.get("from_id", ""),
|
||||
"from_name": data.get("from_name", data.get("from_id", "")),
|
||||
"cy_id": data.get("from_id", ""),
|
||||
"acc_type": data.get("acc_type", "AliWorkbench"),
|
||||
"msg_type": 0,
|
||||
"cy_name": data.get("from_name", data.get("from_id", "")),
|
||||
}
|
||||
activity_event(self.logger, "send_reply_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=text)
|
||||
await self.send_message(msg)
|
||||
self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), text, time.monotonic()))
|
||||
if len(self.recent_outbound) > 200:
|
||||
self.recent_outbound = self.recent_outbound[-200:]
|
||||
activity_event(self.logger, "send_reply_success", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=text)
|
||||
|
||||
async def send_image(self, data: dict, image_url: str, trace_id: str = "-") -> None:
|
||||
image_url = str(image_url or "").strip()
|
||||
if not image_url:
|
||||
return
|
||||
msg = {
|
||||
"msg_id": "",
|
||||
"acc_id": data.get("acc_id", ""),
|
||||
"msg": image_url,
|
||||
"from_id": data.get("from_id", ""),
|
||||
"from_name": data.get("from_name", data.get("from_id", "")),
|
||||
"cy_id": data.get("from_id", ""),
|
||||
"acc_type": data.get("acc_type", "AliWorkbench"),
|
||||
"msg_type": 1,
|
||||
"cy_name": data.get("from_name", data.get("from_id", "")),
|
||||
}
|
||||
activity_event(self.logger, "send_image_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=image_url)
|
||||
await self.send_message(msg)
|
||||
self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), image_url, time.monotonic()))
|
||||
if len(self.recent_outbound) > 200:
|
||||
self.recent_outbound = self.recent_outbound[-200:]
|
||||
activity_event(self.logger, "send_image_success", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=image_url)
|
||||
|
||||
@staticmethod
|
||||
def _clean_text(text: str) -> str:
|
||||
t = str(text or "").strip()
|
||||
t = re.sub(r"\s+", "", t)
|
||||
return t
|
||||
|
||||
def _shorten_reply(self, text: str) -> str:
|
||||
max_len = max(8, int(SHORT_REPLY_MAX_CHARS))
|
||||
t = str(text or "").strip()
|
||||
t = self._humanize_reply(t)
|
||||
if len(t) <= max_len:
|
||||
return t
|
||||
parts = re.split(r"[。!?!?]", t)
|
||||
head = next((p.strip() for p in parts if p and p.strip()), t)
|
||||
if len(head) > max_len:
|
||||
head = head[:max_len].rstrip(",,;;:: ")
|
||||
return head or t[:max_len]
|
||||
|
||||
@staticmethod
|
||||
def _humanize_reply(text: str) -> str:
|
||||
t = str(text or "").strip()
|
||||
return t
|
||||
|
||||
@staticmethod
|
||||
def _is_invalid_ai_reply(text: str) -> bool:
|
||||
t = str(text or "").strip().lower()
|
||||
if not t:
|
||||
return True
|
||||
if "i noticed that you have interrupted me" in t:
|
||||
return True
|
||||
if t.startswith("action:") or t.startswith("{"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _fallback_reply(self, action: str) -> str:
|
||||
if action == "transfer":
|
||||
return "我先给你转人工处理。"
|
||||
return "收到,我先处理一下。"
|
||||
|
||||
def _is_outbound_echo(self, data: dict, msg: str) -> bool:
|
||||
"""
|
||||
轻简可能会把我方刚发送文本回推为“收到消息”。
|
||||
对“短时间完全相同文本”做回环拦截,兼容 acc/from 对调回推,避免无限对话。
|
||||
"""
|
||||
in_acc = str(data.get("acc_id", ""))
|
||||
in_from = str(data.get("from_id", ""))
|
||||
in_msg = self._clean_text(msg)
|
||||
now = time.monotonic()
|
||||
if not in_msg:
|
||||
return False
|
||||
for out_acc, out_to, out_msg, ts in reversed(self.recent_outbound):
|
||||
if (now - ts) > 120:
|
||||
break
|
||||
if self._clean_text(out_msg) != in_msg:
|
||||
continue
|
||||
if (out_acc == in_acc and out_to == in_from) or (out_acc == in_from and out_to == in_acc):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _handle_decision(self, data: dict, merged_msg: str, *, auto_quote: bool = False) -> None:
|
||||
key = self._customer_key(data)
|
||||
trace_id = build_trace_id(data.get("acc_id", ""), data.get("from_id", ""), merged_msg)
|
||||
t0 = time.perf_counter()
|
||||
|
||||
urls = extract_image_urls(merged_msg)
|
||||
if urls:
|
||||
for u in urls:
|
||||
if u not in self.pending_images[key]:
|
||||
self.pending_images[key].append(u)
|
||||
|
||||
context = {
|
||||
"customer_key": key,
|
||||
"acc_id": data.get("acc_id", ""),
|
||||
"customer_id": data.get("from_id", ""),
|
||||
"goods_name": data.get("goods_name", ""),
|
||||
"goods_order": data.get("goods_order", ""),
|
||||
"msg": merged_msg,
|
||||
"intent": "unknown",
|
||||
"pending_images": len(self.pending_images[key]),
|
||||
"auto_quote_trigger": auto_quote,
|
||||
"last_reply": self.last_reply_key.get(key, ""),
|
||||
}
|
||||
|
||||
activity_event(self.logger, "agent_process_start", trace_id=trace_id, customer_id=context["customer_id"], acc_id=context["acc_id"], intent=context["intent"])
|
||||
route, decision, state = await self.orchestrator.decide(context)
|
||||
latency_ms = int((time.perf_counter() - t0) * 1000)
|
||||
activity_event(
|
||||
self.logger,
|
||||
"agent_process_done",
|
||||
trace_id=trace_id,
|
||||
customer_id=context["customer_id"],
|
||||
route=route,
|
||||
action=decision.action,
|
||||
reason=decision.reason,
|
||||
latency_ms=latency_ms,
|
||||
after_sales_stage=state.get("after_sales_stage", "new"),
|
||||
)
|
||||
|
||||
if decision.action == "transfer":
|
||||
text = (decision.transfer_msg or "").strip()
|
||||
if self._is_invalid_ai_reply(text):
|
||||
text = self._fallback_reply("transfer")
|
||||
await self.send_reply(data, text, trace_id=trace_id)
|
||||
self.last_reply_key[key] = text
|
||||
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "transfer", "reply": text})
|
||||
return
|
||||
|
||||
if decision.action == "quote":
|
||||
if AUTO_DRAW_ENABLED and self.pending_images.get(key):
|
||||
latest_image = self.pending_images[key][-1]
|
||||
activity_event(
|
||||
self.logger,
|
||||
"auto_draw_start",
|
||||
trace_id=trace_id,
|
||||
customer_id=context["customer_id"],
|
||||
image_url=latest_image,
|
||||
)
|
||||
draw_res = await auto_draw_preview(
|
||||
image_url=latest_image,
|
||||
customer_id=context["customer_id"],
|
||||
requirement=merged_msg,
|
||||
)
|
||||
if draw_res.get("ok"):
|
||||
preview_url = str(draw_res.get("url", "") or "")
|
||||
await self.send_reply(data, "先给你做了预览图。", trace_id=trace_id)
|
||||
await self.send_image(data, preview_url, trace_id=trace_id)
|
||||
final_text = "看下预览,满意再拍下付款。"
|
||||
await self.send_reply(data, final_text, trace_id=trace_id)
|
||||
self.last_reply_key[key] = final_text
|
||||
# 预览完成后清掉当前批次,避免同一图重复触发
|
||||
self.pending_images[key].clear()
|
||||
activity_event(
|
||||
self.logger,
|
||||
"auto_draw_success",
|
||||
trace_id=trace_id,
|
||||
customer_id=context["customer_id"],
|
||||
preview_url=preview_url,
|
||||
)
|
||||
await post_tianwang_callback(
|
||||
"message_processed",
|
||||
data,
|
||||
extra={"trace_id": trace_id, "route": route, "action": "quote", "reply": final_text, "auto_draw": True},
|
||||
)
|
||||
return
|
||||
activity_event(
|
||||
self.logger,
|
||||
"auto_draw_fail",
|
||||
trace_id=trace_id,
|
||||
customer_id=context["customer_id"],
|
||||
error=str(draw_res.get("error", "unknown")),
|
||||
)
|
||||
text = (decision.reply or "").strip()
|
||||
if self._is_invalid_ai_reply(text):
|
||||
text = self._fallback_reply("quote")
|
||||
if self.last_reply_key.get(key) != text:
|
||||
await self.send_reply(data, text, trace_id=trace_id)
|
||||
self.last_reply_key[key] = text
|
||||
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "quote", "reply": text})
|
||||
return
|
||||
|
||||
if decision.action == "noop":
|
||||
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "noop", "reply": ""})
|
||||
return
|
||||
|
||||
text = (decision.reply or "").strip()
|
||||
if self._is_invalid_ai_reply(text):
|
||||
text = self._fallback_reply("reply")
|
||||
if self.last_reply_key.get(key) != text:
|
||||
await self.send_reply(data, text, trace_id=trace_id)
|
||||
self.last_reply_key[key] = text
|
||||
|
||||
await post_tianwang_callback("message_processed", data, extra={"trace_id": trace_id, "route": route, "action": "reply", "reply": text})
|
||||
|
||||
if self.pending_images[key] and key not in self.auto_quote_tasks:
|
||||
self.auto_quote_tasks[key] = asyncio.create_task(self._auto_quote_later(data))
|
||||
|
||||
async def _auto_quote_later(self, data: dict) -> None:
|
||||
key = self._customer_key(data)
|
||||
try:
|
||||
await asyncio.sleep(AUTO_QUOTE_WAIT_SECONDS)
|
||||
if self.pending_images.get(key):
|
||||
await self._handle_decision(data, "", auto_quote=True)
|
||||
finally:
|
||||
self.auto_quote_tasks.pop(key, None)
|
||||
|
||||
async def _flush_customer(self, key: str) -> None:
|
||||
queue = self.pending_msgs.get(key, [])
|
||||
if not queue:
|
||||
return
|
||||
merged = "、".join([self._msg_text(x) for x in queue if self._msg_text(x)])
|
||||
data = queue[-1]
|
||||
self.pending_msgs[key].clear()
|
||||
await self._handle_decision(data, merged)
|
||||
|
||||
async def _debounce_enqueue(self, data: dict) -> None:
|
||||
key = self._customer_key(data)
|
||||
msg = self._msg_text(data)
|
||||
self.pending_msgs[key].append(data)
|
||||
|
||||
if key in self.debounce_tasks:
|
||||
self.debounce_tasks[key].cancel()
|
||||
|
||||
wait_s = self._debounce_seconds(msg)
|
||||
activity_event(self.logger, "debounce_enqueue", customer_id=data.get("from_id", "-"), key=key, queue_size=len(self.pending_msgs[key]), wait_s=wait_s)
|
||||
|
||||
async def later() -> None:
|
||||
try:
|
||||
await asyncio.sleep(wait_s)
|
||||
await self._flush_customer(key)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
self.debounce_tasks.pop(key, None)
|
||||
|
||||
self.debounce_tasks[key] = asyncio.create_task(later())
|
||||
|
||||
async def _on_message(self, raw: str) -> None:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except Exception:
|
||||
self.logger.info("[非JSON] %s", raw)
|
||||
return
|
||||
|
||||
msg_type = int(data.get("msg_type", 0) or 0)
|
||||
msg = self._msg_text(data)
|
||||
rule = prefilter_message(msg, msg_type)
|
||||
|
||||
self.logger.info("[收消息] acc=%s from=%s type=%s msg=%s", data.get("acc_id", ""), data.get("from_id", ""), msg_type, msg)
|
||||
await post_tianwang_callback("message_received", data, extra={"msg_type": msg_type})
|
||||
|
||||
if self._is_outbound_echo(data, msg):
|
||||
activity_event(
|
||||
self.logger,
|
||||
"inbound_ignored",
|
||||
customer_id=data.get("from_id", "-"),
|
||||
reason="outbound_echo_loop_guard",
|
||||
)
|
||||
return
|
||||
|
||||
if rule.ignore:
|
||||
activity_event(self.logger, "inbound_ignored", customer_id=data.get("from_id", "-"), reason=rule.reason)
|
||||
return
|
||||
|
||||
patched = dict(data)
|
||||
patched["msg"] = rule.normalized_msg or msg
|
||||
|
||||
if msg_type == 1:
|
||||
await self._handle_decision(patched, patched["msg"])
|
||||
return
|
||||
|
||||
await self._debounce_enqueue(patched)
|
||||
|
||||
async def _serve(self) -> None:
|
||||
while self.running:
|
||||
try:
|
||||
self.logger.info("[连接] %s", self.uri)
|
||||
async with websockets.connect(self.uri) as ws:
|
||||
self.websocket = ws
|
||||
self.logger.info("[连接成功]")
|
||||
async for raw in ws:
|
||||
await self._on_message(raw)
|
||||
except Exception as e:
|
||||
self.logger.info("[连接异常] %s", e)
|
||||
await asyncio.sleep(3)
|
||||
|
||||
def run(self) -> None:
|
||||
asyncio.run(self._serve())
|
||||
37
qingjian_cs/app/config.py
Normal file
37
qingjian_cs/app/config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(Path(__file__).resolve().parents[1] / '.env')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
QINGJIAN_WS_URI = os.getenv("QINGJIAN_WS_URI", "ws://127.0.0.1:9528")
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip()
|
||||
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3").strip()
|
||||
OPENAI_MODEL_NAME = os.getenv("OPENAI_MODEL_NAME", "doubao-seed-2-0-pro-260215").strip()
|
||||
|
||||
MESSAGE_DEBOUNCE_SECONDS = int(os.getenv("MESSAGE_DEBOUNCE_SECONDS", "6"))
|
||||
AUTO_QUOTE_WAIT_SECONDS = int(os.getenv("AUTO_QUOTE_WAIT_SECONDS", "18"))
|
||||
AGENT_MAX_ITERS = int(os.getenv("AGENT_MAX_ITERS", "3"))
|
||||
FAST_ROUTE_ENABLED = os.getenv("FAST_ROUTE_ENABLED", "1").strip() in {"1", "true", "True", "yes", "on"}
|
||||
SHORT_REPLY_MAX_CHARS = int(os.getenv("SHORT_REPLY_MAX_CHARS", "20"))
|
||||
|
||||
STORE_BACKEND = os.getenv("STORE_BACKEND", "sqlite").strip().lower()
|
||||
STORE_SQLITE_PATH = os.getenv("STORE_SQLITE_PATH", "").strip()
|
||||
|
||||
MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1").strip()
|
||||
MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
MYSQL_USER = os.getenv("MYSQL_USER", "root").strip()
|
||||
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "").strip()
|
||||
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs").strip()
|
||||
MYSQL_TABLE_PREFIX = os.getenv("MYSQL_TABLE_PREFIX", "qjcs_").strip()
|
||||
|
||||
HTTP_HOST = os.getenv("HTTP_HOST", "127.0.0.1").strip()
|
||||
HTTP_PORT = int(os.getenv("HTTP_PORT", "6060"))
|
||||
TIANWANG_CALLBACK_URL = os.getenv("TIANWANG_CALLBACK_URL", "http://139.199.3.75:18789/api/callback").strip()
|
||||
|
||||
AUTO_DRAW_ENABLED = os.getenv("AUTO_DRAW_ENABLED", "1").strip() in {"1", "true", "True", "yes", "on"}
|
||||
AUTO_DRAW_ENDPOINT = os.getenv("AUTO_DRAW_ENDPOINT", "").strip()
|
||||
AUTO_DRAW_TIMEOUT_SECONDS = int(os.getenv("AUTO_DRAW_TIMEOUT_SECONDS", "25"))
|
||||
51
qingjian_cs/app/http_api.py
Normal file
51
qingjian_cs/app/http_api.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import Flask, jsonify, request
|
||||
|
||||
from .logger import setup_logger
|
||||
from .task_manager import TaskManager
|
||||
|
||||
|
||||
def create_http_app(task_manager: TaskManager | None = None) -> Flask:
|
||||
app = Flask(__name__)
|
||||
logger = setup_logger()
|
||||
tm = task_manager or TaskManager()
|
||||
|
||||
@app.get('/api/health')
|
||||
def health():
|
||||
return jsonify({'ok': True})
|
||||
|
||||
@app.post('/api/task/receive')
|
||||
def receive_task():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
task_id = tm.create_task(payload)
|
||||
logger.info('[任务] receive task_id=%s', task_id)
|
||||
return jsonify({'ok': True, 'task_id': task_id})
|
||||
|
||||
@app.post('/api/task/cancel')
|
||||
def cancel_task():
|
||||
body = request.get_json(silent=True) or {}
|
||||
task_id = str(body.get('task_id', '')).strip()
|
||||
if not task_id:
|
||||
return jsonify({'ok': False, 'error': 'task_id required'}), 400
|
||||
ok = tm.cancel_task(task_id)
|
||||
return jsonify({'ok': ok, 'task_id': task_id})
|
||||
|
||||
@app.get('/api/task/status/<task_id>')
|
||||
def task_status(task_id: str):
|
||||
task = tm.get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({'ok': False, 'error': 'not found'}), 404
|
||||
return jsonify({'ok': True, 'task': task})
|
||||
|
||||
@app.get('/api/task/list')
|
||||
def task_list():
|
||||
limit = int(request.args.get('limit', 100))
|
||||
return jsonify({'ok': True, 'tasks': tm.list_tasks(limit=limit)})
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_http_server(host: str, port: int, task_manager: TaskManager | None = None) -> None:
|
||||
app = create_http_app(task_manager=task_manager)
|
||||
app.run(host=host, port=port, debug=False, use_reloader=False)
|
||||
22
qingjian_cs/app/logger.py
Normal file
22
qingjian_cs/app/logger.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logger() -> logging.Logger:
|
||||
logger = logging.getLogger("qingjian_cs")
|
||||
if logger.handlers:
|
||||
return logger
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", "%H:%M:%S")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# 降低 AgentScope 内部推理/格式器日志噪音,保留本项目活动日志。
|
||||
logging.getLogger("agentscope").setLevel(logging.ERROR)
|
||||
logging.getLogger("agentscope.formatter").setLevel(logging.ERROR)
|
||||
logging.getLogger("agentscope.agent").setLevel(logging.ERROR)
|
||||
logging.getLogger("_openai_formatter").setLevel(logging.ERROR)
|
||||
logging.getLogger("_react_agent").setLevel(logging.ERROR)
|
||||
|
||||
return logger
|
||||
30
qingjian_cs/app/models.py
Normal file
30
qingjian_cs/app/models.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DecisionModel(BaseModel):
|
||||
action: Literal["reply", "quote", "transfer", "noop", "update_state"] = Field(description="唯一动作")
|
||||
reply: str = Field(default="", description="给客户的回复")
|
||||
transfer_msg: str = Field(default="", description="转人工提示")
|
||||
quote_mode: Literal["flush_pending", "analyze_current_or_recent", "collect_only", ""] = Field(default="")
|
||||
state_patch: dict = Field(default_factory=dict, description="状态增量")
|
||||
reason: str = Field(default="", description="内部原因")
|
||||
|
||||
|
||||
class RouteModel(BaseModel):
|
||||
route: Literal["pre_sales", "quote", "after_sales", "risk"] = Field(description="路由目标")
|
||||
reason: str = Field(default="")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Decision:
|
||||
action: str
|
||||
reply: str = ""
|
||||
transfer_msg: str = ""
|
||||
quote_mode: str = ""
|
||||
state_patch: dict | None = None
|
||||
reason: str = ""
|
||||
23
qingjian_cs/app/observability.py
Normal file
23
qingjian_cs/app/observability.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
def now_ts() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
|
||||
def build_trace_id(acc_id: str, customer_id: str, msg: str) -> str:
|
||||
base = f"{acc_id}|{customer_id}|{msg}|{now_ts()}"
|
||||
return hex(abs(hash(base)))[2:18]
|
||||
|
||||
|
||||
def activity_event(logger, event: str, *, trace_id: str = '-', customer_id: str = '-', result: str = 'ok', **kwargs: Any) -> None:
|
||||
payload = {
|
||||
'trace_id': trace_id or '-',
|
||||
'customer_id': customer_id or '-',
|
||||
'event': event,
|
||||
'result': result,
|
||||
**kwargs,
|
||||
}
|
||||
logger.info('[活动日志] %s', json.dumps(payload, ensure_ascii=False))
|
||||
71
qingjian_cs/app/orchestrator.py
Normal file
71
qingjian_cs/app/orchestrator.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .agents import AfterSalesAgent, PreSalesAgent, QuoteAgent, RiskAgent, RouterAgent
|
||||
from .models import Decision
|
||||
from .state_machine import evolve_after_sales_state, migrate_state_schema
|
||||
from .store import ConversationStore
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
def __init__(self) -> None:
|
||||
self.router = RouterAgent()
|
||||
self.pre_sales = PreSalesAgent()
|
||||
self.quote = QuoteAgent()
|
||||
self.after_sales = AfterSalesAgent()
|
||||
self.risk = RiskAgent()
|
||||
self.store = ConversationStore()
|
||||
|
||||
async def decide(self, context: dict[str, Any]) -> tuple[str, Decision, dict[str, Any]]:
|
||||
customer_key = context["customer_key"]
|
||||
session = self.store.get_session(customer_key)
|
||||
prev_state = migrate_state_schema(session.get("state", {}))
|
||||
prev_route = session.get("route", "pre_sales")
|
||||
|
||||
# 统一改为语义决策:不走关键词意图/订单硬判定。
|
||||
intent = "unknown"
|
||||
order_status = "unknown"
|
||||
|
||||
merged_ctx = {
|
||||
**context,
|
||||
"session_state": prev_state,
|
||||
"previous_route": prev_route,
|
||||
"intent": intent,
|
||||
"order_status": order_status,
|
||||
}
|
||||
|
||||
route, route_reason = await self.router.route(merged_ctx)
|
||||
|
||||
if route == "quote":
|
||||
decision = await self.quote.decide(merged_ctx)
|
||||
elif route == "after_sales":
|
||||
decision = await self.after_sales.decide(merged_ctx)
|
||||
elif route == "risk":
|
||||
decision = await self.risk.decide(merged_ctx)
|
||||
else:
|
||||
decision = await self.pre_sales.decide(merged_ctx)
|
||||
|
||||
merged_state = {**prev_state, **(decision.state_patch or {})}
|
||||
new_state = evolve_after_sales_state(
|
||||
merged_state,
|
||||
route=route,
|
||||
action=decision.action,
|
||||
intent=intent,
|
||||
order_status=order_status,
|
||||
msg=str(context.get("msg", "") or ""),
|
||||
)
|
||||
|
||||
self.store.upsert_session(customer_key, context.get("acc_id", ""), context.get("customer_id", ""), route, new_state)
|
||||
self.store.append_event(
|
||||
customer_key,
|
||||
"decision",
|
||||
{
|
||||
"route": route,
|
||||
"route_reason": route_reason,
|
||||
"action": decision.action,
|
||||
"reason": decision.reason,
|
||||
"after_sales_stage": new_state.get("after_sales_stage", "new"),
|
||||
},
|
||||
)
|
||||
return route, decision, new_state
|
||||
137
qingjian_cs/app/rules.py
Normal file
137
qingjian_cs/app/rules.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
IMAGE_URL_RE = re.compile(r"https?://[^\s]+(?:\.jpg|\.jpeg|\.png|\.webp|\.bmp|\.gif)(?:\?[^\s]*)?", re.I)
|
||||
SIZE_RE = re.compile(r"(\d+(?:\.\d+)?)\s*(米|m|M)\s*[xX*乘]\s*(\d+(?:\.\d+)?)\s*(米|m|M)")
|
||||
|
||||
@dataclass
|
||||
class RuleResult:
|
||||
ignore: bool = False
|
||||
normalized_msg: str = ""
|
||||
reason: str = ""
|
||||
|
||||
|
||||
def extract_customer_text_from_shop_card(msg: str) -> str:
|
||||
if "[进店卡片]" not in (msg or ""):
|
||||
return ""
|
||||
prefix = msg.split("#*#[进店卡片]", 1)[0].strip()
|
||||
if prefix and prefix not in {"你好", "您好", "在吗"}:
|
||||
return prefix
|
||||
return prefix
|
||||
|
||||
|
||||
def detect_order_status(order_text: str) -> str:
|
||||
# 订单状态交给主决策 AI 从上下文语义判断。
|
||||
return "unknown"
|
||||
|
||||
|
||||
def extract_size_pairs_m(msg: str) -> list[tuple[float, float]]:
|
||||
out: list[tuple[float, float]] = []
|
||||
for m in SIZE_RE.finditer(msg or ""):
|
||||
w = float(m.group(1))
|
||||
h = float(m.group(3))
|
||||
out.append((w, h))
|
||||
return out
|
||||
|
||||
|
||||
def has_map_or_political_risk(msg: str, goods_name: str = "") -> bool:
|
||||
# 风险由 RiskAgent 语义判断。
|
||||
return False
|
||||
|
||||
|
||||
def has_porn_risk(msg: str) -> bool:
|
||||
# 风险由 RiskAgent 语义判断。
|
||||
return False
|
||||
|
||||
|
||||
def requests_external_contact(msg: str) -> bool:
|
||||
# 外联风险由 RiskAgent 语义判断。
|
||||
return False
|
||||
|
||||
|
||||
def is_meaningless_short(msg: str) -> bool:
|
||||
# 无意义短句由主决策 AI 语义判断。
|
||||
return False
|
||||
|
||||
|
||||
def prefilter_message(msg: str, msg_type: int) -> RuleResult:
|
||||
m = (msg or "").strip()
|
||||
if not m:
|
||||
return RuleResult(ignore=True, reason="empty")
|
||||
if msg_type not in (0, 1):
|
||||
return RuleResult(ignore=True, reason="unsupported_msg_type")
|
||||
if "由 " in m and " 转交给 " in m:
|
||||
return RuleResult(ignore=True, reason="transfer_notice")
|
||||
if "Gemini 店铺消息,跳过" in m:
|
||||
return RuleResult(ignore=True, reason="system_echo")
|
||||
if "[进店卡片]" in m:
|
||||
t = extract_customer_text_from_shop_card(m)
|
||||
if t:
|
||||
return RuleResult(ignore=False, normalized_msg=t, reason="shop_card_with_text")
|
||||
return RuleResult(ignore=True, reason="pure_shop_card")
|
||||
return RuleResult(ignore=False, normalized_msg=m, reason="normal")
|
||||
|
||||
|
||||
def detect_intent(msg: str) -> str:
|
||||
m = (msg or "").lower()
|
||||
if IMAGE_URL_RE.search(m):
|
||||
return "image"
|
||||
# 其余意图交给 AI 语义判断。
|
||||
return "unknown"
|
||||
|
||||
|
||||
def extract_image_urls(msg: str) -> list[str]:
|
||||
return IMAGE_URL_RE.findall(msg or "")
|
||||
|
||||
|
||||
def rules_prompt() -> str:
|
||||
return (
|
||||
"你是淘宝图像服务客服系统的统一决策AI。必须按以下 MASTER_RULES 执行。\n"
|
||||
"只输出 JSON 决策,不要解释过程。\n"
|
||||
"动作 action 只能是: reply / quote / transfer / noop / update_state。\n\n"
|
||||
"HARD_RULES(最高优先级,必须先判断):\n"
|
||||
"1) 命中以下任一类,禁止报价(action=quote):\n"
|
||||
" - 政治/涉政/政治人物/政治事件/政治宣传;\n"
|
||||
" - 地图类(地图/地形图/行政区划图/卫星地图等);\n"
|
||||
" - 黄暴/擦边/色情/露点/明显违规内容;\n"
|
||||
" - 客户索要站外联系方式(微信/QQ/手机号);\n"
|
||||
" - 文本出现超大尺寸需求(如超长边或超大面积,明显超出常规制作能力)。\n"
|
||||
"2) 命中硬规则时,只能 action=reply 或 action=transfer:\n"
|
||||
" - 可直接拒绝的,action=reply,reply简短明确边界;\n"
|
||||
" - 风险高或不确定的,action=transfer,给 transfer_msg。\n"
|
||||
"3) 命中硬规则后禁止改口:后续客户追问“能不能做”,仍保持同一结论。\n"
|
||||
"4) 多图场景若部分可做、部分不可做,必须明确“哪张可做、哪张不做”,禁止含糊表述。\n\n"
|
||||
"MASTER_RULES:\n"
|
||||
"A. 统一动作语义\n"
|
||||
"1) reply: 直接回复客户。\n"
|
||||
"2) quote: 触发报价或触发看图后报价流程。\n"
|
||||
"3) transfer: 转人工,必须给 transfer_msg。\n"
|
||||
"4) noop: 当前不需要回复。\n"
|
||||
"5) update_state: 仅更新状态,不对外发消息。\n\n"
|
||||
"B. 售前与报价\n"
|
||||
"1) 客户发图: 优先承接,可继续收图;不强行一次报完。\n"
|
||||
"2) 客户询价且已有图(当前图/待处理图/最近图): 优先 action=quote。\n"
|
||||
"3) 客户无图询价: action=reply,引导其先发图。\n"
|
||||
"4) 客户说“发完了/就这些/报价吧”: 若有图则 action=quote。\n"
|
||||
"5) 不能承诺“一模一样原图必找到”,可说先看图评估。\n"
|
||||
"6) 尺寸很大或要求高还原时,不夸张承诺,先说明可评估后给结论。\n\n"
|
||||
"C. 订单阶段\n"
|
||||
"1) 已付款: 可回复“已安排处理/正在处理/完成后发你确认”。\n"
|
||||
"2) 待付款: 可提示付款,但不与客户争执;必要时先给预览再引导付款。\n"
|
||||
"3) 退款/售后诉求: 进入售后语境,保持克制,必要时转人工。\n\n"
|
||||
"D. 风控与合规\n"
|
||||
"1) 涉政治/地图边界/黄暴/违规内容: 按 HARD_RULES 执行,禁止报价。\n"
|
||||
"2) 客户索要微信/QQ/手机号等站外联系方式: 不外呼,站内引导。\n"
|
||||
"3) 高风险不确定时,不硬答,给保守回复或转人工。\n\n"
|
||||
"E. 对话质量\n"
|
||||
"1) 单次只做一个动作,不混合。\n"
|
||||
"2) 避免重复同一句话;若语义相同,换表达。\n"
|
||||
"3) reply 必须短: 优先 1 句,口语化,避免AI腔。\n"
|
||||
"4) 不要输出思考过程,不要输出 tool_use 文本给客户。\n"
|
||||
"5) 若上下文不足,先澄清 1 个关键问题,不要连续追问。\n\n"
|
||||
"F. 店铺人格\n"
|
||||
"1) 按店铺/账号口吻说话,像真人客服,不要机械模板。\n"
|
||||
"2) 语气友好直接,不啰嗦,不说“作为AI”。\n\n"
|
||||
"输出格式:\n"
|
||||
'{"action":"reply|quote|transfer|noop|update_state","reply":"","transfer_msg":"","quote_mode":"flush_pending|analyze_current_or_recent|collect_only","state_patch":{},"reason":""}'
|
||||
)
|
||||
60
qingjian_cs/app/state_machine.py
Normal file
60
qingjian_cs/app/state_machine.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
AFTER_SALES_STAGES = {
|
||||
"new",
|
||||
"waiting_material",
|
||||
"quoted",
|
||||
"processing",
|
||||
"waiting_feedback",
|
||||
"done",
|
||||
"refunding",
|
||||
"transferred",
|
||||
}
|
||||
|
||||
|
||||
def migrate_state_schema(state: dict[str, Any] | None) -> dict[str, Any]:
|
||||
src = dict(state or {})
|
||||
# 兼容旧字段
|
||||
if "after_sales_stage" not in src:
|
||||
old = src.get("aftersales_stage") or src.get("status") or "new"
|
||||
src["after_sales_stage"] = str(old)
|
||||
if src.get("after_sales_stage") not in AFTER_SALES_STAGES:
|
||||
src["after_sales_stage"] = "new"
|
||||
|
||||
if "quote_count" not in src:
|
||||
src["quote_count"] = int(src.get("quotes", 0) or 0)
|
||||
if "image_count" not in src:
|
||||
src["image_count"] = int(src.get("images", 0) or 0)
|
||||
if "last_intent" not in src:
|
||||
src["last_intent"] = str(src.get("intent", "unknown") or "unknown")
|
||||
if "version" not in src:
|
||||
src["version"] = 2
|
||||
return src
|
||||
|
||||
|
||||
def evolve_after_sales_state(prev_state: dict[str, Any], *, route: str, action: str, intent: str, order_status: str, msg: str) -> dict[str, Any]:
|
||||
s = migrate_state_schema(prev_state)
|
||||
stage = s.get("after_sales_stage", "new")
|
||||
|
||||
if action == "transfer" or route == "risk":
|
||||
stage = "transferred"
|
||||
elif route == "quote" and action == "quote":
|
||||
stage = "quoted"
|
||||
s["quote_count"] = int(s.get("quote_count", 0)) + 1
|
||||
elif route == "after_sales":
|
||||
if order_status == "paid":
|
||||
stage = "processing"
|
||||
elif stage == "new":
|
||||
stage = "waiting_material"
|
||||
elif route == "pre_sales":
|
||||
if intent == "image":
|
||||
stage = "waiting_material"
|
||||
|
||||
s["after_sales_stage"] = stage
|
||||
s["last_intent"] = intent
|
||||
s["last_order_status"] = order_status or "unknown"
|
||||
s["version"] = 2
|
||||
return s
|
||||
251
qingjian_cs/app/store.py
Normal file
251
qingjian_cs/app/store.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .config import (
|
||||
MYSQL_DATABASE,
|
||||
MYSQL_HOST,
|
||||
MYSQL_PASSWORD,
|
||||
MYSQL_PORT,
|
||||
MYSQL_TABLE_PREFIX,
|
||||
MYSQL_USER,
|
||||
STORE_BACKEND,
|
||||
STORE_SQLITE_PATH,
|
||||
)
|
||||
from .state_machine import migrate_state_schema
|
||||
|
||||
DB_PATH = Path(__file__).resolve().parents[1] / "qingjian_cs.db"
|
||||
|
||||
|
||||
def _safe_prefix(v: str) -> str:
|
||||
p = re.sub(r"[^a-zA-Z0-9_]", "", (v or "").strip())
|
||||
return p or "qjcs_"
|
||||
|
||||
|
||||
class ConversationStore:
|
||||
def __init__(self, backend: str | None = None, db_path: str | None = None) -> None:
|
||||
self.backend = (backend or STORE_BACKEND or "sqlite").lower()
|
||||
self.db_path = db_path or STORE_SQLITE_PATH or str(DB_PATH)
|
||||
self.prefix = _safe_prefix(MYSQL_TABLE_PREFIX)
|
||||
self.sessions_table = f"{self.prefix}sessions"
|
||||
self.events_table = f"{self.prefix}events"
|
||||
self._init_db()
|
||||
|
||||
def _sqlite_conn(self):
|
||||
return sqlite3.connect(self.db_path)
|
||||
|
||||
def _mysql_conn(self):
|
||||
import pymysql
|
||||
|
||||
return pymysql.connect(
|
||||
host=MYSQL_HOST,
|
||||
port=MYSQL_PORT,
|
||||
user=MYSQL_USER,
|
||||
password=MYSQL_PASSWORD,
|
||||
database=MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
def _conn(self):
|
||||
if self.backend == "mysql":
|
||||
return self._mysql_conn()
|
||||
return self._sqlite_conn()
|
||||
|
||||
def _init_db(self) -> None:
|
||||
if self.backend == "mysql":
|
||||
self._init_mysql()
|
||||
else:
|
||||
self._init_sqlite()
|
||||
|
||||
def _ensure_sqlite_column(self, conn: sqlite3.Connection, table: str, col: str, ddl: str) -> None:
|
||||
cols = {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()}
|
||||
if col not in cols:
|
||||
conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}")
|
||||
|
||||
def _init_sqlite(self) -> None:
|
||||
t_s = self.sessions_table
|
||||
t_e = self.events_table
|
||||
with self._sqlite_conn() as c:
|
||||
c.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {t_s} (
|
||||
customer_key TEXT PRIMARY KEY,
|
||||
acc_id TEXT,
|
||||
customer_id TEXT,
|
||||
route TEXT,
|
||||
state_json TEXT,
|
||||
after_sales_stage TEXT,
|
||||
state_version INTEGER DEFAULT 2,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
c.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {t_e} (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_key TEXT,
|
||||
event TEXT,
|
||||
payload_json TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
self._ensure_sqlite_column(c, t_s, "after_sales_stage", "after_sales_stage TEXT")
|
||||
self._ensure_sqlite_column(c, t_s, "state_version", "state_version INTEGER DEFAULT 2")
|
||||
|
||||
def _init_mysql(self) -> None:
|
||||
t_s = self.sessions_table
|
||||
t_e = self.events_table
|
||||
conn = self._mysql_conn()
|
||||
try:
|
||||
with conn.cursor() as c:
|
||||
c.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {t_s} (
|
||||
customer_key VARCHAR(255) PRIMARY KEY,
|
||||
acc_id VARCHAR(255),
|
||||
customer_id VARCHAR(255),
|
||||
route VARCHAR(64),
|
||||
state_json JSON,
|
||||
after_sales_stage VARCHAR(64),
|
||||
state_version INT DEFAULT 2,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_after_sales_stage (after_sales_stage)
|
||||
) CHARACTER SET utf8mb4
|
||||
"""
|
||||
)
|
||||
c.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {t_e} (
|
||||
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
||||
customer_key VARCHAR(255),
|
||||
event VARCHAR(128),
|
||||
payload_json JSON,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_customer_key (customer_key),
|
||||
INDEX idx_event (event)
|
||||
) CHARACTER SET utf8mb4
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_session(self, customer_key: str) -> dict[str, Any]:
|
||||
t_s = self.sessions_table
|
||||
conn = self._conn()
|
||||
try:
|
||||
with conn.cursor() as c:
|
||||
if self.backend == "mysql":
|
||||
c.execute(
|
||||
f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=%s",
|
||||
(customer_key,),
|
||||
)
|
||||
else:
|
||||
c.execute(
|
||||
f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=?",
|
||||
(customer_key,),
|
||||
)
|
||||
row = c.fetchone()
|
||||
if not row:
|
||||
return {"route": "pre_sales", "state": migrate_state_schema({})}
|
||||
|
||||
if isinstance(row, dict):
|
||||
vals = [row.get("acc_id"), row.get("customer_id"), row.get("route"), row.get("state_json"), row.get("after_sales_stage"), row.get("state_version")]
|
||||
else:
|
||||
vals = list(row)
|
||||
|
||||
raw_state = vals[3]
|
||||
try:
|
||||
if isinstance(raw_state, dict):
|
||||
state = raw_state
|
||||
else:
|
||||
state = json.loads(raw_state or "{}")
|
||||
except Exception:
|
||||
state = {}
|
||||
state = migrate_state_schema(state)
|
||||
|
||||
if vals[4] and not state.get("after_sales_stage"):
|
||||
state["after_sales_stage"] = vals[4]
|
||||
if vals[5] and not state.get("version"):
|
||||
state["version"] = int(vals[5])
|
||||
|
||||
return {
|
||||
"acc_id": vals[0],
|
||||
"customer_id": vals[1],
|
||||
"route": vals[2] or "pre_sales",
|
||||
"state": state,
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def upsert_session(self, customer_key: str, acc_id: str, customer_id: str, route: str, state: dict[str, Any]) -> None:
|
||||
t_s = self.sessions_table
|
||||
state = migrate_state_schema(state)
|
||||
state_json = json.dumps(state or {}, ensure_ascii=False)
|
||||
after_sales_stage = str(state.get("after_sales_stage", "new") or "new")
|
||||
state_version = int(state.get("version", 2) or 2)
|
||||
|
||||
conn = self._conn()
|
||||
try:
|
||||
with conn.cursor() as c:
|
||||
if self.backend == "mysql":
|
||||
c.execute(
|
||||
f"""
|
||||
INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version)
|
||||
VALUES(%s,%s,%s,%s,%s,%s,%s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
acc_id=VALUES(acc_id),
|
||||
customer_id=VALUES(customer_id),
|
||||
route=VALUES(route),
|
||||
state_json=VALUES(state_json),
|
||||
after_sales_stage=VALUES(after_sales_stage),
|
||||
state_version=VALUES(state_version)
|
||||
""",
|
||||
(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version),
|
||||
)
|
||||
else:
|
||||
c.execute(
|
||||
f"""
|
||||
INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version)
|
||||
VALUES(?,?,?,?,?,?,?)
|
||||
ON CONFLICT(customer_key) DO UPDATE SET
|
||||
acc_id=excluded.acc_id,
|
||||
customer_id=excluded.customer_id,
|
||||
route=excluded.route,
|
||||
state_json=excluded.state_json,
|
||||
after_sales_stage=excluded.after_sales_stage,
|
||||
state_version=excluded.state_version,
|
||||
updated_at=CURRENT_TIMESTAMP
|
||||
""",
|
||||
(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def append_event(self, customer_key: str, event: str, payload: dict[str, Any]) -> None:
|
||||
t_e = self.events_table
|
||||
payload_json = json.dumps(payload or {}, ensure_ascii=False)
|
||||
conn = self._conn()
|
||||
try:
|
||||
with conn.cursor() as c:
|
||||
if self.backend == "mysql":
|
||||
c.execute(
|
||||
f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(%s,%s,%s)",
|
||||
(customer_key, event, payload_json),
|
||||
)
|
||||
else:
|
||||
c.execute(
|
||||
f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(?,?,?)",
|
||||
(customer_key, event, payload_json),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
86
qingjian_cs/app/task_manager.py
Normal file
86
qingjian_cs/app/task_manager.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
DB_PATH = Path(__file__).resolve().parents[1] / 'task_db.sqlite'
|
||||
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
self.db_path = db_path or str(DB_PATH)
|
||||
self._init_db()
|
||||
|
||||
def _conn(self):
|
||||
return sqlite3.connect(self.db_path)
|
||||
|
||||
def _init_db(self) -> None:
|
||||
with self._conn() as c:
|
||||
c.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
status TEXT NOT NULL,
|
||||
payload_json TEXT,
|
||||
result_json TEXT,
|
||||
created_at TEXT,
|
||||
updated_at TEXT
|
||||
)
|
||||
''')
|
||||
|
||||
def create_task(self, payload: dict[str, Any]) -> str:
|
||||
task_id = uuid.uuid4().hex
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
with self._conn() as c:
|
||||
c.execute(
|
||||
'INSERT INTO tasks(task_id,status,payload_json,result_json,created_at,updated_at) VALUES(?,?,?,?,?,?)',
|
||||
(task_id, 'queued', json.dumps(payload, ensure_ascii=False), '{}', now, now),
|
||||
)
|
||||
return task_id
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
with self._conn() as c:
|
||||
cur = c.execute(
|
||||
"UPDATE tasks SET status='cancelled',updated_at=? WHERE task_id=? AND status IN ('queued','running')",
|
||||
(now, task_id),
|
||||
)
|
||||
return cur.rowcount > 0
|
||||
|
||||
def get_task(self, task_id: str) -> dict[str, Any] | None:
|
||||
with self._conn() as c:
|
||||
row = c.execute(
|
||||
'SELECT task_id,status,payload_json,result_json,created_at,updated_at FROM tasks WHERE task_id=?',
|
||||
(task_id,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return {
|
||||
'task_id': row[0],
|
||||
'status': row[1],
|
||||
'payload': json.loads(row[2] or '{}'),
|
||||
'result': json.loads(row[3] or '{}'),
|
||||
'created_at': row[4],
|
||||
'updated_at': row[5],
|
||||
}
|
||||
|
||||
def list_tasks(self, limit: int = 100) -> list[dict[str, Any]]:
|
||||
with self._conn() as c:
|
||||
rows = c.execute(
|
||||
'SELECT task_id,status,payload_json,result_json,created_at,updated_at FROM tasks ORDER BY created_at DESC LIMIT ?',
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [
|
||||
{
|
||||
'task_id': r[0],
|
||||
'status': r[1],
|
||||
'payload': json.loads(r[2] or '{}'),
|
||||
'result': json.loads(r[3] or '{}'),
|
||||
'created_at': r[4],
|
||||
'updated_at': r[5],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
Reference in New Issue
Block a user