From 00166d7ebfebad6db1c843c8009fe95eb7baa424 Mon Sep 17 00:00:00 2001 From: jimi <1847930177@qq.com> Date: Tue, 3 Mar 2026 10:18:02 +0800 Subject: [PATCH] feat: improve context memory and fix auto-draw gemini/upload chain --- qingjian_cs/app/auto_draw.py | 166 +++++++------------ qingjian_cs/app/client.py | 55 +++++- qingjian_cs/app/rules.py | 4 + qingjian_cs/services/service_tuhui_upload.py | 107 ++++++++++++ 4 files changed, 223 insertions(+), 109 deletions(-) create mode 100644 qingjian_cs/services/service_tuhui_upload.py diff --git a/qingjian_cs/app/auto_draw.py b/qingjian_cs/app/auto_draw.py index 7bccbc1..5b29e99 100644 --- a/qingjian_cs/app/auto_draw.py +++ b/qingjian_cs/app/auto_draw.py @@ -1,105 +1,13 @@ from __future__ import annotations -import asyncio import os -import sys import tempfile import uuid -from pathlib import Path from typing import Any import requests -from dotenv import load_dotenv -from .config import AUTO_DRAW_ENDPOINT, AUTO_DRAW_TIMEOUT_SECONDS - - -def _add_legacy_tw_path() -> None: - root = os.getenv("LEGACY_TW_ROOT", r"D:\main\sandbox\tw_terminator").strip() - if not root: - return - p = Path(root) - # 先加载 legacy 项目的 .env,确保 service_tuhui_upload 在 import 时拿到正确账号配置 - legacy_env = p / ".env" - if legacy_env.exists(): - load_dotenv(legacy_env, override=True) - if p.exists() and str(p) not in sys.path: - sys.path.insert(0, str(p)) - - -async def _draw_via_legacy_tw( - image_url: str, - customer_id: str, - requirement: str = "", -) -> dict[str, Any]: - # 优先使用当前项目中拷贝过来的 service_gemini - from services.service_gemini import GeminiExtractV2Service # type: ignore - - # 上传模块暂时仍走 legacy(你后续可替换为新项目本地上传实现) - _add_legacy_tw_path() - from services.service_tuhui_upload import upload_to_tuhui # type: ignore - - prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本" - - # 1) 下载原图到本地临时文件 - input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg") - output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg") - headers = { - "User-Agent": ( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/122.0.0.0 Safari/537.36" - ), - "Referer": "https://www.taobao.com/", - "Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8", - } - resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS) - if resp.status_code != 200: - return {"ok": False, "error": f"download_http_{resp.status_code}"} - with open(input_path, "wb") as f: - f.write(resp.content) - - # 2) 直调你原来的 service_gemini 作图 API - service = GeminiExtractV2Service() - ok_extract, msg_extract, _ = await service.extract_pattern( - input_path=input_path, - output_path=output_path, - custom_prompt=prompt, - aspect_ratio="1:1", - ) - if not ok_extract: - return {"ok": False, "error": f"extract_failed:{msg_extract}"} - if not os.path.exists(output_path): - return {"ok": False, "error": "extract_no_output_file"} - - # 3) 上传图绘,返回可外发 URL - ok, link, _ = await upload_to_tuhui( - output_path, - title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图", - description="AI自动作图预览", - price=1, - ) - if not ok: - return {"ok": False, "error": str(link)} - return {"ok": True, "url": str(link)} - - -def _draw_via_http_endpoint(image_url: str, customer_id: str, requirement: str = "") -> dict[str, Any]: - if not AUTO_DRAW_ENDPOINT: - return {"ok": False, "error": "AUTO_DRAW_ENDPOINT not configured"} - payload = { - "image_url": image_url, - "customer_id": customer_id, - "requirement": requirement, - } - resp = requests.post(AUTO_DRAW_ENDPOINT, json=payload, timeout=AUTO_DRAW_TIMEOUT_SECONDS) - if resp.status_code != 200: - return {"ok": False, "error": f"http_{resp.status_code}:{resp.text[:200]}"} - data = resp.json() if resp.text else {} - url = str(data.get("url", "") or data.get("preview_url", "") or "") - if not url: - return {"ok": False, "error": "missing_preview_url"} - return {"ok": True, "url": url} +from .config import AUTO_DRAW_TIMEOUT_SECONDS async def auto_draw_preview( @@ -108,24 +16,68 @@ async def auto_draw_preview( requirement: str = "", ) -> dict[str, Any]: """ - 统一自动作图入口: - 1) 优先走 tw_terminator 的 service_gemini 直调链路 - 2) 失败时回退 AUTO_DRAW_ENDPOINT + 统一自动作图入口(直调本地链路): + 1) 下载客户图 + 2) 调 Gemini 生成 + 3) 上传图绘,返回可外发 URL """ try: - return await _draw_via_legacy_tw(image_url=image_url, customer_id=customer_id, requirement=requirement) + from services.service_gemini import GeminiExtractV2Service # type: ignore + from services.service_tuhui_upload import upload_to_tuhui # type: ignore except Exception as e: - legacy_error = str(e) + return {"ok": False, "error": f"import_failed:{e}"} + + prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本" + input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg") + output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg") try: - data = await asyncio.to_thread( - _draw_via_http_endpoint, - image_url, - customer_id, - requirement, + headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/122.0.0.0 Safari/537.36" + ), + "Referer": "https://www.taobao.com/", + "Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8", + } + resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS) + if resp.status_code != 200: + return {"ok": False, "error": f"download_http_{resp.status_code}"} + with open(input_path, "wb") as f: + f.write(resp.content) + + service = GeminiExtractV2Service() + ok_extract, msg_extract, _ = await service.extract_pattern( + input_path=input_path, + output_path=output_path, + custom_prompt=prompt, + aspect_ratio="1:1", ) - if data.get("ok"): - return data - return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{data.get('error','unknown')}"} + if not ok_extract: + return {"ok": False, "error": f"extract_failed:{msg_extract}"} + if not os.path.exists(output_path): + return {"ok": False, "error": "extract_no_output_file"} + + ok_upload, link, _ = await upload_to_tuhui( + output_path, + title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图", + description="AI自动作图预览", + price=1, + ) + if not ok_upload: + return {"ok": False, "error": f"upload_failed:{link}"} + return {"ok": True, "url": str(link)} except Exception as e: - return {"ok": False, "error": f"legacy:{legacy_error}; endpoint:{e}"} + return {"ok": False, "error": str(e)} + finally: + try: + if os.path.exists(input_path): + os.remove(input_path) + except Exception: + pass + try: + if os.path.exists(output_path): + os.remove(output_path) + except Exception: + pass diff --git a/qingjian_cs/app/client.py b/qingjian_cs/app/client.py index eb989c7..b9de03c 100644 --- a/qingjian_cs/app/client.py +++ b/qingjian_cs/app/client.py @@ -3,6 +3,7 @@ import json import re import time from collections import defaultdict +from datetime import datetime import websockets @@ -35,7 +36,9 @@ class QingjianClient: self.pending_images: dict[str, list[str]] = defaultdict(list) self.auto_quote_tasks: dict[str, asyncio.Task] = {} self.last_reply_key: dict[str, str] = {} + self.first_msg_replied: set[str] = set() self.recent_outbound: list[tuple[str, str, str, float]] = [] + self.recent_dialogue: dict[str, list[dict]] = defaultdict(list) @staticmethod def _customer_key(data: dict) -> str: @@ -45,6 +48,38 @@ class QingjianClient: def _msg_text(data: dict) -> str: return str(data.get("msg", "") or "").strip() + def _append_dialogue(self, key: str, role: str, text: str) -> None: + t = str(text or "").strip() + if not t: + return + self.recent_dialogue[key].append({"role": role, "text": t}) + if len(self.recent_dialogue[key]) > 24: + self.recent_dialogue[key] = self.recent_dialogue[key][-24:] + + @staticmethod + def _parse_msg_ts(data: dict) -> float: + # 兼容常见时间字段;解析失败时返回0,后续按入队顺序兜底 + for key in ("timestamp", "msg_time", "send_time", "create_time", "time"): + v = data.get(key) + if v is None: + continue + if isinstance(v, (int, float)): + return float(v) + s = str(v).strip() + if not s: + continue + # 纯数字时间戳 + if re.fullmatch(r"\d{10,13}", s): + n = float(s) + return n / 1000.0 if len(s) == 13 else n + # 常见日期格式 + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M:%S", "%Y-%m-%d %H:%M", "%Y/%m/%d %H:%M"): + try: + return datetime.strptime(s, fmt).timestamp() + except Exception: + pass + return 0.0 + def _debounce_seconds(self, msg: str) -> float: if extract_image_urls(msg): return 2.5 @@ -61,6 +96,7 @@ class QingjianClient: if not text: return text = self._shorten_reply(text) + key = self._customer_key(data) msg = { "msg_id": "", "acc_id": data.get("acc_id", ""), @@ -74,6 +110,7 @@ class QingjianClient: } activity_event(self.logger, "send_reply_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=text) await self.send_message(msg) + self._append_dialogue(key, "assistant", text) self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), text, time.monotonic())) if len(self.recent_outbound) > 200: self.recent_outbound = self.recent_outbound[-200:] @@ -83,6 +120,7 @@ class QingjianClient: image_url = str(image_url or "").strip() if not image_url: return + key = self._customer_key(data) msg = { "msg_id": "", "acc_id": data.get("acc_id", ""), @@ -96,6 +134,7 @@ class QingjianClient: } activity_event(self.logger, "send_image_attempt", trace_id=trace_id, customer_id=data.get("from_id", "-"), msg=image_url) await self.send_message(msg) + self._append_dialogue(key, "assistant", f"[image]{image_url}") self.recent_outbound.append((str(data.get("acc_id", "")), str(data.get("from_id", "")), image_url, time.monotonic())) if len(self.recent_outbound) > 200: self.recent_outbound = self.recent_outbound[-200:] @@ -196,6 +235,7 @@ class QingjianClient: "pending_images": len(self.pending_images[key]), "auto_quote_trigger": auto_quote, "last_reply": self.last_reply_key.get(key, ""), + "recent_dialogue": self.recent_dialogue.get(key, [])[-12:], } activity_event(self.logger, "agent_process_start", trace_id=trace_id, customer_id=context["customer_id"], acc_id=context["acc_id"], intent=context["intent"]) @@ -300,8 +340,11 @@ class QingjianClient: queue = self.pending_msgs.get(key, []) if not queue: return - merged = "、".join([self._msg_text(x) for x in queue if self._msg_text(x)]) - data = queue[-1] + indexed = list(enumerate(queue)) + indexed.sort(key=lambda it: (self._parse_msg_ts(it[1]), it[0])) + ordered = [x for _, x in indexed] + merged = "、".join([self._msg_text(x) for x in ordered if self._msg_text(x)]) + data = ordered[-1] self.pending_msgs[key].clear() await self._handle_decision(data, merged) @@ -356,6 +399,14 @@ class QingjianClient: patched = dict(data) patched["msg"] = rule.normalized_msg or msg + key = self._customer_key(patched) + self._append_dialogue(key, "user", patched["msg"]) + + # 硬编码:每个客户首条消息先快速回复“在的” + if key not in self.first_msg_replied: + await self.send_reply(patched, "在的") + self.last_reply_key[key] = "在的" + self.first_msg_replied.add(key) if msg_type == 1: await self._handle_decision(patched, patched["msg"]) diff --git a/qingjian_cs/app/rules.py b/qingjian_cs/app/rules.py index ab2cd71..213c80d 100644 --- a/qingjian_cs/app/rules.py +++ b/qingjian_cs/app/rules.py @@ -117,6 +117,10 @@ def rules_prompt() -> str: "4) 客户说“发完了/就这些/报价吧”: 若有图则 action=quote。\n" "5) 不能承诺“一模一样原图必找到”,可说先看图评估。\n" "6) 尺寸很大或要求高还原时,不夸张承诺,先说明可评估后给结论。\n\n" + "7) 话术意图匹配(不要硬模板):\n" + " - 客户是“找图/有没有/找原图”诉求时,优先用“我先给你找找看”这类承接语。\n" + " - 客户是“高清修复/清晰处理/重修”诉求时,再用“我先评估下报价/难度”类话术。\n" + " - 避免把“找图”直接说成“评估报价”。\n\n" "C. 订单阶段\n" "1) 已付款: 可回复“已安排处理/正在处理/完成后发你确认”。\n" "2) 待付款: 可提示付款,但不与客户争执;必要时先给预览再引导付款。\n" diff --git a/qingjian_cs/services/service_tuhui_upload.py b/qingjian_cs/services/service_tuhui_upload.py new file mode 100644 index 0000000..ae50e05 --- /dev/null +++ b/qingjian_cs/services/service_tuhui_upload.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import os +from typing import Optional, Tuple +from urllib.parse import urljoin + +import requests + + +TUHUI_BASE_URL = os.getenv("TUHUI_BASE_URL", "http://127.0.0.1:8002").strip() +TUHUI_PHONE = os.getenv("TUHUI_PHONE", "17520145271").strip() +TUHUI_PASSWORD = os.getenv("TUHUI_PASSWORD", "zuowei1216").strip() +TUHUI_DEFAULT_PRICE = int(os.getenv("TUHUI_DEFAULT_PRICE", "20")) +TUHUI_UPLOAD_ENDPOINT = os.getenv("TUHUI_UPLOAD_ENDPOINT", "/api/upload").strip() +TUHUI_UPLOAD_FILE_FIELD = os.getenv("TUHUI_UPLOAD_FILE_FIELD", "file").strip() +TUHUI_DEFAULT_CATEGORY = os.getenv("TUHUI_DEFAULT_CATEGORY", "高清修复").strip() +TUHUI_TIMEOUT_SECONDS = int(os.getenv("TUHUI_TIMEOUT_SECONDS", "30")) + + +def _login() -> Tuple[bool, str]: + try: + resp = requests.post( + f"{TUHUI_BASE_URL}/api/auth/login", + json={"phone": TUHUI_PHONE, "password": TUHUI_PASSWORD}, + timeout=TUHUI_TIMEOUT_SECONDS, + ) + if resp.status_code != 200: + return False, f"login_http_{resp.status_code}:{resp.text[:120]}" + data = resp.json() if resp.text else {} + token = str(data.get("access_token", "") or "") + if not token: + return False, "login_no_token" + return True, token + except Exception as e: + return False, f"login_error:{e}" + + +def _upload_sync( + image_path: str, + title: str, + description: str = "", + price: int = 20, + category: Optional[str] = None, +) -> Tuple[bool, str, int]: + if not os.path.exists(image_path): + return False, "file_not_found", 0 + + ok, token_or_err = _login() + if not ok: + return False, token_or_err, 0 + token = token_or_err + + use_price = int(price or TUHUI_DEFAULT_PRICE) + use_category = (category or TUHUI_DEFAULT_CATEGORY or "高清修复").strip() + headers = {"Authorization": f"Bearer {token}"} + data = { + "title": title, + "description": description, + "price": str(use_price), + "category": use_category, + } + + try: + with open(image_path, "rb") as f: + files = {TUHUI_UPLOAD_FILE_FIELD: ("image.jpg", f, "image/jpeg")} + resp = requests.post( + f"{TUHUI_BASE_URL}{TUHUI_UPLOAD_ENDPOINT}", + files=files, + data=data, + headers=headers, + timeout=TUHUI_TIMEOUT_SECONDS, + ) + if resp.status_code not in (200, 201): + return False, f"upload_http_{resp.status_code}:{resp.text[:120]}", 0 + payload = resp.json() if resp.text else {} + work = payload.get("work", {}) if isinstance(payload.get("work"), dict) else {} + work_id = int(work.get("id") or payload.get("work_id") or payload.get("id") or 0) + image_url = ( + str(work.get("original_image") or work.get("image_url") or payload.get("image_url") or "") + ) + if image_url.startswith("/"): + image_url = urljoin(f"{TUHUI_BASE_URL}/", image_url.lstrip("/")) + if not image_url: + return False, "upload_no_image_url", 0 + return True, image_url, work_id + except Exception as e: + return False, f"upload_error:{e}", 0 + + +async def upload_to_tuhui( + image_path: str, + title: str, + description: str = "", + price: int = 20, + category: Optional[str] = None, +) -> Tuple[bool, str, int]: + import asyncio + + return await asyncio.to_thread( + _upload_sync, + image_path, + title, + description, + price, + category, + ) +