feat: improve context memory and fix auto-draw gemini/upload chain
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled

This commit is contained in:
2026-03-03 10:18:02 +08:00
parent 382581b9bc
commit 00166d7ebf
4 changed files with 223 additions and 109 deletions

View File

@@ -1,105 +1,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import os import os
import sys
import tempfile import tempfile
import uuid import uuid
from pathlib import Path
from typing import Any from typing import Any
import requests import requests
from dotenv import load_dotenv
from .config import AUTO_DRAW_ENDPOINT, AUTO_DRAW_TIMEOUT_SECONDS from .config import AUTO_DRAW_TIMEOUT_SECONDS
def _add_legacy_tw_path() -> None:
root = os.getenv("LEGACY_TW_ROOT", r"D:\main\sandbox\tw_terminator").strip()
if not root:
return
p = Path(root)
# 先加载 legacy 项目的 .env确保 service_tuhui_upload 在 import 时拿到正确账号配置
legacy_env = p / ".env"
if legacy_env.exists():
load_dotenv(legacy_env, override=True)
if p.exists() and str(p) not in sys.path:
sys.path.insert(0, str(p))
async def _draw_via_legacy_tw(
image_url: str,
customer_id: str,
requirement: str = "",
) -> dict[str, Any]:
# 优先使用当前项目中拷贝过来的 service_gemini
from services.service_gemini import GeminiExtractV2Service # type: ignore
# 上传模块暂时仍走 legacy你后续可替换为新项目本地上传实现
_add_legacy_tw_path()
from services.service_tuhui_upload import upload_to_tuhui # type: ignore
prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本"
# 1) 下载原图到本地临时文件
input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg")
output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg")
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/122.0.0.0 Safari/537.36"
),
"Referer": "https://www.taobao.com/",
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
}
resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
if resp.status_code != 200:
return {"ok": False, "error": f"download_http_{resp.status_code}"}
with open(input_path, "wb") as f:
f.write(resp.content)
# 2) 直调你原来的 service_gemini 作图 API
service = GeminiExtractV2Service()
ok_extract, msg_extract, _ = await service.extract_pattern(
input_path=input_path,
output_path=output_path,
custom_prompt=prompt,
aspect_ratio="1:1",
)
if not ok_extract:
return {"ok": False, "error": f"extract_failed:{msg_extract}"}
if not os.path.exists(output_path):
return {"ok": False, "error": "extract_no_output_file"}
# 3) 上传图绘,返回可外发 URL
ok, link, _ = await upload_to_tuhui(
output_path,
title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图",
description="AI自动作图预览",
price=1,
)
if not ok:
return {"ok": False, "error": str(link)}
return {"ok": True, "url": str(link)}
def _draw_via_http_endpoint(image_url: str, customer_id: str, requirement: str = "") -> dict[str, Any]:
if not AUTO_DRAW_ENDPOINT:
return {"ok": False, "error": "AUTO_DRAW_ENDPOINT not configured"}
payload = {
"image_url": image_url,
"customer_id": customer_id,
"requirement": requirement,
}
resp = requests.post(AUTO_DRAW_ENDPOINT, json=payload, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
if resp.status_code != 200:
return {"ok": False, "error": f"http_{resp.status_code}:{resp.text[:200]}"}
data = resp.json() if resp.text else {}
url = str(data.get("url", "") or data.get("preview_url", "") or "")
if not url:
return {"ok": False, "error": "missing_preview_url"}
return {"ok": True, "url": url}
async def auto_draw_preview( async def auto_draw_preview(
@@ -108,24 +16,68 @@ async def auto_draw_preview(
requirement: str = "", requirement: str = "",
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
统一自动作图入口: 统一自动作图入口(直调本地链路)
1) 优先走 tw_terminator 的 service_gemini 直调链路 1) 下载客户图
2) 失败时回退 AUTO_DRAW_ENDPOINT 2) 调 Gemini 生成
3) 上传图绘,返回可外发 URL
""" """
try: try:
return await _draw_via_legacy_tw(image_url=image_url, customer_id=customer_id, requirement=requirement) from services.service_gemini import GeminiExtractV2Service # type: ignore
from services.service_tuhui_upload import upload_to_tuhui # type: ignore
except Exception as e: except Exception as e:
legacy_error = str(e) return {"ok": False, "error": f"import_failed:{e}"}
prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本"
input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg")
output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg")
try: try:
data = await asyncio.to_thread( headers = {
_draw_via_http_endpoint, "User-Agent": (
image_url, "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
customer_id, "AppleWebKit/537.36 (KHTML, like Gecko) "
requirement, "Chrome/122.0.0.0 Safari/537.36"
),
"Referer": "https://www.taobao.com/",
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
}
resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
if resp.status_code != 200:
return {"ok": False, "error": f"download_http_{resp.status_code}"}
with open(input_path, "wb") as f:
f.write(resp.content)
service = GeminiExtractV2Service()
ok_extract, msg_extract, _ = await service.extract_pattern(
input_path=input_path,
output_path=output_path,
custom_prompt=prompt,
aspect_ratio="1:1",
) )
if data.get("ok"): if not ok_extract:
return data return {"ok": False, "error": f"extract_failed:{msg_extract}"}
return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{data.get('error','unknown')}"} if not os.path.exists(output_path):
return {"ok": False, "error": "extract_no_output_file"}
ok_upload, link, _ = await upload_to_tuhui(
output_path,
title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图",
description="AI自动作图预览",
price=1,
)
if not ok_upload:
return {"ok": False, "error": f"upload_failed:{link}"}
return {"ok": True, "url": str(link)}
except Exception as e: except Exception as e:
return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{e}"} return {"ok": False, "error": str(e)}
finally:
try:
if os.path.exists(input_path):
os.remove(input_path)
except Exception:
pass
try:
if os.path.exists(output_path):
os.remove(output_path)
except Exception:
pass

