init
This commit is contained in:
328
image/image_processor.py
Normal file
328
image/image_processor.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""图片处理模块 - 调用 Gemini 作图API,含质检与自动重试"""
|
||||
import os
|
||||
import uuid
|
||||
import tempfile
|
||||
from typing import Optional, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_OUTPUT_DIR = os.getenv("RESULT_IMAGE_DIR", "results")
|
||||
_MAX_RETRIES = int(os.getenv("PROCESS_MAX_RETRIES", "2")) # 含首次共最多处理几次
|
||||
|
||||
|
||||
class ImageProcessor:
|
||||
"""图片处理 - 对接 GeminiExtractV2Service,含质检与重试"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# ─── 内部工具 ────────────────────────────────────────────
|
||||
|
||||
async def _download(self, url: str) -> str:
|
||||
"""下载图片到临时文件,返回本地路径"""
|
||||
import aiohttp
|
||||
tmp = os.path.join(tempfile.gettempdir(), f"gemini_in_{uuid.uuid4().hex}.jpg")
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/122.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Referer": "https://www.taobao.com/",
|
||||
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
|
||||
}
|
||||
async with aiohttp.ClientSession(headers=headers) as session:
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"下载图片失败: HTTP {resp.status}")
|
||||
with open(tmp, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return tmp
|
||||
|
||||
async def _do_perspective(self, service, src: str, level: str) -> str:
|
||||
"""透视矫正,返回矫正后文件路径(失败则返回原路径)"""
|
||||
out = os.path.join(tempfile.gettempdir(), f"gemini_persp_{uuid.uuid4().hex}.jpg")
|
||||
ok, msg, _ = await service.correct_perspective(src, out, level=level)
|
||||
if ok:
|
||||
print(f"[ImageProcessor] 透视矫正完成")
|
||||
return out
|
||||
else:
|
||||
print(f"[ImageProcessor] 透视矫正失败 ({msg}),跳过")
|
||||
if os.path.exists(out):
|
||||
os.remove(out)
|
||||
return src
|
||||
|
||||
@staticmethod
|
||||
def _build_retry_prompt(gemini_prompt: str, qa_issue: str, qa_suggestion: str) -> str:
|
||||
"""
|
||||
根据 QA 质检问题类型,智能调整重试提示词。
|
||||
比简单追加建议更有针对性,让 Gemini 知道上次哪里出了问题。
|
||||
"""
|
||||
base = gemini_prompt or ""
|
||||
issue = (qa_issue or "").lower()
|
||||
suggestion = qa_suggestion if qa_suggestion and qa_suggestion != "无" else ""
|
||||
|
||||
# 背景不干净
|
||||
if any(kw in issue for kw in ["背景", "杂物", "多余", "白色不纯"]):
|
||||
prefix = "【重要:背景必须是纯白色 #FFFFFF,去掉所有杂物和阴影】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 清晰度/细节不足
|
||||
if any(kw in issue for kw in ["模糊", "清晰", "细节", "锐化", "分辨率"]):
|
||||
prefix = "【重要:提升整体清晰度和细节,输出高分辨率版本,不要模糊】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 内容缺失/截断
|
||||
if any(kw in issue for kw in ["缺失", "截断", "不完整", "边缘", "裁剪"]):
|
||||
prefix = "【重要:保留主体完整内容,不要截断边缘,确保四角全部保留】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 颜色偏差
|
||||
if any(kw in issue for kw in ["颜色", "色彩", "偏色", "色调"]):
|
||||
prefix = "【重要:忠实还原原图颜色,不要改变色调或过度饱和】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# AI幻觉/变形
|
||||
if any(kw in issue for kw in ["幻觉", "变形", "失真", "扭曲", "ai生成"]):
|
||||
prefix = "【重要:严格按原图内容处理,不要添加或改变任何图案细节】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 没有匹配到具体类型,直接用质检建议
|
||||
if suggestion:
|
||||
return (base + f"\n【上次问题:{qa_issue}。本次改进方向:{suggestion}】").strip()
|
||||
|
||||
return base
|
||||
|
||||
async def _do_main(self, service, src: str, gemini_prompt: str, aspect_ratio: str,
|
||||
attempt: int, qa_issue: str = "", qa_suggestion: str = "") -> tuple[bool, str, str]:
|
||||
"""
|
||||
执行一次主处理。
|
||||
重试时根据 QA 问题类型智能调整提示词。
|
||||
|
||||
Returns:
|
||||
(success, output_path, message)
|
||||
"""
|
||||
out_name = f"result_{uuid.uuid4().hex}.jpg"
|
||||
output_path = os.path.join(_OUTPUT_DIR, out_name)
|
||||
|
||||
if attempt == 1:
|
||||
prompt = gemini_prompt or None
|
||||
else:
|
||||
prompt = self._build_retry_prompt(gemini_prompt, qa_issue, qa_suggestion)
|
||||
print(f"[ImageProcessor] 重试策略 | 问题: {qa_issue} | 提示词: {(prompt or '')[:80]}...")
|
||||
|
||||
print(f"[ImageProcessor] 主处理第 {attempt} 次 (比例={aspect_ratio})...")
|
||||
success, message, _ = await service.extract_pattern(
|
||||
input_path=src,
|
||||
output_path=output_path,
|
||||
custom_prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
return success, output_path, message
|
||||
|
||||
# ─── 主入口 ──────────────────────────────────────────────
|
||||
|
||||
async def process_image(
|
||||
self,
|
||||
image_url: str,
|
||||
operation: str,
|
||||
requirements: str = "",
|
||||
gemini_prompt: str = "",
|
||||
aspect_ratio: str = "1:1",
|
||||
perspective: str = "no",
|
||||
proc_type: str = "",
|
||||
subject: str = "",
|
||||
quality: str = "",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
完整处理流程:下载 → 透视矫正(可选)→ Gemini主处理 → 质检 → 重试(可选)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"result_path": str,
|
||||
"message": str,
|
||||
"qa_score": int, # 质检得分 0-100
|
||||
"qa_pass": bool, # 是否通过质检
|
||||
"qa_issue": str, # 质检发现的问题
|
||||
"attempts": int, # 共处理了几次
|
||||
}
|
||||
"""
|
||||
from services.service_gemini import GeminiExtractV2Service
|
||||
from image.image_qa import image_qa
|
||||
|
||||
# Step 1: 下载原图
|
||||
try:
|
||||
tmp_input = await self._download(image_url)
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False, "result_path": "", "message": str(e),
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "下载失败", "attempts": 0,
|
||||
}
|
||||
|
||||
# Step 1.5: 敏感图片检测
|
||||
try:
|
||||
from utils.content_filter import is_sensitive_image
|
||||
sensitive, reason = await is_sensitive_image(tmp_input)
|
||||
if sensitive:
|
||||
if os.path.exists(tmp_input):
|
||||
os.remove(tmp_input)
|
||||
return {
|
||||
"success": False, "result_path": "", "message": reason,
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "敏感图片", "attempts": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"[ImageProcessor] 敏感图片检测异常: {e},继续处理")
|
||||
|
||||
# Step 1.6: 预检(尺寸/格式/损坏)
|
||||
try:
|
||||
from image.image_precheck import precheck
|
||||
ok, msg = precheck(tmp_input)
|
||||
if not ok:
|
||||
if os.path.exists(tmp_input):
|
||||
os.remove(tmp_input)
|
||||
return {
|
||||
"success": False, "result_path": "", "message": msg,
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "预检不通过", "attempts": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"[ImageProcessor] 预检异常: {e},继续处理")
|
||||
|
||||
service = GeminiExtractV2Service()
|
||||
tmp_files = [tmp_input]
|
||||
try:
|
||||
# Step 2: 透视矫正
|
||||
current_input = tmp_input
|
||||
if perspective in ("mild", "strong"):
|
||||
print(f"[ImageProcessor] 透视矫正中 (level={perspective})...")
|
||||
corrected = await self._do_perspective(service, tmp_input, perspective)
|
||||
if corrected != tmp_input:
|
||||
tmp_files.append(corrected)
|
||||
current_input = corrected
|
||||
|
||||
# Step 3: 主处理 + 质检,最多 _MAX_RETRIES 次
|
||||
qa_result = {"score": 0, "pass": False, "issue": "未质检", "suggestion": "无"}
|
||||
output_path = ""
|
||||
last_message = ""
|
||||
qa_issue = ""
|
||||
qa_suggestion = ""
|
||||
|
||||
for attempt in range(1, _MAX_RETRIES + 1):
|
||||
ok, output_path, last_message = await self._do_main(
|
||||
service, current_input, gemini_prompt, aspect_ratio,
|
||||
attempt=attempt, qa_issue=qa_issue, qa_suggestion=qa_suggestion,
|
||||
)
|
||||
|
||||
if not ok:
|
||||
print(f"[ImageProcessor] 第 {attempt} 次处理失败: {last_message}")
|
||||
if attempt < _MAX_RETRIES:
|
||||
continue
|
||||
return {
|
||||
"success": False, "result_path": "", "message": last_message,
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "Gemini处理失败", "attempts": attempt,
|
||||
}
|
||||
|
||||
# Step 4: 质检
|
||||
print(f"[ImageProcessor] 质检中 (第 {attempt} 次结果)...")
|
||||
qa_result = await image_qa.check(
|
||||
original_path=current_input,
|
||||
result_path=output_path,
|
||||
proc_type=proc_type,
|
||||
subject=subject,
|
||||
quality=quality,
|
||||
gemini_prompt=gemini_prompt,
|
||||
)
|
||||
qa_issue = qa_result.get("issue", "")
|
||||
qa_suggestion = qa_result.get("suggestion", "无")
|
||||
|
||||
if qa_result["pass"]:
|
||||
print(f"[ImageProcessor] 质检通过 ({qa_result['score']}分),共处理 {attempt} 次")
|
||||
break
|
||||
else:
|
||||
print(f"[ImageProcessor] 质检不合格 ({qa_result['score']}分),问题: {qa_result['issue']}")
|
||||
if attempt < _MAX_RETRIES:
|
||||
# 清理这次不合格的结果
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path)
|
||||
print(f"[ImageProcessor] 准备第 {attempt + 1} 次重试...")
|
||||
else:
|
||||
print(f"[ImageProcessor] 已达最大重试次数 {_MAX_RETRIES},保留最后结果,人工跟进")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result_path": output_path,
|
||||
"message": last_message,
|
||||
"qa_score": qa_result.get("score", 0),
|
||||
"qa_pass": qa_result.get("pass", False),
|
||||
"qa_issue": qa_result.get("issue", ""),
|
||||
"attempts": attempt,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False, "result_path": "", "message": f"处理异常: {e}",
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": str(e), "attempts": 0,
|
||||
}
|
||||
finally:
|
||||
await service.cleanup()
|
||||
for f in tmp_files:
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
async def enhance(self, image_url: str) -> Dict[str, Any]:
|
||||
return await self.process_image(image_url, "enhance")
|
||||
|
||||
async def remove_bg(self, image_url: str) -> Dict[str, Any]:
|
||||
return await self.process_image(image_url, "remove_bg")
|
||||
|
||||
async def resize(self, image_url: str, width: int, height: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
改尺寸:下载图片(或读取本地路径),按指定宽高缩放,保存到 results/。
|
||||
|
||||
Args:
|
||||
image_url: 图片 URL 或本地路径
|
||||
width: 目标宽度(像素)
|
||||
height: 目标高度(0=按宽度等比缩放)
|
||||
|
||||
Returns:
|
||||
{"success": bool, "result_path": str, "message": str}
|
||||
"""
|
||||
from PIL import Image
|
||||
is_temp = image_url.startswith(("http://", "https://"))
|
||||
try:
|
||||
if is_temp:
|
||||
tmp = await self._download(image_url)
|
||||
else:
|
||||
tmp = image_url
|
||||
if not os.path.exists(tmp):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {tmp}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
|
||||
try:
|
||||
img = Image.open(tmp).convert("RGB")
|
||||
w_orig, h_orig = img.size
|
||||
if width <= 0 or width > 10000:
|
||||
return {"success": False, "result_path": "", "message": f"宽度无效: {width}"}
|
||||
if height == 0:
|
||||
ratio = width / w_orig
|
||||
height = int(h_orig * ratio)
|
||||
elif height <= 0 or height > 10000:
|
||||
return {"success": False, "result_path": "", "message": f"高度无效: {height}"}
|
||||
resized = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
out_name = f"resize_{uuid.uuid4().hex}.jpg"
|
||||
out_path = os.path.join(_OUTPUT_DIR, out_name)
|
||||
resized.save(out_path, "JPEG", quality=95)
|
||||
print(f"[ImageProcessor] 改尺寸完成: {w_orig}x{h_orig} → {width}x{height}")
|
||||
return {"success": True, "result_path": out_path, "message": f"已改为 {width}x{height}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if is_temp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
# 全局实例
|
||||
image_processor = ImageProcessor()
|
||||
Reference in New Issue
Block a user