feat: automate image pipeline and simplify gemini flow
This commit is contained in:
@@ -18,6 +18,7 @@ from db.pending_transfer_db import (
|
||||
retry_pending_transfer,
|
||||
)
|
||||
from services.dispatch_service import dispatch_service
|
||||
from services.service_auto_image_pipeline import auto_image_pipeline_service
|
||||
|
||||
logger = logging.getLogger("cs_agent")
|
||||
|
||||
@@ -86,6 +87,7 @@ class SystemOrchestrator:
|
||||
self._user_locks: Dict[str, asyncio.Lock] = {}
|
||||
self._pending_transfer_task: Optional[asyncio.Task] = None
|
||||
self._last_retry_transfer_time: Dict[str, float] = {}
|
||||
self._auto_pipeline_jobs: Dict[str, float] = {}
|
||||
|
||||
bus.subscribe("MESSAGE_OUTBOUND", self.handle_outbound_event)
|
||||
|
||||
@@ -223,6 +225,102 @@ class SystemOrchestrator:
|
||||
return "我在帮你看记录,稍等哈"
|
||||
return cleaned
|
||||
|
||||
@staticmethod
|
||||
def _extract_designer_name(transfer_cmd: str) -> str:
|
||||
text = str(transfer_cmd or "").strip()
|
||||
match = re.search(r"\[转移会话\],([^,]+),", text)
|
||||
return str(match.group(1)).strip() if match else ""
|
||||
|
||||
@staticmethod
|
||||
def _infer_processing_intent(requirement_text: str, history: Optional[List[dict]] = None) -> str:
|
||||
combined_parts = [str(requirement_text or "").lower()]
|
||||
for item in history or []:
|
||||
if item.get("role") == "user":
|
||||
combined_parts.append(str(item.get("content") or "").lower())
|
||||
combined = "\n".join(combined_parts)
|
||||
repair_keywords = ("修复", "高清", "清晰", "放大", "老照片")
|
||||
if any(k in combined for k in repair_keywords):
|
||||
return "repair"
|
||||
return "find_original"
|
||||
|
||||
@staticmethod
|
||||
def _collect_recent_image_urls(history: List[dict], fallback_urls: Optional[List[str]] = None) -> List[str]:
|
||||
urls: List[str] = []
|
||||
seen = set()
|
||||
|
||||
def add_url(url: str):
|
||||
value = str(url or "").strip()
|
||||
if not value or value in seen:
|
||||
return
|
||||
seen.add(value)
|
||||
urls.append(value)
|
||||
|
||||
for url in fallback_urls or []:
|
||||
add_url(url)
|
||||
|
||||
for item in reversed(history or []):
|
||||
if item.get("role") != "user":
|
||||
continue
|
||||
raw_urls = item.get("image_urls") or []
|
||||
if isinstance(raw_urls, str):
|
||||
for part in re.split(r"[\n#]+", raw_urls):
|
||||
add_url(part)
|
||||
elif isinstance(raw_urls, list):
|
||||
for part in raw_urls:
|
||||
add_url(part)
|
||||
content = str(item.get("content") or "")
|
||||
for match in re.findall(r"https?://[^\s#]+", content):
|
||||
add_url(match)
|
||||
if len(urls) >= 5:
|
||||
break
|
||||
return urls
|
||||
|
||||
def _schedule_auto_pipeline(
|
||||
self,
|
||||
*,
|
||||
session_key: str,
|
||||
customer_id: str,
|
||||
acc_id: str,
|
||||
designer_name: str,
|
||||
requirement_text: str,
|
||||
history: List[dict],
|
||||
image_urls: Optional[List[str]] = None,
|
||||
):
|
||||
resolved_urls = self._collect_recent_image_urls(history, image_urls)
|
||||
if not resolved_urls:
|
||||
logger.info(f"[Orchestrator] 自动处理跳过:未找到客户图片 user={customer_id} acc={acc_id}")
|
||||
return
|
||||
|
||||
intent = self._infer_processing_intent(requirement_text, history)
|
||||
signature_src = f"{session_key}|{designer_name}|{intent}|{'|'.join(resolved_urls)}"
|
||||
signature = str(abs(hash(signature_src)))
|
||||
now = time.time()
|
||||
last_run = self._auto_pipeline_jobs.get(signature, 0.0)
|
||||
if now - last_run < 600:
|
||||
logger.info(f"[Orchestrator] 自动处理已在近期触发,跳过重复任务 user={customer_id} acc={acc_id}")
|
||||
return
|
||||
self._auto_pipeline_jobs[signature] = now
|
||||
|
||||
async def _runner():
|
||||
try:
|
||||
result = await auto_image_pipeline_service.process_and_notify(
|
||||
session_key=session_key,
|
||||
customer_id=customer_id,
|
||||
acc_id=acc_id,
|
||||
designer_name=designer_name,
|
||||
requirement_text=requirement_text,
|
||||
image_urls=resolved_urls,
|
||||
intent=intent,
|
||||
)
|
||||
logger.info(
|
||||
f"[Orchestrator] 自动处理完成 user={customer_id} acc={acc_id} "
|
||||
f"ok={result.get('success')} uploaded={len(result.get('uploaded') or [])}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Orchestrator] 自动处理失败 user={customer_id} acc={acc_id}: {e}")
|
||||
|
||||
asyncio.create_task(_runner())
|
||||
|
||||
async def on_raw_message_received(self, platform: str, raw_data: dict):
|
||||
"""链路入口"""
|
||||
try:
|
||||
@@ -339,7 +437,7 @@ class SystemOrchestrator:
|
||||
except Exception as e:
|
||||
logger.warning(f"[Orchestrator] 订单消息处理异常: {e}")
|
||||
|
||||
async def _analyze_images_background(self, session_key: str, image_urls: List[str]):
|
||||
async def _analyze_images_background(self, session_key: str, image_urls: List[str], requirement_text: str = ""):
|
||||
"""后台静默分析图片,存入用户数据库用于数据标定"""
|
||||
try:
|
||||
from services.service_image_analyzer import image_analyzer_service
|
||||
@@ -348,9 +446,9 @@ class SystemOrchestrator:
|
||||
db = CustomerDatabase()
|
||||
profile = db.get_customer(session_key)
|
||||
|
||||
for url in image_urls:
|
||||
for url in (image_urls or [])[:1]:
|
||||
try:
|
||||
result = await image_analyzer_service.analyze(url)
|
||||
result = await image_analyzer_service.analyze(url, customer_requirement=requirement_text)
|
||||
result_json = json.dumps(result, ensure_ascii=False)
|
||||
|
||||
# 更新最近一次分析
|
||||
@@ -454,7 +552,7 @@ class SystemOrchestrator:
|
||||
|
||||
# B2. 后台图片分析(不阻塞主流程,用于数据标定)
|
||||
if all_image_urls:
|
||||
asyncio.create_task(self._analyze_images_background(session_key, all_image_urls))
|
||||
asyncio.create_task(self._analyze_images_background(session_key, all_image_urls, combined_content))
|
||||
|
||||
history_start = time.time()
|
||||
history = await repo.get_chat_history(user_id, limit=12, acc_id=acc_id)
|
||||
@@ -504,6 +602,7 @@ class SystemOrchestrator:
|
||||
|
||||
# 转接场景:先发一句安抚话,再发转接指令
|
||||
if "[转移会话]" in std_res.reply_content:
|
||||
designer_name = self._extract_designer_name(std_res.reply_content)
|
||||
transfer_prelude = str(std_res.metadata.get("transfer_prelude") or "").strip()
|
||||
greet = StandardResponse(
|
||||
reply_content=transfer_prelude or "收到,我叫设计师来看下哈",
|
||||
@@ -547,6 +646,15 @@ class SystemOrchestrator:
|
||||
|
||||
if "[转移会话]" in std_res.reply_content:
|
||||
self._last_transfer_time[session_key] = time.time()
|
||||
self._schedule_auto_pipeline(
|
||||
session_key=session_key,
|
||||
customer_id=user_id,
|
||||
acc_id=acc_id,
|
||||
designer_name=self._extract_designer_name(std_res.reply_content),
|
||||
requirement_text=combined_content,
|
||||
history=history,
|
||||
image_urls=all_image_urls,
|
||||
)
|
||||
|
||||
except asyncio.CancelledError: pass
|
||||
except Exception as e: logger.exception(f"[Orchestrator] 处理失败: {e}")
|
||||
@@ -618,6 +726,15 @@ class SystemOrchestrator:
|
||||
)
|
||||
|
||||
self._last_transfer_time[f"{customer_id}@{acc_id}"] = time.time()
|
||||
history = await repo.get_chat_history(customer_id, limit=12, acc_id=acc_id)
|
||||
self._schedule_auto_pipeline(
|
||||
session_key=f"{customer_id}@{acc_id}",
|
||||
customer_id=customer_id,
|
||||
acc_id=acc_id,
|
||||
designer_name=designer_name,
|
||||
requirement_text=reason,
|
||||
history=history,
|
||||
)
|
||||
await asyncio.to_thread(complete_pending_transfer, row_id)
|
||||
logger.info(
|
||||
f"[Orchestrator] 待转接自动完成: pending_id={row_id} user={customer_id} designer={designer_name} reason={reason}"
|
||||
|
||||
232
services/service_auto_image_pipeline.py
Normal file
232
services/service_auto_image_pipeline.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from db.customer_db import CustomerDatabase
|
||||
from db.image_tasks_db import TaskStatus, db as task_db
|
||||
from services.service_gemini import GeminiExtractV2Service
|
||||
from services.service_tuhui_upload import upload_to_tuhui
|
||||
from services.service_wecom_bot import wecom_bot_service
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger("cs_agent")
|
||||
|
||||
AUTO_PROCESS_PRICE = int(os.getenv("AUTO_PROCESS_DEFAULT_PRICE", "12"))
|
||||
AUTO_PROCESS_CATEGORY = os.getenv("AUTO_PROCESS_CATEGORY", "设计素材")
|
||||
AUTO_PROCESS_ROOT = Path(
|
||||
os.getenv("AUTO_PROCESS_ROOT", str(Path(__file__).resolve().parents[1] / "runtime" / "auto_processed"))
|
||||
)
|
||||
|
||||
|
||||
def _safe_name(text: str, fallback: str = "image") -> str:
|
||||
cleaned = re.sub(r"[^0-9A-Za-z\u4e00-\u9fa5_-]+", "_", str(text or "").strip())
|
||||
cleaned = cleaned.strip("_")
|
||||
return cleaned[:40] or fallback
|
||||
|
||||
|
||||
def _suffix_from_url(url: str) -> str:
|
||||
path = urlparse(str(url or "")).path
|
||||
suffix = Path(path).suffix.lower()
|
||||
if suffix in {".png", ".jpg", ".jpeg", ".webp"}:
|
||||
return suffix
|
||||
return ".png"
|
||||
|
||||
|
||||
def _build_processing_prompt(intent: str, requirement_text: str, analysis: Dict) -> str:
|
||||
base_prompt = str((analysis or {}).get("gemini_prompt") or "").strip()
|
||||
req = str(requirement_text or "").strip()
|
||||
if base_prompt:
|
||||
return base_prompt
|
||||
if intent == "repair":
|
||||
return f"根据客户需求“{req or '高清修复'}”,保留主体和构图,做高清修复并补足细节。"
|
||||
return f"根据客户需求“{req or '找原图'}”,严格参考原图元素与构图,生成完整干净的高质量素材图。"
|
||||
|
||||
|
||||
class AutoImagePipelineService:
|
||||
def __init__(self):
|
||||
self.customer_db = CustomerDatabase()
|
||||
|
||||
async def _download_image(self, image_url: str, dest_path: Path) -> Path:
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
|
||||
response = await client.get(image_url)
|
||||
response.raise_for_status()
|
||||
dest_path.write_bytes(response.content)
|
||||
return dest_path
|
||||
|
||||
@staticmethod
|
||||
def _format_transfer_notice(
|
||||
customer_id: str,
|
||||
acc_id: str,
|
||||
designer_name: str,
|
||||
requirement_text: str,
|
||||
intent: str,
|
||||
image_urls: List[str],
|
||||
) -> str:
|
||||
lines = [
|
||||
"【AI自动转设计师】",
|
||||
f"店铺:{acc_id or '-'}",
|
||||
f"客户:{customer_id or '-'}",
|
||||
f"设计师:{designer_name or '-'}",
|
||||
f"需求:{requirement_text or '-'}",
|
||||
f"类型:{'高清修复' if intent == 'repair' else '找原图'}",
|
||||
f"默认价格:{AUTO_PROCESS_PRICE}元",
|
||||
]
|
||||
if image_urls:
|
||||
lines.append("原图URL:")
|
||||
lines.extend(image_urls[:5])
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _format_finish_notice(
|
||||
customer_id: str,
|
||||
acc_id: str,
|
||||
designer_name: str,
|
||||
links: List[Dict[str, str]],
|
||||
failures: List[str],
|
||||
) -> str:
|
||||
lines = [
|
||||
"【AI处理完成】",
|
||||
f"店铺:{acc_id or '-'}",
|
||||
f"客户:{customer_id or '-'}",
|
||||
f"设计师:{designer_name or '-'}",
|
||||
f"默认价格:{AUTO_PROCESS_PRICE}元",
|
||||
]
|
||||
if links:
|
||||
lines.append("处理结果:")
|
||||
for idx, item in enumerate(links, 1):
|
||||
lines.append(f"{idx}. 图绘链接:{item.get('download_url') or '-'}")
|
||||
lines.append(f" 原图URL:{item.get('source_url') or '-'}")
|
||||
if failures:
|
||||
lines.append("失败项:")
|
||||
lines.extend(failures[:5])
|
||||
return "\n".join(lines)
|
||||
|
||||
async def process_and_notify(
|
||||
self,
|
||||
*,
|
||||
session_key: str,
|
||||
customer_id: str,
|
||||
acc_id: str,
|
||||
designer_name: str,
|
||||
requirement_text: str,
|
||||
image_urls: List[str],
|
||||
intent: str = "",
|
||||
) -> Dict:
|
||||
image_urls = [str(url).strip() for url in (image_urls or []) if str(url).strip()]
|
||||
if not image_urls:
|
||||
return {"success": False, "message": "no_images"}
|
||||
image_urls = image_urls[:1]
|
||||
|
||||
profile = self.customer_db.get_customer(session_key)
|
||||
analysis = {}
|
||||
if getattr(profile, "last_image_analysis", ""):
|
||||
try:
|
||||
analysis = json.loads(profile.last_image_analysis)
|
||||
except Exception:
|
||||
analysis = {}
|
||||
|
||||
if not intent:
|
||||
intent = "repair" if "修复" in requirement_text else "find_original"
|
||||
|
||||
await wecom_bot_service.send_text(
|
||||
self._format_transfer_notice(
|
||||
customer_id=customer_id,
|
||||
acc_id=acc_id,
|
||||
designer_name=designer_name,
|
||||
requirement_text=requirement_text,
|
||||
intent=intent,
|
||||
image_urls=image_urls,
|
||||
)
|
||||
)
|
||||
|
||||
pipeline_root = AUTO_PROCESS_ROOT / _safe_name(customer_id, "customer")
|
||||
pipeline_root.mkdir(parents=True, exist_ok=True)
|
||||
gemini_service = GeminiExtractV2Service()
|
||||
uploaded_links: List[Dict[str, str]] = []
|
||||
failures: List[str] = []
|
||||
|
||||
for idx, image_url in enumerate(image_urls, 1):
|
||||
digest = hashlib.md5(f"{customer_id}|{acc_id}|{image_url}".encode("utf-8")).hexdigest()[:10]
|
||||
input_path = pipeline_root / f"{digest}_src{_suffix_from_url(image_url)}"
|
||||
output_path = pipeline_root / f"{digest}_out.png"
|
||||
title = f"{_safe_name(customer_id, '客户')}_{'修复' if intent == 'repair' else '原图'}_{idx}"
|
||||
prompt = _build_processing_prompt(intent, requirement_text, analysis)
|
||||
task_id = task_db.add_task(
|
||||
customer_id=customer_id,
|
||||
platform="qianniu",
|
||||
original_image=image_url,
|
||||
operation=intent or "auto_process",
|
||||
requirements=requirement_text,
|
||||
status=TaskStatus.PROCESSING.value,
|
||||
)
|
||||
try:
|
||||
await self._download_image(image_url, input_path)
|
||||
success, message, data = await gemini_service.extract_pattern(
|
||||
str(input_path),
|
||||
str(output_path),
|
||||
custom_prompt=prompt,
|
||||
aspect_ratio=str((analysis or {}).get("aspect_ratio") or "1:1"),
|
||||
)
|
||||
if not success or not output_path.exists():
|
||||
if task_id:
|
||||
task_db.update_status(task_id, TaskStatus.FAILED.value)
|
||||
failures.append(f"{idx}. Gemini失败:{message}")
|
||||
continue
|
||||
|
||||
upload_result = await upload_to_tuhui(
|
||||
image_path=str(output_path),
|
||||
title=title,
|
||||
description=requirement_text or prompt[:120],
|
||||
price=AUTO_PROCESS_PRICE,
|
||||
category=AUTO_PROCESS_CATEGORY,
|
||||
tags="AI处理,自动转接",
|
||||
designer_name=designer_name,
|
||||
)
|
||||
if not upload_result.success:
|
||||
if task_id:
|
||||
task_db.update_status(task_id, TaskStatus.FAILED.value)
|
||||
failures.append(f"{idx}. 图绘上传失败:{upload_result.message}")
|
||||
continue
|
||||
|
||||
if task_id:
|
||||
task_db.update_status(task_id, TaskStatus.COMPLETED.value, upload_result.download_url)
|
||||
uploaded_links.append(
|
||||
{
|
||||
"download_url": upload_result.download_url,
|
||||
"source_url": image_url,
|
||||
"work_id": str(upload_result.work_id),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
if task_id:
|
||||
task_db.update_status(task_id, TaskStatus.FAILED.value)
|
||||
failures.append(f"{idx}. 处理异常:{e}")
|
||||
|
||||
await wecom_bot_service.send_text(
|
||||
self._format_finish_notice(
|
||||
customer_id=customer_id,
|
||||
acc_id=acc_id,
|
||||
designer_name=designer_name,
|
||||
links=uploaded_links,
|
||||
failures=failures,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"success": bool(uploaded_links),
|
||||
"uploaded": uploaded_links,
|
||||
"failures": failures,
|
||||
}
|
||||
|
||||
|
||||
auto_image_pipeline_service = AutoImagePipelineService()
|
||||
@@ -1,106 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Gemini印花提取V2服务 - 使用服务
|
||||
更经济的选择:1.4毛/张
|
||||
"""
|
||||
"""Gemini 出图服务。固定走老张 Gemini 原生出图接口。"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from utils.metrics_tracker import emit as metrics_emit
|
||||
|
||||
|
||||
from utils.service_base import BaseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
load_dotenv()
|
||||
GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3.1-flash-image-preview")
|
||||
GEMINI_IMAGE_FALLBACK_MODEL = os.getenv("GEMINI_IMAGE_FALLBACK_MODEL", "gemini-2.5-flash-image")
|
||||
GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "1K")
|
||||
GEMINI_THINKING_LEVEL = os.getenv("GEMINI_THINKING_LEVEL", "MINIMAL")
|
||||
|
||||
GEMINI_API_KEY = os.getenv(
|
||||
"GEMINI_API_KEY",
|
||||
"sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8",
|
||||
)
|
||||
GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview")
|
||||
GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "2K")
|
||||
GEMINI_PERSON_GENERATION = os.getenv("GEMINI_PERSON_GENERATION", "")
|
||||
GEMINI_THINKING_LEVEL = os.getenv("GEMINI_THINKING_LEVEL", "MINIMAL")
|
||||
|
||||
|
||||
class GeminiExtractV2Service(BaseService):
|
||||
"""Gemini印花提取V2服务类 - 使用服务,更经济"""
|
||||
|
||||
"""固定单接口的 Gemini 出图服务。"""
|
||||
|
||||
SERVICE_NAME = "gemini_extract_v2"
|
||||
|
||||
# 多API配置,按优先级排序(便宜的优先使用)
|
||||
API_CONFIGS = [
|
||||
API_BASE_URL = "https://api.laozhang.ai/v1beta/models"
|
||||
DEFAULT_PROMPT = (
|
||||
"提取印花图案,把褶皱移除。补齐缺失的部分,要生成完整,细节丰富,"
|
||||
"严格按照原图的元素位置生成平面的印花图,不要相似的,相似度要100%,生成高质量的印刷图"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
# {
|
||||
# "name": "西风接口$0.003逆向",
|
||||
# "api_key": "sk-UT9aupbfHI4rc3RUn8x5D8gN5Kk31yvLZQu8M3BCY5Nja1Fc",
|
||||
# "api_url": "https://api.apiqik.com/v1/chat/completions" ,
|
||||
# "api_model": "gemini-2.5-flash-image",
|
||||
# "max_retries": 3, # 贵接口少重试
|
||||
# "cost": "低"
|
||||
# },
|
||||
|
||||
|
||||
{
|
||||
"name": "西风接口$0.014",
|
||||
"api_key": "sk-uRuvzLfIHsc3BiHZ2cyebk0cYsZ8NR9rLL326QqXCKIy9EpK",
|
||||
"api_url": "https://api.apiqik.online/v1beta/models",
|
||||
"api_model": GEMINI_IMAGE_MODEL,
|
||||
"max_retries": 2,
|
||||
"cost": "中",
|
||||
"use_gemini_format": True # 使用Gemini原生API格式
|
||||
},
|
||||
{
|
||||
"name": "西风接口Fallback",
|
||||
"api_key": "sk-uRuvzLfIHsc3BiHZ2cyebk0cYsZ8NR9rLL326QqXCKIy9EpK",
|
||||
"api_url": "https://api.apiqik.online/v1beta/models",
|
||||
"api_model": GEMINI_IMAGE_FALLBACK_MODEL,
|
||||
"max_retries": 1,
|
||||
"cost": "中",
|
||||
"use_gemini_format": True
|
||||
},
|
||||
|
||||
{
|
||||
"name": "最贵的",
|
||||
"api_key": "sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8",
|
||||
"api_url": "https://api.laozhang.ai/v1/chat/completions",
|
||||
"api_model": GEMINI_IMAGE_MODEL,
|
||||
"max_retries": 1,
|
||||
"cost": "高"
|
||||
}
|
||||
]
|
||||
|
||||
# 默认提示词
|
||||
DEFAULT_PROMPT = "提取印花图案,把褶皱移除。补齐缺失的部分,要生成完整,细节丰富,严格按照原图的元素位置生成平面的印花图,不要相似的,相似度要100%,生成高质量的印刷图"
|
||||
# DEFAULT_PROMPT = "生成图片,把衣服的图案展开起来做成数码印花印刷平面图。去掉皱褶,生成图案增强细节。排除衣服图案以外内容"
|
||||
def __init__(self):
|
||||
super().__init__(name="gemini_extract_v2")
|
||||
self.session = None
|
||||
|
||||
def image_to_base64(self, image_path: str) -> str:
|
||||
"""将图片文件转换为base64编码字符串"""
|
||||
super().__init__(name=self.SERVICE_NAME)
|
||||
|
||||
@staticmethod
|
||||
def _image_to_base64(image_path: str) -> str:
|
||||
if not os.path.exists(image_path):
|
||||
logger.error(f"文件不存在: {image_path}")
|
||||
return ""
|
||||
try:
|
||||
if not os.path.exists(image_path):
|
||||
logger.error(f"文件不存在: {image_path}")
|
||||
return None
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return encoded_string
|
||||
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Base64转换失败: {e}")
|
||||
return None
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _guess_mime_type(image_path: str) -> str:
|
||||
mime_type, _ = mimetypes.guess_type(str(image_path))
|
||||
return mime_type or "image/png"
|
||||
|
||||
@staticmethod
|
||||
def _build_generation_config(
|
||||
aspect_ratio: str,
|
||||
image_size: str,
|
||||
person_generation: str,
|
||||
thinking_level: str,
|
||||
) -> Dict:
|
||||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||||
valid_sizes = {"1K", "2K", "4K"}
|
||||
valid_thinking = {"MINIMAL", "LOW", "MEDIUM", "HIGH"}
|
||||
|
||||
image_config = {}
|
||||
if aspect_ratio in valid_ratios:
|
||||
image_config["aspectRatio"] = aspect_ratio
|
||||
size_val = (image_size or GEMINI_IMAGE_SIZE or "").upper().strip()
|
||||
if size_val in valid_sizes:
|
||||
image_config["imageSize"] = size_val
|
||||
person_val = (person_generation or GEMINI_PERSON_GENERATION or "").strip()
|
||||
if person_val:
|
||||
image_config["personGeneration"] = person_val
|
||||
|
||||
generation_config = {"responseModalities": ["IMAGE"]}
|
||||
if image_config:
|
||||
generation_config["imageConfig"] = image_config
|
||||
|
||||
thinking_val = (thinking_level or GEMINI_THINKING_LEVEL or "").upper().strip()
|
||||
if thinking_val in valid_thinking:
|
||||
generation_config["thinkingConfig"] = {"thinkingLevel": thinking_val}
|
||||
|
||||
return generation_config
|
||||
|
||||
@staticmethod
|
||||
def _extract_image_bytes(result: Dict) -> bytes:
|
||||
candidates = result.get("candidates") or []
|
||||
if not candidates:
|
||||
raise ValueError("响应缺少 candidates")
|
||||
parts = ((candidates[0] or {}).get("content") or {}).get("parts") or []
|
||||
for part in parts:
|
||||
inline_data = part.get("inlineData") or {}
|
||||
encoded = inline_data.get("data")
|
||||
if encoded:
|
||||
return base64.b64decode(encoded)
|
||||
finish_reason = candidates[0].get("finishReason") or ""
|
||||
if finish_reason == "NO_IMAGE":
|
||||
raise ValueError("模型未返回图片(NO_IMAGE)")
|
||||
raise ValueError("响应中未找到 inlineData 图片")
|
||||
|
||||
@staticmethod
|
||||
def _save_image(image_data: bytes, output_path: str) -> Dict:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
file_size = os.path.getsize(output_path)
|
||||
return {
|
||||
"output_path": output_path,
|
||||
"file_size": file_size,
|
||||
"api_used": "laozhang_gemini_native",
|
||||
}
|
||||
|
||||
async def extract_pattern(
|
||||
self,
|
||||
input_path: str,
|
||||
@@ -111,415 +126,85 @@ class GeminiExtractV2Service(BaseService):
|
||||
person_generation: str = "",
|
||||
thinking_level: str = "",
|
||||
) -> tuple[bool, str, dict]:
|
||||
"""
|
||||
使用多API配置进行印花图案提取
|
||||
|
||||
Args:
|
||||
input_path: 输入图片路径
|
||||
output_path: 输出图片路径
|
||||
custom_prompt: 自定义提示词
|
||||
|
||||
Returns:
|
||||
tuple: (success, message, data)
|
||||
"""
|
||||
# 转换图片为Base64
|
||||
img64 = self.image_to_base64(input_path)
|
||||
img64 = self._image_to_base64(input_path)
|
||||
if not img64:
|
||||
return False, "图片编码失败", {}
|
||||
|
||||
# 使用自定义提示词或默认提示词
|
||||
prompt = custom_prompt or self.DEFAULT_PROMPT
|
||||
|
||||
# 按优先级逐个尝试API配置
|
||||
for config_index, config in enumerate(self.API_CONFIGS):
|
||||
logger.info(f"尝试使用API: {config['name']} (成本: {config['cost']})")
|
||||
metrics_emit("gemini_request", model=config.get("api_model", ""), provider=config.get("name", ""))
|
||||
|
||||
# 对每个API配置进行重试
|
||||
for attempt in range(config['max_retries']):
|
||||
try:
|
||||
logger.info(f"开始Gemini V2印花提取 - {config['name']} (第{attempt + 1}/{config['max_retries']}次尝试): {input_path}")
|
||||
|
||||
# 准备请求数据和URL
|
||||
if config.get('use_gemini_format', False):
|
||||
# Gemini原生API格式
|
||||
api_url = f"{config['api_url']}/{config['api_model']}:generateContent?key={config['api_key']}"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 有效比例列表(Auto 不传 aspectRatio)
|
||||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||||
valid_sizes = {"1K", "2K", "4K"}
|
||||
valid_thinking = {"MINIMAL", "LOW", "MEDIUM", "HIGH"}
|
||||
image_config = {}
|
||||
if aspect_ratio in valid_ratios:
|
||||
image_config["aspectRatio"] = aspect_ratio
|
||||
size_val = (image_size or GEMINI_IMAGE_SIZE or "").upper().strip()
|
||||
if size_val in valid_sizes:
|
||||
image_config["imageSize"] = size_val
|
||||
person_val = (person_generation or GEMINI_PERSON_GENERATION or "").strip()
|
||||
if person_val:
|
||||
# 中转接口若支持该字段会生效;不设置时不发送,保证兼容
|
||||
image_config["personGeneration"] = person_val
|
||||
thinking_val = (thinking_level or GEMINI_THINKING_LEVEL or "").upper().strip()
|
||||
thinking_config = {}
|
||||
if thinking_val in valid_thinking:
|
||||
thinking_config["thinkingLevel"] = thinking_val
|
||||
|
||||
data = {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": img64
|
||||
}
|
||||
},
|
||||
{
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["IMAGE"],
|
||||
**({"imageConfig": image_config} if image_config else {}),
|
||||
**({"thinkingConfig": thinking_config} if thinking_config else {}),
|
||||
}
|
||||
}
|
||||
logger.info(
|
||||
f"Gemini 生成配置: 比例={aspect_ratio} 尺寸={image_config.get('imageSize', '默认')} "
|
||||
f"person={image_config.get('personGeneration', '默认')} thinking={thinking_config.get('thinkingLevel', '默认')}"
|
||||
)
|
||||
else:
|
||||
# OpenAI兼容格式
|
||||
api_url = config['api_url']
|
||||
headers = {
|
||||
"Authorization": f"Bearer {config['api_key']}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": config['api_model'],
|
||||
"stream": False,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
logger.info(f"正在请求{config['name']}服务 (第{attempt + 1}次)...")
|
||||
|
||||
# 发送异步请求
|
||||
timeout = aiohttp.ClientTimeout(total=300, connect=30)
|
||||
connector = aiohttp.TCPConnector(limit=10, limit_per_host=5)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"{config['name']} API请求失败 (第{attempt + 1}次): {response.status} - {error_text}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
result = await response.json()
|
||||
# Gemini 偶发只返回文本不返回图片:NO_IMAGE 时快速重试/降级
|
||||
if config.get('use_gemini_format', False):
|
||||
finish_reason = ""
|
||||
try:
|
||||
finish_reason = (
|
||||
(result.get("candidates") or [{}])[0].get("finishReason", "")
|
||||
)
|
||||
except Exception:
|
||||
finish_reason = ""
|
||||
if finish_reason == "NO_IMAGE":
|
||||
logger.warning(
|
||||
f"{config['name']} 返回 NO_IMAGE (模型={config.get('api_model')}),第{attempt + 1}次"
|
||||
)
|
||||
metrics_emit("gemini_no_image", model=config.get("api_model", ""), provider=config.get("name", ""))
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} NO_IMAGE 重试已用完,切换下一个配置")
|
||||
break
|
||||
await asyncio.sleep(1 + attempt)
|
||||
continue
|
||||
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, AssertionError) as e:
|
||||
logger.error(f"{config['name']} 网络连接错误 (第{attempt + 1}次): {str(e)}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 网络重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
logger.info(f"{config['name']} 服务请求成功 (第{attempt + 1}次),正在处理响应...")
|
||||
|
||||
# 处理API响应并提取图片
|
||||
success, message, data = await self._process_api_response(result, output_path, config['name'], config)
|
||||
|
||||
if success:
|
||||
logger.info(f"使用 {config['name']} 成功完成印花提取")
|
||||
metrics_emit("gemini_success", model=config.get("api_model", ""), provider=config.get("name", ""))
|
||||
try:
|
||||
from utils.api_cost_tracker import record
|
||||
record("gemini_extract", count=1)
|
||||
except Exception:
|
||||
pass
|
||||
return True, f"Gemini V2印花提取完成 - 使用{config['name']}", data
|
||||
else:
|
||||
logger.warning(f"{config['name']} 响应处理失败: {message}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{config['name']} API调用异常 (第{attempt + 1}次): {str(e)}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 异常重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# 所有API配置都尝试过了,返回失败
|
||||
return False, "所有API配置都已尝试失败", {}
|
||||
|
||||
async def _process_api_response(self, result: dict, output_path: str, api_name: str, config: dict) -> tuple[bool, str, dict]:
|
||||
"""处理API响应并提取图片"""
|
||||
try:
|
||||
# 根据API格式提取内容
|
||||
if config.get('use_gemini_format', False):
|
||||
# Gemini原生API格式: candidates[0].content.parts[0]
|
||||
content_parts = result['candidates'][0]['content']['parts']
|
||||
|
||||
# 查找包含图片数据的part
|
||||
image_data = None
|
||||
for part in content_parts:
|
||||
# 注意:响应中使用驼峰命名 inlineData
|
||||
if 'inlineData' in part:
|
||||
# 提取Base64图片数据
|
||||
base64_data = part['inlineData']['data']
|
||||
logger.info(f"{api_name} 找到Gemini格式的inlineData图片")
|
||||
try:
|
||||
image_data = base64.b64decode(base64_data)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||||
return False, f"Base64解码失败: {e}", {}
|
||||
|
||||
if not image_data:
|
||||
logger.error(f"{api_name} 在Gemini响应中未找到图片数据")
|
||||
return False, "未找到图片数据", {}
|
||||
|
||||
# 直接保存图片
|
||||
return await self._save_image(image_data, output_path, api_name)
|
||||
|
||||
else:
|
||||
# OpenAI兼容格式: choices[0].message.content
|
||||
content = result['choices'][0]['message']['content']
|
||||
logger.info(f"{api_name} 收到内容: {content[:200]}...")
|
||||
|
||||
# 使用原有的URL/Base64提取逻辑
|
||||
return await self._extract_and_save_image(content, output_path, api_name)
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(f"{api_name} 响应格式不正确,缺少字段: {e}")
|
||||
logger.error(f"响应内容: {json.dumps(result, ensure_ascii=False)[:500]}")
|
||||
return False, f"响应格式错误: {e}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} 处理响应时发生异常: {e}")
|
||||
return False, f"处理异常: {e}", {}
|
||||
|
||||
async def _save_image(self, image_data: bytes, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||||
"""保存图片文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
logger.info(f"{api_name} 图片已保存到: {output_path}")
|
||||
|
||||
# 验证保存的图片
|
||||
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||||
file_size = os.path.getsize(output_path)
|
||||
logger.info(f"{api_name} 图片保存成功,文件大小: {file_size} bytes")
|
||||
|
||||
return True, f"{api_name} 印花提取完成", {
|
||||
'output_path': output_path,
|
||||
'file_size': file_size,
|
||||
'api_used': api_name
|
||||
prompt = str(custom_prompt or self.DEFAULT_PROMPT).strip()
|
||||
api_url = f"{self.API_BASE_URL}/{GEMINI_IMAGE_MODEL}:generateContent"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {GEMINI_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{"inlineData": {"mimeType": self._guess_mime_type(input_path), "data": img64}},
|
||||
{"text": prompt},
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.error(f"{api_name} 保存的图片文件无效")
|
||||
return False, "保存的图片文件无效", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} 保存图片时发生错误: {e}")
|
||||
return False, f"保存图片失败: {e}", {}
|
||||
|
||||
async def _extract_and_save_image(self, content: str, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||||
"""从响应内容中提取并保存图片(URL或Base64格式)"""
|
||||
# 查找和处理图片数据
|
||||
image_data = None
|
||||
|
||||
# 方法1: 查找URL链接 (优先检查URL格式)
|
||||
url_match = re.search(r'https?://[^\s\)]+\.(?:png|jpg|jpeg|gif|webp)', content)
|
||||
if url_match:
|
||||
image_url = url_match.group(0)
|
||||
logger.info(f"{api_name} 找到图片URL: {image_url}")
|
||||
|
||||
# 图片下载重试机制
|
||||
download_retries = 3
|
||||
for download_attempt in range(download_retries):
|
||||
],
|
||||
"generationConfig": self._build_generation_config(
|
||||
aspect_ratio=aspect_ratio,
|
||||
image_size=image_size,
|
||||
person_generation=person_generation,
|
||||
thinking_level=thinking_level,
|
||||
),
|
||||
}
|
||||
|
||||
metrics_emit("gemini_request", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native")
|
||||
timeout = aiohttp.ClientTimeout(total=300, connect=30)
|
||||
|
||||
for attempt in range(1, 3):
|
||||
try:
|
||||
logger.info(f"Gemini 出图开始 attempt={attempt}/2 model={GEMINI_IMAGE_MODEL} input={input_path}")
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Gemini API请求失败 attempt={attempt}: {response.status} - {error_text}")
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
return False, f"Gemini API请求失败: {response.status}", {}
|
||||
result = await response.json()
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
data = self._save_image(image_bytes, output_path)
|
||||
metrics_emit("gemini_success", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native")
|
||||
try:
|
||||
logger.info(f"{api_name} 开始下载图片 (第{download_attempt + 1}/{download_retries}次尝试): {image_url}")
|
||||
|
||||
# 异步下载图片,增加超时时间
|
||||
timeout = aiohttp.ClientTimeout(total=300, connect=60)
|
||||
connector = aiohttp.TCPConnector(limit=5, limit_per_host=2)
|
||||
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
connector=connector,
|
||||
headers=headers
|
||||
) as download_session:
|
||||
logger.info(f"{api_name} 正在发送HTTP请求...")
|
||||
async with download_session.get(image_url) as img_response:
|
||||
logger.info(f"{api_name} 收到HTTP响应: {img_response.status}")
|
||||
if img_response.status == 200:
|
||||
image_data = await img_response.read()
|
||||
logger.info(f"{api_name} 图片下载成功,大小: {len(image_data)} bytes")
|
||||
break # 成功则跳出重试循环
|
||||
else:
|
||||
logger.error(f"{api_name} 图片下载失败,HTTP状态码: {img_response.status}")
|
||||
if download_attempt == download_retries - 1:
|
||||
return False, "图片下载失败", {}
|
||||
else:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} 下载图片时发生异常 (第{download_attempt + 1}次): {type(e).__name__}: {str(e)}")
|
||||
if download_attempt == download_retries - 1:
|
||||
return False, f"图片下载异常: {str(e)}", {}
|
||||
else:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
else:
|
||||
# 方法2: 查找标准格式 data:image/type;base64,data
|
||||
base64_match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', content)
|
||||
|
||||
if base64_match:
|
||||
base64_data = base64_match.group(1)
|
||||
logger.info(f"{api_name} 找到标准格式的Base64数据")
|
||||
try:
|
||||
image_data = base64.b64decode(base64_data)
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||||
return False, f"Base64解码失败: {e}", {}
|
||||
else:
|
||||
# 方法3: 查找纯Base64数据(长字符串)
|
||||
base64_match = re.search(r'([A-Za-z0-9+/=]{100,})', content)
|
||||
if base64_match:
|
||||
base64_data = base64_match.group(1)
|
||||
logger.info(f"{api_name} 找到纯Base64数据")
|
||||
try:
|
||||
image_data = base64.b64decode(base64_data)
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||||
return False, f"Base64解码失败: {e}", {}
|
||||
else:
|
||||
logger.error(f"{api_name} 在响应中未找到图片数据")
|
||||
return False, "未找到图片数据", {}
|
||||
|
||||
# 检查图片数据
|
||||
if not image_data:
|
||||
logger.error(f"{api_name} 图片数据为空")
|
||||
return False, "图片数据为空", {}
|
||||
|
||||
# 保存图片
|
||||
return await self._save_image(image_data, output_path, api_name)
|
||||
|
||||
from utils.api_cost_tracker import record
|
||||
|
||||
record("gemini_extract", count=1)
|
||||
except Exception:
|
||||
pass
|
||||
return True, "Gemini 出图完成", data
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini 出图异常 attempt={attempt}: {e}")
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(attempt)
|
||||
continue
|
||||
return False, f"Gemini 出图失败: {e}", {}
|
||||
|
||||
return False, "Gemini 出图失败", {}
|
||||
|
||||
async def correct_perspective(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
level: str = "mild",
|
||||
) -> tuple[bool, str, dict]:
|
||||
"""
|
||||
透视矫正:先把有透视畸变的图还原为正面平铺视图,再做后续处理。
|
||||
|
||||
Args:
|
||||
input_path: 本地图片路径
|
||||
output_path: 矫正后输出路径
|
||||
level: "mild" 或 "strong"
|
||||
"""
|
||||
if level == "strong":
|
||||
prompt = (
|
||||
"这张图存在明显透视畸变(俯拍/斜拍/贴墙)。"
|
||||
"请对图片进行透视矫正:将主体变换为正面平铺视图,"
|
||||
"使所有边缘变成水平或垂直,去除梯形形变,"
|
||||
"保持图案颜色和细节完全不变,只矫正几何形状,输出矫正后的完整图片。"
|
||||
"这张图存在明显透视畸变。请把主体矫正为正面平铺视图,"
|
||||
"所有边缘尽量水平或垂直,保持图案颜色和细节不变,只做几何矫正。"
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
"这张图存在轻微透视畸变(衣物悬挂/桌面斜拍)。"
|
||||
"请做轻度透视矫正:将主体调整为尽量正视角,"
|
||||
"消除轻微的梯形拉伸感,保持图案颜色和细节不变,输出矫正后的图片。"
|
||||
"这张图存在轻微透视畸变。请做轻度透视矫正,"
|
||||
"消除斜拍拉伸感,保持图案颜色和细节不变。"
|
||||
)
|
||||
|
||||
# 透视矫正使用 1:1 比例避免比例失真
|
||||
return await self.extract_pattern(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
@@ -528,40 +213,17 @@ class GeminiExtractV2Service(BaseService):
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
return None
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def extract_pattern_v2(
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
custom_prompt: str = None,
|
||||
aspect_ratio: str = "1:1",
|
||||
) -> tuple[bool, str, dict]:
|
||||
"""Gemini V2印花提取便捷函数"""
|
||||
service = GeminiExtractV2Service()
|
||||
try:
|
||||
return await service.extract_pattern(input_path, output_path, custom_prompt, aspect_ratio)
|
||||
finally:
|
||||
await service.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
import asyncio
|
||||
|
||||
async def test():
|
||||
service = GeminiExtractV2Service()
|
||||
|
||||
input_path = "F:/api/134.png"
|
||||
output_path = "test_output_v2.png"
|
||||
|
||||
success, message, data = await service.extract_pattern(input_path, output_path)
|
||||
|
||||
print(f"结果: {success}")
|
||||
print(f"消息: {message}")
|
||||
print(f"数据: {data}")
|
||||
|
||||
await service.cleanup()
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
@@ -20,6 +20,13 @@ logger = logging.getLogger("cs_agent")
|
||||
|
||||
|
||||
ANALYSIS_PROMPT = """你是一个电商图片处理评估专家。
|
||||
客户需求如下:
|
||||
{customer_requirement}
|
||||
|
||||
请结合客户需求和图片内容一起判断,不要只看图片本身。
|
||||
如果客户明确说了“找原图/找图/素材/大图”,类型优先判断为“找原图/素材提取”类;
|
||||
如果客户明确说了“修复/高清/清晰/放大”,类型优先判断为“高清修复”类。
|
||||
|
||||
请仔细分析这张图片,输出以下字段,每行一个,不要多余内容:
|
||||
|
||||
敏感内容: <yes|no>
|
||||
@@ -101,7 +108,7 @@ class ImageAnalyzerService:
|
||||
logger.debug(f"[ImageAnalyzer] 获取尺寸失败: {e}")
|
||||
return (0, 0)
|
||||
|
||||
async def analyze(self, image_url: str) -> Dict[str, Any]:
|
||||
async def analyze(self, image_url: str, customer_requirement: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
异步分析图片,返回结构化结果
|
||||
|
||||
@@ -133,7 +140,8 @@ class ImageAnalyzerService:
|
||||
return self._fallback(image_url, "未配置 API Key")
|
||||
|
||||
# 缓存检查
|
||||
cache_key = image_url
|
||||
customer_requirement = str(customer_requirement or "").strip()
|
||||
cache_key = f"{image_url}|{customer_requirement}"
|
||||
now = time.monotonic()
|
||||
cached = self._analysis_cache.get(cache_key)
|
||||
if cached:
|
||||
@@ -149,6 +157,9 @@ class ImageAnalyzerService:
|
||||
try:
|
||||
client = AsyncOpenAI(base_url=self.base_url, api_key=self.api_key)
|
||||
|
||||
prompt_text = ANALYSIS_PROMPT.format(
|
||||
customer_requirement=customer_requirement or "未提供明确补充需求"
|
||||
)
|
||||
response = await asyncio.wait_for(
|
||||
client.chat.completions.create(
|
||||
model=self.vision_model,
|
||||
@@ -156,7 +167,7 @@ class ImageAnalyzerService:
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": ANALYSIS_PROMPT}
|
||||
{"type": "text", "text": prompt_text}
|
||||
]
|
||||
}],
|
||||
max_tokens=500
|
||||
@@ -170,6 +181,7 @@ class ImageAnalyzerService:
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
result = self._parse_result(image_url, content)
|
||||
result["customer_requirement"] = customer_requirement
|
||||
result["elapsed"] = round(elapsed, 2)
|
||||
|
||||
# 获取尺寸
|
||||
@@ -241,6 +253,7 @@ class ImageAnalyzerService:
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
"customer_requirement": "",
|
||||
"complexity": complexity,
|
||||
"reason": extract("原因"),
|
||||
"subject": extract("主体"),
|
||||
@@ -280,6 +293,7 @@ class ImageAnalyzerService:
|
||||
from datetime import datetime
|
||||
return {
|
||||
"url": url,
|
||||
"customer_requirement": "",
|
||||
"complexity": "normal",
|
||||
"reason": reason,
|
||||
"subject": "",
|
||||
|
||||
@@ -7,8 +7,9 @@ import os
|
||||
import httpx
|
||||
import logging
|
||||
import mimetypes
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from typing import Iterator, Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,11 +18,41 @@ load_dotenv()
|
||||
# 图绘平台配置
|
||||
TUHUI_BASE_URL = os.getenv("TUHUI_BASE_URL", "https://tuhui.cloud")
|
||||
TUHUI_FALLBACK_BASE_URL = "https://tuhui.cloud"
|
||||
TUHUI_WEB_BASE_URL = os.getenv("TUHUI_WEB_BASE_URL", "https://tuhui.cloud").rstrip("/")
|
||||
TUHUI_PHONE = os.getenv("TUHUI_PHONE", "17520145271") # 图绘账号手机号
|
||||
TUHUI_PASSWORD = os.getenv("TUHUI_PASSWORD", "zuowei1216") # 图绘账号密码
|
||||
TUHUI_DEFAULT_PRICE = int(os.getenv("TUHUI_DEFAULT_PRICE", "20")) # 默认定价(元)
|
||||
TUHUI_DEFAULT_CATEGORY = os.getenv("TUHUI_DEFAULT_CATEGORY", "设计素材")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TuhuiUploadResult:
|
||||
"""图绘上传结果。主返回 URL 为站内作品页,保留三元组解包兼容。"""
|
||||
success: bool
|
||||
download_url: str
|
||||
work_id: int
|
||||
image_url: str = ""
|
||||
thumbnail_url: str = ""
|
||||
watermarked_url: str = ""
|
||||
message: str = ""
|
||||
|
||||
def __iter__(self) -> Iterator[object]:
|
||||
# 兼容历史调用:ok, download_url, work_id = result
|
||||
yield self.success
|
||||
yield self.download_url
|
||||
yield self.work_id
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"download_url": self.download_url,
|
||||
"work_id": self.work_id,
|
||||
"image_url": self.image_url,
|
||||
"thumbnail_url": self.thumbnail_url,
|
||||
"watermarked_url": self.watermarked_url,
|
||||
"message": self.message,
|
||||
}
|
||||
|
||||
class TuhuiUploadService:
|
||||
"""图绘平台上传服务"""
|
||||
|
||||
@@ -49,6 +80,10 @@ class TuhuiUploadService:
|
||||
def _api_url(self, path: str) -> str:
|
||||
return self._build_api_url(self.base_url, path)
|
||||
|
||||
@staticmethod
|
||||
def _build_work_url(work_id: int) -> str:
|
||||
return f"{TUHUI_WEB_BASE_URL}/detail/{int(work_id)}"
|
||||
|
||||
@staticmethod
|
||||
def _guess_file_meta(image_path: str) -> tuple[str, str]:
|
||||
path = Path(image_path)
|
||||
@@ -97,7 +132,8 @@ class TuhuiUploadService:
|
||||
price: Optional[int] = None,
|
||||
category: str = TUHUI_DEFAULT_CATEGORY,
|
||||
tags: str = "",
|
||||
) -> Tuple[bool, str, int]:
|
||||
designer_name: str = "",
|
||||
) -> TuhuiUploadResult:
|
||||
"""
|
||||
上传图片到图绘平台
|
||||
|
||||
@@ -109,16 +145,19 @@ class TuhuiUploadService:
|
||||
category: 分类
|
||||
|
||||
Returns:
|
||||
(success, image_url, work_id)
|
||||
TuhuiUploadResult
|
||||
- success: 是否上传成功
|
||||
- image_url: 图片 URL
|
||||
- download_url: 站内作品页地址
|
||||
- image_url: 原图 URL(保留,便于需要时取用)
|
||||
- thumbnail_url: 缩略图 URL
|
||||
- watermarked_url: 水印图 URL
|
||||
- work_id: 作品 ID
|
||||
"""
|
||||
try:
|
||||
# 如果 token 过期,重新登录
|
||||
if not self.access_token:
|
||||
if not await self.login():
|
||||
return False, "登录失败", 0
|
||||
return TuhuiUploadResult(False, "", 0, message="登录失败")
|
||||
|
||||
# 准备上传数据
|
||||
price = price or self.default_price
|
||||
@@ -126,7 +165,7 @@ class TuhuiUploadService:
|
||||
# 读取图片文件
|
||||
if not os.path.exists(image_path):
|
||||
logger.error(f"图片文件不存在:{image_path}")
|
||||
return False, "文件不存在", 0
|
||||
return TuhuiUploadResult(False, "", 0, message="文件不存在")
|
||||
|
||||
filename, mime_type = self._guess_file_meta(image_path)
|
||||
with open(image_path, "rb") as f:
|
||||
@@ -142,6 +181,8 @@ class TuhuiUploadService:
|
||||
}
|
||||
if tags:
|
||||
data["tags"] = tags
|
||||
if designer_name:
|
||||
data["designer_name"] = str(designer_name).strip()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}"
|
||||
@@ -160,12 +201,34 @@ class TuhuiUploadService:
|
||||
payload = response.json()
|
||||
if not payload.get("success", False):
|
||||
logger.error(f"图绘平台上传返回失败:{payload}")
|
||||
return False, payload.get("message", "上传失败"), 0
|
||||
return TuhuiUploadResult(
|
||||
False,
|
||||
"",
|
||||
0,
|
||||
message=str(payload.get("message", "上传失败")),
|
||||
)
|
||||
|
||||
work_id = int(payload.get("work_id") or payload.get("work", {}).get("id") or 0)
|
||||
image_url = str(payload.get("image_url") or payload.get("work", {}).get("original_image") or "")
|
||||
logger.info(f"图绘平台上传成功,作品 ID: {work_id}, URL: {image_url}")
|
||||
return True, image_url, work_id
|
||||
thumbnail_url = str(
|
||||
payload.get("thumbnail_url") or payload.get("work", {}).get("thumbnail_image") or ""
|
||||
)
|
||||
watermarked_url = str(
|
||||
payload.get("watermarked_url") or payload.get("work", {}).get("watermarked_image") or ""
|
||||
)
|
||||
download_url = self._build_work_url(work_id) if work_id else ""
|
||||
logger.info(
|
||||
f"图绘平台上传成功,作品 ID: {work_id}, 站内地址: {download_url}, 原图: {image_url}"
|
||||
)
|
||||
return TuhuiUploadResult(
|
||||
True,
|
||||
download_url,
|
||||
work_id,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
watermarked_url=watermarked_url,
|
||||
message=str(payload.get("message", "上传成功")),
|
||||
)
|
||||
else:
|
||||
logger.error(f"图绘平台上传失败:{response.status_code} {response.text}")
|
||||
|
||||
@@ -176,13 +239,13 @@ class TuhuiUploadService:
|
||||
if await self.login():
|
||||
# 重新上传
|
||||
return await self.upload_image(
|
||||
image_path, title, description, price, category
|
||||
image_path, title, description, price, category, tags, designer_name
|
||||
)
|
||||
|
||||
return False, f"上传失败:{response.text}", 0
|
||||
return TuhuiUploadResult(False, "", 0, message=f"上传失败:{response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"图绘平台上传异常:{e}")
|
||||
return False, f"上传异常:{e}", 0
|
||||
return TuhuiUploadResult(False, "", 0, message=f"上传异常:{e}")
|
||||
|
||||
|
||||
# 单例
|
||||
@@ -204,12 +267,13 @@ async def upload_to_tuhui(
|
||||
price: int = 20,
|
||||
category: str = TUHUI_DEFAULT_CATEGORY,
|
||||
tags: str = "",
|
||||
) -> Tuple[bool, str, int]:
|
||||
designer_name: str = "",
|
||||
) -> TuhuiUploadResult:
|
||||
"""
|
||||
便捷函数:上传图片到图绘平台
|
||||
|
||||
Returns:
|
||||
(success, image_url, work_id)
|
||||
TuhuiUploadResult
|
||||
"""
|
||||
service = get_tuhui_service()
|
||||
return await service.upload_image(image_path, title, description, price, category, tags)
|
||||
return await service.upload_image(image_path, title, description, price, category, tags, designer_name)
|
||||
|
||||
51
services/service_wecom_bot.py
Normal file
51
services/service_wecom_bot.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger("cs_agent")
|
||||
|
||||
WECOM_BOT_WEBHOOK = os.getenv(
|
||||
"WECOM_BOT_WEBHOOK",
|
||||
"https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=cc88bdef-a13f-4d7e-bdb6-ee51b68b8205",
|
||||
).strip()
|
||||
|
||||
|
||||
class WecomBotService:
|
||||
def __init__(self, webhook_url: str = WECOM_BOT_WEBHOOK):
|
||||
self.webhook_url = str(webhook_url or "").strip()
|
||||
|
||||
async def send_text(self, content: str) -> bool:
|
||||
text = str(content or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
if not self.webhook_url:
|
||||
logger.warning("[WeComBot] 未配置 webhook,跳过发送")
|
||||
return False
|
||||
|
||||
payload = {
|
||||
"msgtype": "text",
|
||||
"text": {
|
||||
"content": text[:3500],
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(self.webhook_url, json=payload)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"[WeComBot] 发送失败 HTTP {response.status_code}: {response.text}")
|
||||
return False
|
||||
data = response.json()
|
||||
ok = int(data.get("errcode", -1)) == 0
|
||||
if not ok:
|
||||
logger.warning(f"[WeComBot] 发送失败: {data}")
|
||||
return ok
|
||||
except Exception as e:
|
||||
logger.warning(f"[WeComBot] 发送异常: {e}")
|
||||
return False
|
||||
|
||||
|
||||
wecom_bot_service = WecomBotService()
|
||||
Reference in New Issue
Block a user