View File

@@ -3,6 +3,7 @@ import json
import re import re
import time import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime
import websockets import websockets
@@ -35,7 +36,9 @@ class QingjianClient:
self.pending_images: dict[str, list[str]] = defaultdict(list) self.pending_images: dict[str, list[str]] = defaultdict(list)
self.auto_quote_tasks: dict[str, asyncio.Task] = {} self.auto_quote_tasks: dict[str, asyncio.Task] = {}
self.last_reply_key: dict[str, str] = {} self.last_reply_key: dict[str, str] = {}
self.first_msg_replied: set[str] = set()
self.recent_outbound: list[tuple[str, str, str, float]] = [] self.recent_outbound: list[tuple[str, str, str, float]] = []
self.recent_dialogue: dict[str, list[dict]] = defaultdict(list)
@staticmethod @staticmethod
def _customer_key(data: dict) -> str: def _customer_key(data: dict) -> str:
@@ -45,6 +48,38 @@ class QingjianClient:
def _msg_text(data: dict) -> str: def _msg_text(data: dict) -> str:
return str(data.get("msg", "") or "").strip() return str(data.get("msg", "") or "").strip()
def _append_dialogue(self, key: str, role: str, text: str) -> None:
t = str(text or "").strip()
if not t:
return
self.recent_dialogue[key].append({"role": role, "text": t})
if len(self.recent_dialogue[key]) > 24:
self.recent_dialogue[key] = self.recent_dialogue[key][-24:]
@staticmethod
def _parse_msg_ts(data: dict) -> float:
# 兼容常见时间字段解析失败时返回0后续按入队顺序兜底
for key in ("timestamp", "msg_time", "send_time", "create_time", "time"):
v = data.get(key)
if v is None:
continue
if isinstance(v, (int, float)):
return float(v)
s = str(v).strip()
if not s:
continue
# 纯数字时间戳
if re.fullmatch(r"\d{10,13}", s):
n = float(s)
return n / 1000.0 if len(s) == 13 else n
# 常见日期格式
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S", "%Y-%m-%d %H:%M", "%Y/%m/%d %H:%M"):
try:
return datetime.strptime(s, fmt).timestamp()
except Exception:
pass
return 0.0
def _debounce_seconds(self, msg: str) -> float: def _debounce_seconds(self, msg: str) -> float:
if extract_image_urls(msg): if extract_image_urls(msg):
return 2.5 return 2.5
@@ -61,6 +96,7 @@ class QingjianClient:
if not text: if not text:
return return
text = self._shorten_reply(text) text = self._shorten_reply(text)
key = self._customer_key(data)
msg = { msg = {
"msg_id": "", "msg_id": "",
"acc_id": data.get("acc_id", ""), "acc_id": data.get("acc_id", ""),
@@ -74,6 +110,7 @@ class QingjianClient:
} }
activity_event(self.logger, "send_reply_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=text) activity_event(self.logger, "send_reply_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=text)
await self.send_message(msg) await self.send_message(msg)
self._append_dialogue(key, "assistant", text)
self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), text, time.monotonic())) self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), text, time.monotonic()))
if len(self.recent_outbound) > 200: if len(self.recent_outbound) > 200:
self.recent_outbound = self.recent_outbound[-200:] self.recent_outbound = self.recent_outbound[-200:]
@@ -83,6 +120,7 @@ class QingjianClient:
image_url = str(image_url or "").strip() image_url = str(image_url or "").strip()
if not image_url: if not image_url:
return return
key = self._customer_key(data)
msg = { msg = {
"msg_id": "", "msg_id": "",
"acc_id": data.get("acc_id", ""), "acc_id": data.get("acc_id", ""),
@@ -96,6 +134,7 @@ class QingjianClient:
} }
activity_event(self.logger, "send_image_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=image_url) activity_event(self.logger, "send_image_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=image_url)
await self.send_message(msg) await self.send_message(msg)
self._append_dialogue(key, "assistant", f"[image]{image_url}")
self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), image_url, time.monotonic())) self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), image_url, time.monotonic()))
if len(self.recent_outbound) > 200: if len(self.recent_outbound) > 200:
self.recent_outbound = self.recent_outbound[-200:] self.recent_outbound = self.recent_outbound[-200:]
@@ -196,6 +235,7 @@ class QingjianClient:
"pending_images": len(self.pending_images[key]), "pending_images": len(self.pending_images[key]),
"auto_quote_trigger": auto_quote, "auto_quote_trigger": auto_quote,
"last_reply": self.last_reply_key.get(key, ""), "last_reply": self.last_reply_key.get(key, ""),
"recent_dialogue": self.recent_dialogue.get(key, [])[-12:],
} }
activity_event(self.logger, "agent_process_start", trace_id=trace_id, customer_id=context["customer_id"], acc_id=context["acc_id"], intent=context["intent"]) activity_event(self.logger, "agent_process_start", trace_id=trace_id, customer_id=context["customer_id"], acc_id=context["acc_id"], intent=context["intent"])
@@ -300,8 +340,11 @@ class QingjianClient:
queue = self.pending_msgs.get(key, []) queue = self.pending_msgs.get(key, [])
if not queue: if not queue:
return return
merged = "".join([self._msg_text(x) for x in queue if self._msg_text(x)]) indexed = list(enumerate(queue))
data = queue[-1] indexed.sort(key=lambda it: (self._parse_msg_ts(it[1]), it[0]))
ordered = [x for _, x in indexed]
merged = "".join([self._msg_text(x) for x in ordered if self._msg_text(x)])
data = ordered[-1]
self.pending_msgs[key].clear() self.pending_msgs[key].clear()
await self._handle_decision(data, merged) await self._handle_decision(data, merged)
@@ -356,6 +399,14 @@ class QingjianClient:
patched = dict(data) patched = dict(data)
patched["msg"] = rule.normalized_msg or msg patched["msg"] = rule.normalized_msg or msg
key = self._customer_key(patched)
self._append_dialogue(key, "user", patched["msg"])
# 硬编码:每个客户首条消息先快速回复“在的”
if key not in self.first_msg_replied:
await self.send_reply(patched, "在的")
self.last_reply_key[key] = "在的"
self.first_msg_replied.add(key)
if msg_type == 1: if msg_type == 1:
await self._handle_decision(patched, patched["msg"]) await self._handle_decision(patched, patched["msg"])

View File

@@ -117,6 +117,10 @@ def rules_prompt() -> str:
"4) 客户说“发完了/就这些/报价吧”: 若有图则 action=quote。\n" "4) 客户说“发完了/就这些/报价吧”: 若有图则 action=quote。\n"
"5) 不能承诺“一模一样原图必找到”,可说先看图评估。\n" "5) 不能承诺“一模一样原图必找到”,可说先看图评估。\n"
"6) 尺寸很大或要求高还原时,不夸张承诺,先说明可评估后给结论。\n\n" "6) 尺寸很大或要求高还原时,不夸张承诺,先说明可评估后给结论。\n\n"
"7) 话术意图匹配(不要硬模板):\n"
" - 客户是“找图/有没有/找原图”诉求时,优先用“我先给你找找看”这类承接语。\n"
" - 客户是“高清修复/清晰处理/重修”诉求时,再用“我先评估下报价/难度”类话术。\n"
" - 避免把“找图”直接说成“评估报价”。\n\n"
"C. 订单阶段\n" "C. 订单阶段\n"
"1) 已付款: 可回复“已安排处理/正在处理/完成后发你确认”。\n" "1) 已付款: 可回复“已安排处理/正在处理/完成后发你确认”。\n"
"2) 待付款: 可提示付款,但不与客户争执;必要时先给预览再引导付款。\n" "2) 待付款: 可提示付款,但不与客户争执;必要时先给预览再引导付款。\n"

View File

@@ -0,0 +1,107 @@
from __future__ import annotations
import os
from typing import Optional, Tuple
from urllib.parse import urljoin
import requests
TUHUI_BASE_URL = os.getenv("TUHUI_BASE_URL", "http://127.0.0.1:8002").strip()
TUHUI_PHONE = os.getenv("TUHUI_PHONE", "17520145271").strip()
TUHUI_PASSWORD = os.getenv("TUHUI_PASSWORD", "zuowei1216").strip()
TUHUI_DEFAULT_PRICE = int(os.getenv("TUHUI_DEFAULT_PRICE", "20"))
TUHUI_UPLOAD_ENDPOINT = os.getenv("TUHUI_UPLOAD_ENDPOINT", "/api/upload").strip()
TUHUI_UPLOAD_FILE_FIELD = os.getenv("TUHUI_UPLOAD_FILE_FIELD", "file").strip()
TUHUI_DEFAULT_CATEGORY = os.getenv("TUHUI_DEFAULT_CATEGORY", "高清修复").strip()
TUHUI_TIMEOUT_SECONDS = int(os.getenv("TUHUI_TIMEOUT_SECONDS", "30"))
def _login() -> Tuple[bool, str]:
try:
resp = requests.post(
f"{TUHUI_BASE_URL}/api/auth/login",
json={"phone": TUHUI_PHONE, "password": TUHUI_PASSWORD},
timeout=TUHUI_TIMEOUT_SECONDS,
)
if resp.status_code != 200:
return False, f"login_http_{resp.status_code}:{resp.text[:120]}"
data = resp.json() if resp.text else {}
token = str(data.get("access_token", "") or "")
if not token:
return False, "login_no_token"
return True, token
except Exception as e:
return False, f"login_error:{e}"
def _upload_sync(
image_path: str,
title: str,
description: str = "",
price: int = 20,
category: Optional[str] = None,
) -> Tuple[bool, str, int]:
if not os.path.exists(image_path):
return False, "file_not_found", 0
ok, token_or_err = _login()
if not ok:
return False, token_or_err, 0
token = token_or_err
use_price = int(price or TUHUI_DEFAULT_PRICE)
use_category = (category or TUHUI_DEFAULT_CATEGORY or "高清修复").strip()
headers = {"Authorization": f"Bearer {token}"}
data = {
"title": title,
"description": description,
"price": str(use_price),
"category": use_category,
}
try:
with open(image_path, "rb") as f:
files = {TUHUI_UPLOAD_FILE_FIELD: ("image.jpg", f, "image/jpeg")}
resp = requests.post(
f"{TUHUI_BASE_URL}{TUHUI_UPLOAD_ENDPOINT}",
files=files,
data=data,
headers=headers,
timeout=TUHUI_TIMEOUT_SECONDS,
)
if resp.status_code not in (200, 201):
return False, f"upload_http_{resp.status_code}:{resp.text[:120]}", 0
payload = resp.json() if resp.text else {}
work = payload.get("work", {}) if isinstance(payload.get("work"), dict) else {}
work_id = int(work.get("id") or payload.get("work_id") or payload.get("id") or 0)
image_url = (
str(work.get("original_image") or work.get("image_url") or payload.get("image_url") or "")
)
if image_url.startswith("/"):
image_url = urljoin(f"{TUHUI_BASE_URL}/", image_url.lstrip("/"))
if not image_url:
return False, "upload_no_image_url", 0
return True, image_url, work_id
except Exception as e:
return False, f"upload_error:{e}", 0
async def upload_to_tuhui(
image_path: str,
title: str,
description: str = "",
price: int = 20,
category: Optional[str] = None,
) -> Tuple[bool, str, int]:
import asyncio
return await asyncio.to_thread(
_upload_sync,
image_path,
title,
description,
price,
category,
)