init
This commit is contained in:
0
image/__init__.py
Normal file
0
image/__init__.py
Normal file
BIN
image/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
image/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
image/__pycache__/image_analyzer.cpython-310.pyc
Normal file
BIN
image/__pycache__/image_analyzer.cpython-310.pyc
Normal file
Binary file not shown.
BIN
image/__pycache__/image_precheck.cpython-310.pyc
Normal file
BIN
image/__pycache__/image_precheck.cpython-310.pyc
Normal file
Binary file not shown.
BIN
image/__pycache__/image_processor.cpython-310.pyc
Normal file
BIN
image/__pycache__/image_processor.cpython-310.pyc
Normal file
Binary file not shown.
BIN
image/__pycache__/image_qa.cpython-310.pyc
Normal file
BIN
image/__pycache__/image_qa.cpython-310.pyc
Normal file
Binary file not shown.
593
image/image_analyzer.py
Normal file
593
image/image_analyzer.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""
|
||||
图片复杂度识别模块
|
||||
|
||||
使用智谱 GLM-4V 视觉模型分析客户发来的图片,
|
||||
判断处理难度,为客服AI提供报价依据。
|
||||
|
||||
复杂度等级(越平整越便宜):
|
||||
simple → 10-15元(画面平整、无小字、无人脸、无阴影)
|
||||
normal → 15-20元(一般复杂度)
|
||||
complex → 20-25元(有褶皱/小字/人脸/阴影)
|
||||
hard → 25-30元(非常复杂)
|
||||
|
||||
报价维度:平整度、含文字(小字加价)、含人脸、阴影。
|
||||
同一 URL 5 分钟内复用缓存,节省 API 调用。
|
||||
"""
|
||||
import os
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
from openai import AsyncOpenAI
|
||||
from dotenv import load_dotenv
|
||||
from PIL import Image
|
||||
import aiohttp
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
ANALYSIS_PROMPT = """你是一个电商图片处理评估专家,同时也是 Gemini 图像生成提示词专家。
|
||||
请仔细分析这张图片,输出以下字段,每行一个,不要多余内容:
|
||||
|
||||
敏感内容: <yes|no>
|
||||
平整度: <flat|mild|rough>
|
||||
含文字: <yes|no>
|
||||
含人脸: <yes|no>
|
||||
阴影: <yes|no>
|
||||
复杂度: <simple|normal|complex|hard>
|
||||
原因: <15字以内,说明复杂度判断依据>
|
||||
主体: <图片核心内容,如:印花图案/logo/人物/产品/老照片/风景/文字/其他>
|
||||
类型: <处理类型,如:印花提取/高清修复/去背景/老照片修复/logo提取/人像修复/其他>
|
||||
质量: <原图质量,如:清晰/轻微模糊/严重模糊/低分辨率/截图/扫描件>
|
||||
可做: <yes|partial|no>
|
||||
风险: <none|low|high>
|
||||
透视: <no|mild|strong>
|
||||
比例: <从以下选一个最合适的:1:1 / 9:16 / 16:9 / 3:4 / 4:3 / 3:2 / 2:3 / 5:4 / 4:5>
|
||||
提示词: <为 Gemini 写处理指令,中文,60字以内,说明要做什么、保留什么、去掉什么>
|
||||
备注: <给客服AI的特别提示,没有则填无>
|
||||
|
||||
判断规则:
|
||||
|
||||
【报价核心:越平整越便宜】
|
||||
- 平整度 flat:画面平整、无褶皱、无透视 → 便宜
|
||||
- 平整度 mild:轻微褶皱/透视 → 中等
|
||||
- 平整度 rough:有褶皱/透视/曲面 → 贵
|
||||
- 含文字:大字没关系不加价;小字需精细保留/清晰化 → 加价(含文字填 yes 仅指有小字的情况)
|
||||
- 含人脸 yes:有人脸 → 加价
|
||||
- 阴影 yes:有明显阴影需处理 → 加价
|
||||
综合以上因素,越平整、无小字、无人脸、无阴影 → 越便宜(simple)
|
||||
|
||||
【含文字】
|
||||
- yes:含小字需精细保留/清晰化(小字难处理 → 加价)
|
||||
- no:无文字,或仅有大字(大字没关系 → 不加价)
|
||||
|
||||
【含人脸】
|
||||
- yes:图中有真实人物面孔(人像照/集体照/证件照/老照片等)
|
||||
- no:无人脸或人脸极小不影响主体
|
||||
|
||||
【风险评估】
|
||||
- none:印花/图案/logo/风景/产品,AI处理效果稳定,可直接报价
|
||||
- low:有人脸但清晰度尚可,AI修复后人脸相似度70-90%,建议先看效果
|
||||
- high:以下任一情况 → 严重模糊的人脸照片/老照片人像/需要打印/客户问能否找回原图
|
||||
high情况下,可做改为partial,备注写明风险话术
|
||||
|
||||
【敏感内容】优先判断,若为 yes 则 可做 必填 no
|
||||
- yes:图片含色情/黄色/擦边/裸露/性暗示/大尺度等违规内容
|
||||
- no:无上述敏感内容
|
||||
|
||||
【可做判断】
|
||||
- yes:效果有把握,可直接处理
|
||||
- partial:能处理但有明显限制(人脸变形风险/分辨率极低/严重损坏)
|
||||
- no:无法处理(纯黑/纯白/完全损坏/找原始RAW文件/敏感内容)
|
||||
|
||||
【风险话术模板(备注字段)】
|
||||
- 含人脸+需打印:AI修复后人脸可能有轻微变化,建议先看效果确认再打印
|
||||
- 严重模糊人脸:这张模糊程度较高,修复后清晰了但人脸可能跟原来有差异
|
||||
- 找原图:找不到原始文件,只能对现有图片做高清修复处理
|
||||
- 完全损坏:这张无法处理
|
||||
|
||||
【透视判断】
|
||||
- no:正面拍摄,无明显变形
|
||||
- mild:轻微透视(衣服悬挂/桌面小角度斜拍)
|
||||
- strong:严重透视(俯拍/贴墙/大角度倾斜)
|
||||
|
||||
【比例选择】
|
||||
- 印花/图案/logo/正方形 -> 1:1
|
||||
- 竖屏壁纸/短视频封面 -> 9:16
|
||||
- 宽屏/横版视频 -> 16:9
|
||||
- 移动广告/Instagram竖图 -> 4:5
|
||||
- 竖向人像/海报/证件照 -> 3:4
|
||||
- 竖向相机照片 -> 2:3
|
||||
- 接近正方形产品图 -> 5:4
|
||||
- 横向标准图/风景 -> 4:3
|
||||
- 横向相机照片/产品实拍 -> 3:2
|
||||
|
||||
示例1(印花,无风险):
|
||||
敏感内容: no
|
||||
平整度: mild
|
||||
含文字: no
|
||||
含人脸: no
|
||||
阴影: no
|
||||
复杂度: complex
|
||||
原因: 印花细节密集颜色层次多
|
||||
主体: 印花图案
|
||||
类型: 印花提取
|
||||
质量: 轻微模糊
|
||||
可做: yes
|
||||
风险: none
|
||||
透视: mild
|
||||
比例: 1:1
|
||||
提示词: 提取衣物印花图案,去除褶皱和背景杂色,补全缺失部分,保持颜色细节100%还原,输出干净平面印花图
|
||||
备注: 无
|
||||
|
||||
示例2(人像老照片,要打印):
|
||||
敏感内容: no
|
||||
平整度: flat
|
||||
含文字: no
|
||||
含人脸: yes
|
||||
阴影: no
|
||||
复杂度: hard
|
||||
原因: 严重模糊人脸细节丢失
|
||||
主体: 人物照片
|
||||
类型: 人像修复
|
||||
质量: 严重模糊
|
||||
可做: partial
|
||||
风险: high
|
||||
透视: no
|
||||
比例: 3:4
|
||||
提示词: 对模糊人像进行高清修复,增强细节,保持人物特征不变
|
||||
备注: AI修复后人脸可能有轻微变化,建议先看效果确认满意再用于打印
|
||||
|
||||
示例3(平整印花,最便宜):
|
||||
敏感内容: no
|
||||
平整度: flat
|
||||
含文字: no
|
||||
含人脸: no
|
||||
阴影: no
|
||||
复杂度: simple
|
||||
原因: 画面平整无褶皱无文字无人脸
|
||||
主体: 印花图案
|
||||
类型: 印花提取
|
||||
质量: 清晰
|
||||
可做: yes
|
||||
风险: none
|
||||
透视: no
|
||||
比例: 1:1
|
||||
提示词: 提取印花图案,去除背景,输出干净平面图
|
||||
备注: 无"""
|
||||
|
||||
|
||||
class ImageAnalyzer:
|
||||
"""图片复杂度分析器"""
|
||||
|
||||
# 同一 URL 5 分钟内复用结果,节省 API 调用
|
||||
_CACHE_TTL_SECONDS = 300
|
||||
_analysis_cache: dict = {} # url -> (result_dict, timestamp)
|
||||
|
||||
PRICE_MAP = {
|
||||
"simple": (10, 15, "画面简单干净"),
|
||||
"normal": (15, 20, "一般复杂度"),
|
||||
"complex": (20, 25, "细节偏多"),
|
||||
"hard": (25, 30, "非常复杂"),
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
|
||||
# 视觉模型,智谱 GLM-4V 系列
|
||||
self.vision_model = os.getenv("VISION_MODEL", "glm-4v-flash")
|
||||
|
||||
def _is_url(self, image_path: str) -> bool:
|
||||
return image_path.startswith("http://") or image_path.startswith("https://")
|
||||
|
||||
def _load_image_base64(self, image_path: str) -> Optional[str]:
|
||||
"""本地图片转 base64"""
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
return base64.b64encode(f.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[ImageAnalyzer] 读取图片失败: {e}")
|
||||
return None
|
||||
|
||||
async def _get_image_size(self, image_path: str) -> Tuple[int, int]:
|
||||
"""获取图片像素尺寸 (width, height),URL 或 本地路径"""
|
||||
try:
|
||||
if self._is_url(image_path):
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(image_path) as resp:
|
||||
if resp.status != 200:
|
||||
return (0, 0)
|
||||
data = await resp.read()
|
||||
from io import BytesIO
|
||||
with Image.open(BytesIO(data)) as img:
|
||||
w, h = img.size
|
||||
return (int(w), int(h))
|
||||
else:
|
||||
with Image.open(image_path) as img:
|
||||
w, h = img.size
|
||||
return (int(w), int(h))
|
||||
except Exception as e:
|
||||
print(f"[ImageAnalyzer] 获取尺寸失败: {e}")
|
||||
return (0, 0)
|
||||
|
||||
# 最短等待时间(秒):即使AI极快返回,也等这么久,看起来像真人在找
|
||||
MIN_WAIT_SECONDS = 4
|
||||
|
||||
async def analyze(self, image_path: str) -> dict:
|
||||
"""
|
||||
异步分析图片复杂度(使用火山引擎 /responses 接口)。
|
||||
实际等待时间 = max(视觉AI响应时间, MIN_WAIT_SECONDS)
|
||||
|
||||
Args:
|
||||
image_path: 图片URL 或 本地路径
|
||||
|
||||
Returns:
|
||||
{
|
||||
"complexity": "simple|normal|complex|hard",
|
||||
"reason": "原因描述",
|
||||
"price_min": 最低报价,
|
||||
"price_max": 最高报价,
|
||||
"price_suggest": 建议报价,
|
||||
"elapsed": 实际耗时秒数,
|
||||
"success": True/False
|
||||
}
|
||||
"""
|
||||
if not self.api_key:
|
||||
await asyncio.sleep(self.MIN_WAIT_SECONDS)
|
||||
return self._fallback("未配置 API Key")
|
||||
|
||||
# 缓存:仅对 URL 生效,本地路径不缓存
|
||||
cache_key = image_path if self._is_url(image_path) else None
|
||||
if cache_key:
|
||||
now = time.monotonic()
|
||||
cached = self._analysis_cache.get(cache_key)
|
||||
if cached:
|
||||
result, cached_at = cached
|
||||
if now - cached_at < self._CACHE_TTL_SECONDS:
|
||||
print(f"[ImageAnalyzer] 缓存命中 | URL 已分析过,跳过 API 调用")
|
||||
result = dict(result)
|
||||
result["elapsed"] = 0
|
||||
return result
|
||||
else:
|
||||
del self._analysis_cache[cache_key]
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
# 构建图片内容
|
||||
if self._is_url(image_path):
|
||||
image_item = {
|
||||
"type": "input_image",
|
||||
"image_url": image_path
|
||||
}
|
||||
else:
|
||||
b64 = self._load_image_base64(image_path)
|
||||
if not b64:
|
||||
await asyncio.sleep(self.MIN_WAIT_SECONDS)
|
||||
return self._fallback("图片读取失败")
|
||||
image_item = {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/jpeg;base64,{b64}"
|
||||
}
|
||||
|
||||
# 使用火山引擎官方 SDK(AsyncOpenAI + /responses 接口)
|
||||
client = AsyncOpenAI(
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
response = await client.responses.create(
|
||||
model=self.vision_model,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
image_item,
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": ANALYSIS_PROMPT
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
content = response.output_text
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
print(f"[ImageAnalyzer] 视觉AI响应耗时: {elapsed:.1f}s")
|
||||
|
||||
await self._wait_remaining(elapsed)
|
||||
|
||||
result = self._parse_result(content)
|
||||
result["elapsed"] = elapsed
|
||||
|
||||
# 计算尺寸与类型加价
|
||||
try:
|
||||
w, h = await self._get_image_size(image_path)
|
||||
mp = round((w * h) / 1_000_000, 2) if w and h else 0.0
|
||||
result["width"] = w
|
||||
result["height"] = h
|
||||
result["megapixels"] = mp
|
||||
|
||||
# 归一化类型
|
||||
subj = (result.get("subject") or "").lower()
|
||||
ptype = (result.get("proc_type") or "").lower()
|
||||
ratio = result.get("aspect_ratio") or "1:1"
|
||||
category = "general"
|
||||
# 初步判断
|
||||
if ("壁纸" in subj) or ("wallpaper" in subj) or ratio in ("9:16", "16:9"):
|
||||
category = "wallpaper"
|
||||
elif ("衣" in subj) or ("服" in subj) or ("印花" in subj) or ("fabric" in subj) or ("cloth" in subj) or ("服装" in subj) or ("印花" in ptype):
|
||||
category = "clothing"
|
||||
elif ("logo" in subj) or ("logo" in ptype):
|
||||
category = "logo"
|
||||
elif ("海报" in subj) or ("poster" in subj):
|
||||
category = "poster"
|
||||
elif ("人像" in subj) or ("人物" in subj) or ("portrait" in subj):
|
||||
category = "portrait"
|
||||
elif ("产品" in subj) or ("product" in subj):
|
||||
category = "product"
|
||||
elif ("老照片" in subj) or ("old photo" in subj):
|
||||
category = "old_photo"
|
||||
# 可印花/印刷物体扩展
|
||||
keywords = subj + " " + ptype
|
||||
if any(k in keywords for k in ["装饰画", "挂画", "油画", "canvas", "painting"]):
|
||||
category = "decor_painting"
|
||||
elif any(k in keywords for k in ["窗帘", "curtain"]):
|
||||
category = "curtain"
|
||||
elif any(k in keywords for k in ["地垫", "脚垫", "地毯", "垫", "mat", "rug"]):
|
||||
category = "floor_mat"
|
||||
elif any(k in keywords for k in ["广告牌", "喷绘", "展架", "灯箱", "banner", "billboard"]):
|
||||
category = "billboard"
|
||||
elif any(k in keywords for k in ["毯子", "毛毯", "blanket"]):
|
||||
category = "blanket"
|
||||
elif any(k in keywords for k in ["桌布", "台布", "tablecloth", "桌旗"]):
|
||||
category = "tablecloth"
|
||||
elif any(k in keywords for k in ["书本", "书籍", "封面", "book", "book cover"]):
|
||||
category = "book"
|
||||
elif any(k in keywords for k in ["鼠标垫", "mouse pad", "mousepad"]):
|
||||
category = "mouse_pad"
|
||||
elif any(k in keywords for k in ["头像", "个人头像", "个人照", "profile", "avatar"]):
|
||||
category = "avatar"
|
||||
result["category"] = category
|
||||
|
||||
surcharge = 0
|
||||
size_note = ""
|
||||
# 按类别设定尺寸要求与加价阈值(单位:百万像素)
|
||||
if category == "wallpaper":
|
||||
if h and h < 1920:
|
||||
size_note = "壁纸高度低于1920px,清晰度可能不足"
|
||||
if mp > 8:
|
||||
surcharge = 10
|
||||
elif mp > 3:
|
||||
surcharge = 5
|
||||
elif category == "clothing":
|
||||
if (w and w < 1024) or (h and h < 1024):
|
||||
size_note = "印花源图边长低于1024px,放大后细节可能不足"
|
||||
if mp > 6:
|
||||
surcharge = 10
|
||||
elif mp > 2:
|
||||
surcharge = 5
|
||||
elif category in ("poster", "portrait", "product"):
|
||||
if mp > 12:
|
||||
surcharge = 10
|
||||
elif mp > 6:
|
||||
surcharge = 5
|
||||
elif category == "logo":
|
||||
if mp > 6:
|
||||
surcharge = 5
|
||||
elif category == "decor_painting":
|
||||
if (w and w < 1500) or (h and h < 1500):
|
||||
size_note = "装饰画边长低于1500px,打印放大可能不够清晰"
|
||||
if mp > 12:
|
||||
surcharge = 10
|
||||
elif mp > 6:
|
||||
surcharge = 5
|
||||
elif category == "curtain":
|
||||
if (w and w < 1500):
|
||||
size_note = "窗帘宽度低于1500px,印花放大可能不够清晰"
|
||||
if mp > 16:
|
||||
surcharge = 10
|
||||
elif mp > 8:
|
||||
surcharge = 5
|
||||
elif category == "floor_mat":
|
||||
if mp > 12:
|
||||
surcharge = 10
|
||||
elif mp > 6:
|
||||
surcharge = 5
|
||||
elif category == "billboard":
|
||||
if (w and w < 2000) or (h and h < 1000):
|
||||
size_note = "广告牌尺寸较小,建议更高分辨率以保证喷绘清晰"
|
||||
if mp > 20:
|
||||
surcharge = 10
|
||||
elif mp > 10:
|
||||
surcharge = 5
|
||||
elif category == "blanket":
|
||||
if mp > 16:
|
||||
surcharge = 10
|
||||
elif mp > 8:
|
||||
surcharge = 5
|
||||
elif category == "tablecloth":
|
||||
if mp > 12:
|
||||
surcharge = 10
|
||||
elif mp > 6:
|
||||
surcharge = 5
|
||||
elif category == "book":
|
||||
if (w and w < 800):
|
||||
size_note = "书本封面宽度低于800px,印刷细节可能不足"
|
||||
if mp > 6:
|
||||
surcharge = 5
|
||||
elif category == "mouse_pad":
|
||||
if (w and w < 1000):
|
||||
size_note = "鼠标垫源图宽度低于1000px,细节可能不足"
|
||||
if mp > 4:
|
||||
surcharge = 5
|
||||
elif category == "avatar":
|
||||
if (w and w < 800) or (h and h < 800):
|
||||
size_note = "头像边长低于800px,清晰度可能不足"
|
||||
if mp > 6:
|
||||
surcharge = 5
|
||||
else:
|
||||
if mp > 8:
|
||||
surcharge = 10
|
||||
elif mp > 4:
|
||||
surcharge = 5
|
||||
|
||||
# 应用加价,保持5的整数倍与 10-30 区间
|
||||
base = result.get("price_suggest", 20)
|
||||
adjusted = base + surcharge
|
||||
adjusted = max(10, min(30, adjusted))
|
||||
adjusted = round(adjusted / 5) * 5
|
||||
# 同步范围
|
||||
result["price_suggest"] = adjusted
|
||||
result["price_max"] = max(result["price_max"], adjusted)
|
||||
result["size_surcharge"] = surcharge
|
||||
result["size_note"] = size_note
|
||||
except Exception as e:
|
||||
print(f"[ImageAnalyzer] 尺寸与类型加价计算失败: {e}")
|
||||
|
||||
# 写入缓存
|
||||
if cache_key:
|
||||
self._analysis_cache[cache_key] = (dict(result), time.monotonic())
|
||||
# 简单清理:缓存超过 50 条时删最旧的
|
||||
if len(self._analysis_cache) > 50:
|
||||
oldest = min(self._analysis_cache.items(), key=lambda x: x[1][1])
|
||||
del self._analysis_cache[oldest[0]]
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
elapsed = time.monotonic() - start
|
||||
print(f"[ImageAnalyzer] 请求超时 ({elapsed:.1f}s)")
|
||||
return self._fallback("请求超时")
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start
|
||||
print(f"[ImageAnalyzer] 分析失败: {e}")
|
||||
await self._wait_remaining(elapsed)
|
||||
return self._fallback(str(e))
|
||||
|
||||
async def _wait_remaining(self, elapsed: float):
|
||||
"""补足最短等待时间"""
|
||||
remaining = self.MIN_WAIT_SECONDS - elapsed
|
||||
if remaining > 0:
|
||||
await asyncio.sleep(remaining)
|
||||
|
||||
def _parse_line(self, content: str, *keys: str) -> str:
|
||||
"""从多行文本中提取指定字段值,支持中英文冒号"""
|
||||
for line in content.strip().split("\n"):
|
||||
line = line.strip()
|
||||
for key in keys:
|
||||
if line.startswith(key):
|
||||
return line.split(":", 1)[-1].split(":", 1)[-1].strip()
|
||||
return ""
|
||||
|
||||
def _parse_result(self, content: str) -> dict:
|
||||
"""解析模型返回的结果"""
|
||||
p = self._parse_line
|
||||
|
||||
# 复杂度
|
||||
complexity_raw = p(content, "复杂度:", "复杂度:").lower()
|
||||
complexity = complexity_raw if complexity_raw in self.PRICE_MAP else "normal"
|
||||
|
||||
sensitive = p(content, "敏感内容:", "敏感内容:").lower().strip()
|
||||
flatness = p(content, "平整度:", "平整度:").lower().strip() # flat|mild|rough
|
||||
has_text = p(content, "含文字:", "含文字:").lower().strip()
|
||||
has_face = p(content, "含人脸:", "含人脸:").lower().strip()
|
||||
has_shadow = p(content, "阴影:", "阴影:").lower().strip()
|
||||
reason = p(content, "原因:", "原因:")
|
||||
subject = p(content, "主体:", "主体:")
|
||||
proc_type = p(content, "类型:", "类型:")
|
||||
quality = p(content, "质量:", "质量:")
|
||||
feasibility = p(content, "可做:", "可做:").lower()
|
||||
risk = p(content, "风险:", "风险:").lower().strip()
|
||||
perspective = p(content, "透视:", "透视:").lower().strip()
|
||||
aspect_ratio = p(content, "比例:", "比例:").strip()
|
||||
gemini_prompt= p(content, "提示词:", "提示词:")
|
||||
note = p(content, "备注:", "备注:")
|
||||
|
||||
if has_face not in ("yes", "no"):
|
||||
has_face = "no"
|
||||
if risk not in ("none", "low", "high"):
|
||||
risk = "none"
|
||||
if perspective not in ("no", "mild", "strong"):
|
||||
perspective = "no"
|
||||
|
||||
# 校验比例合法性
|
||||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||||
if aspect_ratio not in valid_ratios:
|
||||
aspect_ratio = "1:1" # 默认正方形
|
||||
|
||||
price_min, price_max, default_reason = self.PRICE_MAP[complexity]
|
||||
if not reason:
|
||||
reason = default_reason
|
||||
if feasibility not in ("yes", "partial", "no"):
|
||||
feasibility = "yes"
|
||||
|
||||
# 建议报价:complex/hard 取固定值,simple/normal 取中间,且必须为5的整数倍
|
||||
raw = price_max if complexity in ("complex", "hard") else (price_min + price_max) // 2
|
||||
price_suggest = round(raw / 5) * 5
|
||||
|
||||
if sensitive == "yes":
|
||||
feasibility = "no"
|
||||
note = "图片含敏感内容,不接单"
|
||||
risk_label = {"none": "无风险", "low": "低风险", "high": "高风险"}.get(risk, "")
|
||||
sens_tag = " | 敏感:是" if sensitive == "yes" else ""
|
||||
print(f"[ImageAnalyzer] 识别结果: {complexity} | {reason} | 建议报价: {price_suggest}元{sens_tag}")
|
||||
print(f"[ImageAnalyzer] 主体: {subject} | 类型: {proc_type} | 质量: {quality} | 平整度: {flatness} | 含文字: {has_text} | 含人脸: {has_face} | 阴影: {has_shadow} | 风险: {risk_label} | 透视: {perspective} | 比例: {aspect_ratio} | 可做: {feasibility}")
|
||||
if gemini_prompt:
|
||||
print(f"[ImageAnalyzer] Gemini提示词: {gemini_prompt}")
|
||||
if note and note not in ("无", ""):
|
||||
print(f"[ImageAnalyzer] 备注: {note}")
|
||||
|
||||
return {
|
||||
"complexity": complexity,
|
||||
"reason": reason,
|
||||
"subject": subject,
|
||||
"proc_type": proc_type,
|
||||
"quality": quality,
|
||||
"flatness": flatness if flatness in ("flat", "mild", "rough") else "",
|
||||
"has_text": has_text if has_text in ("yes", "no") else "no",
|
||||
"has_face": has_face, # yes / no
|
||||
"has_shadow": has_shadow if has_shadow in ("yes", "no") else "no",
|
||||
"risk": risk, # none / low / high
|
||||
"feasibility": feasibility,
|
||||
"perspective": perspective,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"gemini_prompt": gemini_prompt,
|
||||
"note": note,
|
||||
"price_min": price_min,
|
||||
"price_max": price_max,
|
||||
"price_suggest": price_suggest,
|
||||
"success": True
|
||||
}
|
||||
|
||||
def _fallback(self, reason: str) -> dict:
|
||||
"""识别失败时的默认结果(返回 normal,让人工判断)"""
|
||||
print(f"[ImageAnalyzer] 识别失败,使用默认值: {reason}")
|
||||
return {
|
||||
"complexity": "normal",
|
||||
"reason": reason,
|
||||
"subject": "",
|
||||
"proc_type": "",
|
||||
"quality": "",
|
||||
"flatness": "",
|
||||
"has_text": "no",
|
||||
"has_face": "no",
|
||||
"has_shadow": "no",
|
||||
"risk": "none",
|
||||
"feasibility": "yes",
|
||||
"perspective": "no",
|
||||
"aspect_ratio": "1:1",
|
||||
"gemini_prompt": "",
|
||||
"note": "",
|
||||
"price_min": 20,
|
||||
"price_max": 30,
|
||||
"price_suggest": 25,
|
||||
"success": False
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
image_analyzer = ImageAnalyzer()
|
||||
47
image/image_precheck.py
Normal file
47
image/image_precheck.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
图片预检 - 下载后检查尺寸/格式/是否损坏,不合格直接拒单
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 可配置
|
||||
MIN_WIDTH = int(os.getenv("IMAGE_PRECHECK_MIN_WIDTH", "50"))
|
||||
MIN_HEIGHT = int(os.getenv("IMAGE_PRECHECK_MIN_HEIGHT", "50"))
|
||||
MAX_WIDTH = int(os.getenv("IMAGE_PRECHECK_MAX_WIDTH", "8000"))
|
||||
MAX_HEIGHT = int(os.getenv("IMAGE_PRECHECK_MAX_HEIGHT", "8000"))
|
||||
MIN_SIZE = int(os.getenv("IMAGE_PRECHECK_MIN_BYTES", "100")) # 至少 100 字节
|
||||
MAX_SIZE = int(os.getenv("IMAGE_PRECHECK_MAX_BYTES", "0")) # 0=不限制
|
||||
SUPPORTED_FORMATS = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
|
||||
|
||||
def precheck(local_path: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
预检图片文件。
|
||||
|
||||
Returns:
|
||||
(ok, message) - ok=False 时 message 为拒单原因
|
||||
"""
|
||||
if not os.path.exists(local_path):
|
||||
return False, "图片文件不存在"
|
||||
size = os.path.getsize(local_path)
|
||||
if size < MIN_SIZE:
|
||||
return False, f"图片太小({size} 字节),可能损坏或格式异常"
|
||||
if MAX_SIZE > 0 and size > MAX_SIZE:
|
||||
return False, f"图片过大({size/1024/1024:.1f}MB),超过 {MAX_SIZE/1024/1024:.0f}MB 限制"
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
with Image.open(local_path) as img:
|
||||
w, h = img.size
|
||||
if w < MIN_WIDTH or h < MIN_HEIGHT:
|
||||
return False, f"图片尺寸过小({w}x{h}),最小 {MIN_WIDTH}x{MIN_HEIGHT}"
|
||||
if w > MAX_WIDTH or h > MAX_HEIGHT:
|
||||
return False, f"图片尺寸过大({w}x{h}),最大 {MAX_WIDTH}x{MAX_HEIGHT}"
|
||||
img.verify()
|
||||
except Exception as e:
|
||||
return False, f"图片无法读取或已损坏:{str(e)[:50]}"
|
||||
return True, ""
|
||||
328
image/image_processor.py
Normal file
328
image/image_processor.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""图片处理模块 - 调用 Gemini 作图API,含质检与自动重试"""
|
||||
import os
|
||||
import uuid
|
||||
import tempfile
|
||||
from typing import Optional, Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_OUTPUT_DIR = os.getenv("RESULT_IMAGE_DIR", "results")
|
||||
_MAX_RETRIES = int(os.getenv("PROCESS_MAX_RETRIES", "2")) # 含首次共最多处理几次
|
||||
|
||||
|
||||
class ImageProcessor:
|
||||
"""图片处理 - 对接 GeminiExtractV2Service,含质检与重试"""
|
||||
|
||||
def __init__(self):
|
||||
os.makedirs(_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# ─── 内部工具 ────────────────────────────────────────────
|
||||
|
||||
async def _download(self, url: str) -> str:
|
||||
"""下载图片到临时文件,返回本地路径"""
|
||||
import aiohttp
|
||||
tmp = os.path.join(tempfile.gettempdir(), f"gemini_in_{uuid.uuid4().hex}.jpg")
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/122.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Referer": "https://www.taobao.com/",
|
||||
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
|
||||
}
|
||||
async with aiohttp.ClientSession(headers=headers) as session:
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"下载图片失败: HTTP {resp.status}")
|
||||
with open(tmp, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return tmp
|
||||
|
||||
async def _do_perspective(self, service, src: str, level: str) -> str:
|
||||
"""透视矫正,返回矫正后文件路径(失败则返回原路径)"""
|
||||
out = os.path.join(tempfile.gettempdir(), f"gemini_persp_{uuid.uuid4().hex}.jpg")
|
||||
ok, msg, _ = await service.correct_perspective(src, out, level=level)
|
||||
if ok:
|
||||
print(f"[ImageProcessor] 透视矫正完成")
|
||||
return out
|
||||
else:
|
||||
print(f"[ImageProcessor] 透视矫正失败 ({msg}),跳过")
|
||||
if os.path.exists(out):
|
||||
os.remove(out)
|
||||
return src
|
||||
|
||||
@staticmethod
|
||||
def _build_retry_prompt(gemini_prompt: str, qa_issue: str, qa_suggestion: str) -> str:
|
||||
"""
|
||||
根据 QA 质检问题类型,智能调整重试提示词。
|
||||
比简单追加建议更有针对性,让 Gemini 知道上次哪里出了问题。
|
||||
"""
|
||||
base = gemini_prompt or ""
|
||||
issue = (qa_issue or "").lower()
|
||||
suggestion = qa_suggestion if qa_suggestion and qa_suggestion != "无" else ""
|
||||
|
||||
# 背景不干净
|
||||
if any(kw in issue for kw in ["背景", "杂物", "多余", "白色不纯"]):
|
||||
prefix = "【重要:背景必须是纯白色 #FFFFFF,去掉所有杂物和阴影】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 清晰度/细节不足
|
||||
if any(kw in issue for kw in ["模糊", "清晰", "细节", "锐化", "分辨率"]):
|
||||
prefix = "【重要:提升整体清晰度和细节,输出高分辨率版本,不要模糊】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 内容缺失/截断
|
||||
if any(kw in issue for kw in ["缺失", "截断", "不完整", "边缘", "裁剪"]):
|
||||
prefix = "【重要:保留主体完整内容,不要截断边缘,确保四角全部保留】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 颜色偏差
|
||||
if any(kw in issue for kw in ["颜色", "色彩", "偏色", "色调"]):
|
||||
prefix = "【重要:忠实还原原图颜色,不要改变色调或过度饱和】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# AI幻觉/变形
|
||||
if any(kw in issue for kw in ["幻觉", "变形", "失真", "扭曲", "ai生成"]):
|
||||
prefix = "【重要:严格按原图内容处理,不要添加或改变任何图案细节】"
|
||||
return prefix + ("\n" + base if base else "")
|
||||
|
||||
# 没有匹配到具体类型,直接用质检建议
|
||||
if suggestion:
|
||||
return (base + f"\n【上次问题:{qa_issue}。本次改进方向:{suggestion}】").strip()
|
||||
|
||||
return base
|
||||
|
||||
async def _do_main(self, service, src: str, gemini_prompt: str, aspect_ratio: str,
|
||||
attempt: int, qa_issue: str = "", qa_suggestion: str = "") -> tuple[bool, str, str]:
|
||||
"""
|
||||
执行一次主处理。
|
||||
重试时根据 QA 问题类型智能调整提示词。
|
||||
|
||||
Returns:
|
||||
(success, output_path, message)
|
||||
"""
|
||||
out_name = f"result_{uuid.uuid4().hex}.jpg"
|
||||
output_path = os.path.join(_OUTPUT_DIR, out_name)
|
||||
|
||||
if attempt == 1:
|
||||
prompt = gemini_prompt or None
|
||||
else:
|
||||
prompt = self._build_retry_prompt(gemini_prompt, qa_issue, qa_suggestion)
|
||||
print(f"[ImageProcessor] 重试策略 | 问题: {qa_issue} | 提示词: {(prompt or '')[:80]}...")
|
||||
|
||||
print(f"[ImageProcessor] 主处理第 {attempt} 次 (比例={aspect_ratio})...")
|
||||
success, message, _ = await service.extract_pattern(
|
||||
input_path=src,
|
||||
output_path=output_path,
|
||||
custom_prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
return success, output_path, message
|
||||
|
||||
# ─── 主入口 ──────────────────────────────────────────────
|
||||
|
||||
async def process_image(
|
||||
self,
|
||||
image_url: str,
|
||||
operation: str,
|
||||
requirements: str = "",
|
||||
gemini_prompt: str = "",
|
||||
aspect_ratio: str = "1:1",
|
||||
perspective: str = "no",
|
||||
proc_type: str = "",
|
||||
subject: str = "",
|
||||
quality: str = "",
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
完整处理流程:下载 → 透视矫正(可选)→ Gemini主处理 → 质检 → 重试(可选)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"result_path": str,
|
||||
"message": str,
|
||||
"qa_score": int, # 质检得分 0-100
|
||||
"qa_pass": bool, # 是否通过质检
|
||||
"qa_issue": str, # 质检发现的问题
|
||||
"attempts": int, # 共处理了几次
|
||||
}
|
||||
"""
|
||||
from services.service_gemini import GeminiExtractV2Service
|
||||
from image.image_qa import image_qa
|
||||
|
||||
# Step 1: 下载原图
|
||||
try:
|
||||
tmp_input = await self._download(image_url)
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False, "result_path": "", "message": str(e),
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "下载失败", "attempts": 0,
|
||||
}
|
||||
|
||||
# Step 1.5: 敏感图片检测
|
||||
try:
|
||||
from utils.content_filter import is_sensitive_image
|
||||
sensitive, reason = await is_sensitive_image(tmp_input)
|
||||
if sensitive:
|
||||
if os.path.exists(tmp_input):
|
||||
os.remove(tmp_input)
|
||||
return {
|
||||
"success": False, "result_path": "", "message": reason,
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "敏感图片", "attempts": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"[ImageProcessor] 敏感图片检测异常: {e},继续处理")
|
||||
|
||||
# Step 1.6: 预检(尺寸/格式/损坏)
|
||||
try:
|
||||
from image.image_precheck import precheck
|
||||
ok, msg = precheck(tmp_input)
|
||||
if not ok:
|
||||
if os.path.exists(tmp_input):
|
||||
os.remove(tmp_input)
|
||||
return {
|
||||
"success": False, "result_path": "", "message": msg,
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "预检不通过", "attempts": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"[ImageProcessor] 预检异常: {e},继续处理")
|
||||
|
||||
service = GeminiExtractV2Service()
|
||||
tmp_files = [tmp_input]
|
||||
try:
|
||||
# Step 2: 透视矫正
|
||||
current_input = tmp_input
|
||||
if perspective in ("mild", "strong"):
|
||||
print(f"[ImageProcessor] 透视矫正中 (level={perspective})...")
|
||||
corrected = await self._do_perspective(service, tmp_input, perspective)
|
||||
if corrected != tmp_input:
|
||||
tmp_files.append(corrected)
|
||||
current_input = corrected
|
||||
|
||||
# Step 3: 主处理 + 质检,最多 _MAX_RETRIES 次
|
||||
qa_result = {"score": 0, "pass": False, "issue": "未质检", "suggestion": "无"}
|
||||
output_path = ""
|
||||
last_message = ""
|
||||
qa_issue = ""
|
||||
qa_suggestion = ""
|
||||
|
||||
for attempt in range(1, _MAX_RETRIES + 1):
|
||||
ok, output_path, last_message = await self._do_main(
|
||||
service, current_input, gemini_prompt, aspect_ratio,
|
||||
attempt=attempt, qa_issue=qa_issue, qa_suggestion=qa_suggestion,
|
||||
)
|
||||
|
||||
if not ok:
|
||||
print(f"[ImageProcessor] 第 {attempt} 次处理失败: {last_message}")
|
||||
if attempt < _MAX_RETRIES:
|
||||
continue
|
||||
return {
|
||||
"success": False, "result_path": "", "message": last_message,
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": "Gemini处理失败", "attempts": attempt,
|
||||
}
|
||||
|
||||
# Step 4: 质检
|
||||
print(f"[ImageProcessor] 质检中 (第 {attempt} 次结果)...")
|
||||
qa_result = await image_qa.check(
|
||||
original_path=current_input,
|
||||
result_path=output_path,
|
||||
proc_type=proc_type,
|
||||
subject=subject,
|
||||
quality=quality,
|
||||
gemini_prompt=gemini_prompt,
|
||||
)
|
||||
qa_issue = qa_result.get("issue", "")
|
||||
qa_suggestion = qa_result.get("suggestion", "无")
|
||||
|
||||
if qa_result["pass"]:
|
||||
print(f"[ImageProcessor] 质检通过 ({qa_result['score']}分),共处理 {attempt} 次")
|
||||
break
|
||||
else:
|
||||
print(f"[ImageProcessor] 质检不合格 ({qa_result['score']}分),问题: {qa_result['issue']}")
|
||||
if attempt < _MAX_RETRIES:
|
||||
# 清理这次不合格的结果
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path)
|
||||
print(f"[ImageProcessor] 准备第 {attempt + 1} 次重试...")
|
||||
else:
|
||||
print(f"[ImageProcessor] 已达最大重试次数 {_MAX_RETRIES},保留最后结果,人工跟进")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result_path": output_path,
|
||||
"message": last_message,
|
||||
"qa_score": qa_result.get("score", 0),
|
||||
"qa_pass": qa_result.get("pass", False),
|
||||
"qa_issue": qa_result.get("issue", ""),
|
||||
"attempts": attempt,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False, "result_path": "", "message": f"处理异常: {e}",
|
||||
"qa_score": 0, "qa_pass": False, "qa_issue": str(e), "attempts": 0,
|
||||
}
|
||||
finally:
|
||||
await service.cleanup()
|
||||
for f in tmp_files:
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
async def enhance(self, image_url: str) -> Dict[str, Any]:
|
||||
return await self.process_image(image_url, "enhance")
|
||||
|
||||
async def remove_bg(self, image_url: str) -> Dict[str, Any]:
|
||||
return await self.process_image(image_url, "remove_bg")
|
||||
|
||||
async def resize(self, image_url: str, width: int, height: int = 0) -> Dict[str, Any]:
|
||||
"""
|
||||
改尺寸:下载图片(或读取本地路径),按指定宽高缩放,保存到 results/。
|
||||
|
||||
Args:
|
||||
image_url: 图片 URL 或本地路径
|
||||
width: 目标宽度(像素)
|
||||
height: 目标高度(0=按宽度等比缩放)
|
||||
|
||||
Returns:
|
||||
{"success": bool, "result_path": str, "message": str}
|
||||
"""
|
||||
from PIL import Image
|
||||
is_temp = image_url.startswith(("http://", "https://"))
|
||||
try:
|
||||
if is_temp:
|
||||
tmp = await self._download(image_url)
|
||||
else:
|
||||
tmp = image_url
|
||||
if not os.path.exists(tmp):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {tmp}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
|
||||
try:
|
||||
img = Image.open(tmp).convert("RGB")
|
||||
w_orig, h_orig = img.size
|
||||
if width <= 0 or width > 10000:
|
||||
return {"success": False, "result_path": "", "message": f"宽度无效: {width}"}
|
||||
if height == 0:
|
||||
ratio = width / w_orig
|
||||
height = int(h_orig * ratio)
|
||||
elif height <= 0 or height > 10000:
|
||||
return {"success": False, "result_path": "", "message": f"高度无效: {height}"}
|
||||
resized = img.resize((width, height), Image.Resampling.LANCZOS)
|
||||
out_name = f"resize_{uuid.uuid4().hex}.jpg"
|
||||
out_path = os.path.join(_OUTPUT_DIR, out_name)
|
||||
resized.save(out_path, "JPEG", quality=95)
|
||||
print(f"[ImageProcessor] 改尺寸完成: {w_orig}x{h_orig} → {width}x{height}")
|
||||
return {"success": True, "result_path": out_path, "message": f"已改为 {width}x{height}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if is_temp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
# 全局实例
|
||||
image_processor = ImageProcessor()
|
||||
189
image/image_qa.py
Normal file
189
image/image_qa.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
图片处理结果质检模块
|
||||
|
||||
处理完成后,用视觉 AI 对比原图和结果图,判断是否符合客户需求。
|
||||
评分 0-100,低于阈值则判定不合格,触发重试或人工跟进。
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_QA_PASS_SCORE = int(os.getenv("QA_PASS_SCORE", "70")) # 合格分数线,默认70
|
||||
|
||||
QA_PROMPT_TEMPLATE = """\
|
||||
你是一名专业的图片处理质检员,需要评估处理结果是否满足要求。
|
||||
|
||||
【处理类型】{proc_type}
|
||||
【客户需求/Gemini提示词】{gemini_prompt}
|
||||
【原图描述】主体:{subject},类型:{proc_type},质量:{quality}
|
||||
|
||||
请对比左图(原图)和右图(处理结果),从以下维度打分(每项0-25分):
|
||||
|
||||
1. 内容完整性:主体图案/内容是否完整保留,有无缺失、截断
|
||||
2. 畸变去除:褶皱/透视变形/背景是否已被清除
|
||||
3. 细节还原:颜色、线条、纹理等细节与原图的匹配程度
|
||||
4. 输出干净度:背景是否干净,有无多余内容、AI幻觉、模糊块
|
||||
|
||||
输出格式(严格按照此格式,每行一个字段):
|
||||
完整性: <0-25>
|
||||
畸变: <0-25>
|
||||
细节: <0-25>
|
||||
干净: <0-25>
|
||||
总分: <0-100>
|
||||
结论: <pass|fail>
|
||||
问题: <简述主要问题,不超过30字,无问题填"无">
|
||||
建议: <如果fail,给出重试改进建议,不超过40字,pass则填"无">
|
||||
"""
|
||||
|
||||
|
||||
class ImageQA:
|
||||
"""处理结果质检器"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_BASE_URL", "https://open.bigmodel.cn/api/paas/v4")
|
||||
self.model = os.getenv("VISION_MODEL", "glm-4v-flash")
|
||||
self.pass_score = _QA_PASS_SCORE
|
||||
|
||||
def _to_base64(self, path: str) -> Optional[str]:
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
return base64.b64encode(f.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"[ImageQA] 读取图片失败 {path}: {e}")
|
||||
return None
|
||||
|
||||
def _parse(self, text: str) -> dict:
|
||||
def p(key):
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
for k in [f"{key}:", f"{key}:"]:
|
||||
if line.startswith(k):
|
||||
return line[len(k):].strip()
|
||||
return ""
|
||||
|
||||
try:
|
||||
score = int(p("总分"))
|
||||
except ValueError:
|
||||
score = 0
|
||||
|
||||
conclusion = p("结论").lower()
|
||||
if conclusion not in ("pass", "fail"):
|
||||
conclusion = "pass" if score >= self.pass_score else "fail"
|
||||
|
||||
return {
|
||||
"score": score,
|
||||
"pass": conclusion == "pass",
|
||||
"issue": p("问题"),
|
||||
"suggestion": p("建议"),
|
||||
"detail": {
|
||||
"completeness": p("完整性"),
|
||||
"distortion": p("畸变"),
|
||||
"detail": p("细节"),
|
||||
"clean": p("干净"),
|
||||
},
|
||||
"raw": text,
|
||||
}
|
||||
|
||||
async def check(
|
||||
self,
|
||||
original_path: str,
|
||||
result_path: str,
|
||||
proc_type: str = "",
|
||||
subject: str = "",
|
||||
quality: str = "",
|
||||
gemini_prompt: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
质检处理结果。
|
||||
|
||||
Args:
|
||||
original_path: 原图本地路径
|
||||
result_path: 处理结果本地路径
|
||||
proc_type: 处理类型(印花提取 / 高清修复等)
|
||||
subject: 主体描述
|
||||
quality: 原图质量
|
||||
gemini_prompt: 传给 Gemini 的提示词(体现客户需求)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"score": int, # 0-100
|
||||
"pass": bool, # 是否合格
|
||||
"issue": str, # 主要问题
|
||||
"suggestion": str, # 重试改进建议
|
||||
"detail": dict, # 各维度分数
|
||||
}
|
||||
"""
|
||||
if not self.api_key:
|
||||
print("[ImageQA] 未配置 API Key,跳过质检,默认通过")
|
||||
return {"score": 80, "pass": True, "issue": "无", "suggestion": "无", "detail": {}}
|
||||
|
||||
orig_b64 = self._to_base64(original_path)
|
||||
result_b64 = self._to_base64(result_path)
|
||||
if not orig_b64 or not result_b64:
|
||||
print("[ImageQA] 图片读取失败,跳过质检")
|
||||
return {"score": 75, "pass": True, "issue": "质检图片读取失败", "suggestion": "无", "detail": {}}
|
||||
|
||||
prompt = QA_PROMPT_TEMPLATE.format(
|
||||
proc_type=proc_type or "图片处理",
|
||||
subject=subject or "未知",
|
||||
quality=quality or "未知",
|
||||
gemini_prompt=gemini_prompt or "按标准处理",
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
client = AsyncOpenAI(base_url=self.base_url, api_key=self.api_key)
|
||||
|
||||
response = await client.responses.create(
|
||||
model=self.model,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/jpeg;base64,{orig_b64}",
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/jpeg;base64,{result_b64}",
|
||||
},
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
content = response.output_text
|
||||
elapsed = time.monotonic() - start
|
||||
result = self._parse(content)
|
||||
result["elapsed"] = round(elapsed, 1)
|
||||
|
||||
status = "✓ 合格" if result["pass"] else "✗ 不合格"
|
||||
print(f"[ImageQA] {status} | 得分: {result['score']}/100 | 问题: {result['issue']} | 耗时: {elapsed:.1f}s")
|
||||
if not result["pass"]:
|
||||
print(f"[ImageQA] 改进建议: {result['suggestion']}")
|
||||
try:
|
||||
from utils.api_cost_tracker import record
|
||||
record("gemini_vision", count=1)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start
|
||||
print(f"[ImageQA] 质检失败 ({elapsed:.1f}s): {e}")
|
||||
return {"score": 75, "pass": True, "issue": f"质检异常: {e}", "suggestion": "无", "detail": {}}
|
||||
|
||||
|
||||
# 全局实例
|
||||
image_qa = ImageQA()
|
||||
293
image/image_tools.py
Normal file
293
image/image_tools.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
图片处理独立工具 - 可单独调用,也可被主流程复用。
|
||||
|
||||
主流程(付款触发)不变,这些工具供 AI 按需组合使用。
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
import tempfile
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
_OUTPUT_DIR = os.getenv("RESULT_IMAGE_DIR", "results")
|
||||
os.makedirs(_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
|
||||
async def _download(url: str) -> str:
|
||||
"""下载图片到临时文件"""
|
||||
import aiohttp
|
||||
tmp = os.path.join(tempfile.gettempdir(), f"img_{uuid.uuid4().hex}.jpg")
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
"Referer": "https://www.taobao.com/",
|
||||
}
|
||||
async with aiohttp.ClientSession(headers=headers) as session:
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"下载失败: HTTP {resp.status}")
|
||||
with open(tmp, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return tmp
|
||||
|
||||
|
||||
async def remove_background(image_url: str, save_path: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】去背景 → 纯白/纯色背景。
|
||||
输入 URL 或本地路径,输出白底产品图。
|
||||
"""
|
||||
from image.perspective_fix import _gemini_call, PROMPT_WHITE_BG
|
||||
tmp = None
|
||||
try:
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
tmp = await _download(image_url)
|
||||
src = tmp
|
||||
else:
|
||||
src = image_url
|
||||
if not os.path.exists(src):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {src}"}
|
||||
|
||||
out = save_path or os.path.join(_OUTPUT_DIR, f"bg_{uuid.uuid4().hex}.jpg")
|
||||
ok = await _gemini_call(src, out, PROMPT_WHITE_BG, aspect_ratio="auto", label="去背景")
|
||||
if ok:
|
||||
return {"success": True, "result_path": out, "message": "去背景完成"}
|
||||
return {"success": False, "result_path": "", "message": "去背景失败"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
async def perspective_correct(image_url: str, save_path: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】透视矫正。
|
||||
输入需为白底图(可先调 remove_background),输出展平后的图。
|
||||
"""
|
||||
import cv2
|
||||
from image.perspective_fix import find_quad, four_point_transform
|
||||
tmp = None
|
||||
try:
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
tmp = await _download(image_url)
|
||||
src = tmp
|
||||
else:
|
||||
src = image_url
|
||||
if not os.path.exists(src):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {src}"}
|
||||
|
||||
img = cv2.imread(src)
|
||||
if img is None:
|
||||
return {"success": False, "result_path": "", "message": "无法读取图片"}
|
||||
pts = find_quad(img)
|
||||
if pts is None:
|
||||
return {"success": False, "result_path": "", "message": "未检测到四边形,无法透视矫正"}
|
||||
warped = four_point_transform(img, pts)
|
||||
out = save_path or os.path.join(_OUTPUT_DIR, f"persp_{uuid.uuid4().hex}.jpg")
|
||||
cv2.imwrite(out, warped, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
return {"success": True, "result_path": out, "message": "透视矫正完成"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
async def extract_pattern(image_url: str, prompt: str = "", aspect_ratio: str = "1:1",
|
||||
save_path: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】印花提取/主处理。
|
||||
按提示词和比例输出处理后的图。
|
||||
"""
|
||||
from services.service_gemini import GeminiExtractV2Service
|
||||
tmp = None
|
||||
try:
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
tmp = await _download(image_url)
|
||||
src = tmp
|
||||
else:
|
||||
src = image_url
|
||||
if not os.path.exists(src):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {src}"}
|
||||
|
||||
out = save_path or os.path.join(_OUTPUT_DIR, f"extract_{uuid.uuid4().hex}.jpg")
|
||||
service = GeminiExtractV2Service()
|
||||
try:
|
||||
ok, msg, _ = await service.extract_pattern(
|
||||
input_path=src, output_path=out,
|
||||
custom_prompt=prompt or None, aspect_ratio=aspect_ratio,
|
||||
)
|
||||
if ok and os.path.exists(out):
|
||||
return {"success": True, "result_path": out, "message": "提取完成"}
|
||||
return {"success": False, "result_path": "", "message": msg or "提取失败"}
|
||||
finally:
|
||||
await service.cleanup()
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
async def enhance_image(image_url: str, save_path: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】高清增强。
|
||||
使用 Qwen RunningHub,失败时降级 Gemini。
|
||||
"""
|
||||
from services.service_qwen import 清晰化_api
|
||||
from image.perspective_fix import _gemini_call, PROMPT_ENHANCE_SIMPLE
|
||||
tmp = None
|
||||
try:
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
tmp = await _download(image_url)
|
||||
src = tmp
|
||||
else:
|
||||
src = image_url
|
||||
if not os.path.exists(src):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {src}"}
|
||||
|
||||
out = save_path or os.path.join(_OUTPUT_DIR, f"enh_{uuid.uuid4().hex}.jpg")
|
||||
ok = await 清晰化_api(img_path=src, save_path=out)
|
||||
if not ok:
|
||||
ok = await _gemini_call(src, out, PROMPT_ENHANCE_SIMPLE, aspect_ratio="auto", label="增强")
|
||||
if ok:
|
||||
return {"success": True, "result_path": out, "message": "高清增强完成"}
|
||||
return {"success": False, "result_path": "", "message": "高清增强失败"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
async def color_match_images(orig_url: str, result_url: str, save_path: str = "",
|
||||
strength: float = 0.75) -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】颜色匹配。将 result 的色调匹配到 orig。
|
||||
"""
|
||||
import cv2
|
||||
from image.perspective_fix import _color_match
|
||||
tmp_orig = tmp_result = None
|
||||
try:
|
||||
if orig_url.startswith(("http://", "https://")):
|
||||
tmp_orig = await _download(orig_url)
|
||||
orig_path = tmp_orig
|
||||
else:
|
||||
orig_path = orig_url
|
||||
if result_url.startswith(("http://", "https://")):
|
||||
tmp_result = await _download(result_url)
|
||||
result_path = tmp_result
|
||||
else:
|
||||
result_path = result_url
|
||||
|
||||
orig_img = cv2.imread(orig_path)
|
||||
result_img = cv2.imread(result_path)
|
||||
if orig_img is None or result_img is None:
|
||||
return {"success": False, "result_path": "", "message": "图片读取失败"}
|
||||
matched = _color_match(orig_img, result_img, strength=strength)
|
||||
out = save_path or os.path.join(_OUTPUT_DIR, f"color_{uuid.uuid4().hex}.jpg")
|
||||
cv2.imwrite(out, matched, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
return {"success": True, "result_path": out, "message": f"颜色匹配完成(强度{strength:.0%})"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
for t in (tmp_orig, tmp_result):
|
||||
if t and os.path.exists(t):
|
||||
os.remove(t)
|
||||
|
||||
|
||||
async def trim_border(image_url: str, save_path: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】裁切四周背景边(支持任意颜色:白/黄/米等)。
|
||||
"""
|
||||
import cv2
|
||||
from image.perspective_fix import tool_trim_white_border
|
||||
tmp = None
|
||||
try:
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
tmp = await _download(image_url)
|
||||
src = tmp
|
||||
else:
|
||||
src = image_url
|
||||
if not os.path.exists(src):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {src}"}
|
||||
|
||||
img = cv2.imread(src)
|
||||
if img is None:
|
||||
return {"success": False, "result_path": "", "message": "无法读取图片"}
|
||||
trimmed, did_trim, info = tool_trim_white_border(img)
|
||||
out = save_path or os.path.join(_OUTPUT_DIR, f"trim_{uuid.uuid4().hex}.jpg")
|
||||
cv2.imwrite(out, trimmed, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
return {"success": True, "result_path": out, "message": "裁边完成" if did_trim else "无需裁边"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
async def vectorize_to_eps(image_url: str, save_path: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】矢量化 - 将图片转为 EPS 矢量文件。
|
||||
客户要做矢量图、转 EPS、转 AI 格式时调用。
|
||||
"""
|
||||
tmp = None
|
||||
try:
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
tmp = await _download(image_url)
|
||||
src = tmp
|
||||
else:
|
||||
src = image_url
|
||||
if not os.path.exists(src):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {src}"}
|
||||
|
||||
from services.service_vectorizer import VectorizerService
|
||||
svc = VectorizerService()
|
||||
out = save_path or os.path.join(_OUTPUT_DIR, f"vec_{uuid.uuid4().hex}.eps")
|
||||
result_path = await svc.image_to_eps(src, save_eps_path=out)
|
||||
if result_path and os.path.exists(result_path):
|
||||
return {"success": True, "result_path": result_path, "message": "矢量化完成,已生成 EPS 文件"}
|
||||
return {"success": False, "result_path": "", "message": "矢量化失败"}
|
||||
except ImportError as e:
|
||||
return {"success": False, "result_path": "", "message": f"矢量化服务不可用: {e}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
|
||||
|
||||
async def meitu_enhance(image_url: str, mode: str = "standard", save_path: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
【独立工具】美图画质增强。
|
||||
模式: crystal(极速重绘) standard(标准) enhance(增强) hdr(HDR) portrait(人像优化)
|
||||
客户要画质增强、清晰化、美图处理时调用。
|
||||
"""
|
||||
tmp = None
|
||||
try:
|
||||
if image_url.startswith(("http://", "https://")):
|
||||
tmp = await _download(image_url)
|
||||
src = tmp
|
||||
else:
|
||||
src = image_url
|
||||
if not os.path.exists(src):
|
||||
return {"success": False, "result_path": "", "message": f"文件不存在: {src}"}
|
||||
|
||||
from pathlib import Path
|
||||
from services.service_meitu import MeituAPIService
|
||||
svc = MeituAPIService()
|
||||
output_dir = Path(_OUTPUT_DIR)
|
||||
result = await svc.process_image(src, mode=mode, output_dir=output_dir)
|
||||
out = result.get("processed_path")
|
||||
if out and os.path.exists(str(out)):
|
||||
if save_path:
|
||||
import shutil
|
||||
shutil.copy(str(out), save_path)
|
||||
out = save_path
|
||||
return {"success": True, "result_path": str(out), "message": f"画质增强完成({result.get('mode_name', mode)})"}
|
||||
return {"success": False, "result_path": "", "message": "美图处理失败"}
|
||||
except ImportError as e:
|
||||
return {"success": False, "result_path": "", "message": f"美图服务不可用: {e}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "result_path": "", "message": str(e)}
|
||||
finally:
|
||||
if tmp and os.path.exists(tmp):
|
||||
os.remove(tmp)
|
||||
651
image/perspective_fix.py
Normal file
651
image/perspective_fix.py
Normal file
@@ -0,0 +1,651 @@
|
||||
"""
|
||||
透视矫正三步流程:
|
||||
Step1: Gemini 去背景 → 纯白背景
|
||||
Step2: OpenCV 在白背景图上检测四角 → warpPerspective 展平
|
||||
Step3: Gemini 对展平结果做高清增强
|
||||
|
||||
用法:
|
||||
python perspective_fix.py <图片路径或URL> [--debug] [--skip-step1] [--skip-step3]
|
||||
"""
|
||||
import sys, io
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
||||
|
||||
import os, asyncio, uuid, tempfile
|
||||
import numpy as np
|
||||
import cv2
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_OUTPUT_DIR = os.getenv("RESULT_IMAGE_DIR", "results")
|
||||
os.makedirs(_OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# Gemini 辅助函数
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
async def _gemini_call(input_path: str, output_path: str, prompt: str,
|
||||
aspect_ratio: str = "1:1", label: str = "") -> bool:
|
||||
from services.service_gemini import GeminiExtractV2Service
|
||||
service = GeminiExtractV2Service()
|
||||
try:
|
||||
ok, msg, _ = await service.extract_pattern(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
custom_prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
)
|
||||
status = "成功" if ok else "失败"
|
||||
print(f" [{label}] Gemini {status}: {msg[:80]}")
|
||||
return ok and os.path.exists(output_path)
|
||||
except Exception as e:
|
||||
print(f" [{label}] Gemini 异常: {e}")
|
||||
return False
|
||||
finally:
|
||||
await service.cleanup()
|
||||
|
||||
|
||||
PROMPT_WHITE_BG = (
|
||||
"请处理这张图片:\n"
|
||||
"1. 识别图中的地毯/地垫/印花布料/产品本体作为主体\n"
|
||||
"2. 去掉主体上面放置的所有物品(杯子、碗、餐具、装饰品等),只保留地垫本身\n"
|
||||
"3. 把所有背景(桌面、地板、墙壁、阴影)全部替换为纯白色(#FFFFFF)\n"
|
||||
"4. 保持地垫/产品的颜色、图案、边缘完全不变\n"
|
||||
"输出:只有主体产品、纯白背景、无杂物的干净产品图。"
|
||||
)
|
||||
|
||||
# 当第一次去背景效果不好时(白色覆盖率过低),用更强硬的提示词重试
|
||||
PROMPT_WHITE_BG_STRONG = (
|
||||
"严格执行:将这张图的背景彻底替换为纯白色 RGB(255,255,255)。\n"
|
||||
"只保留图片中央的产品/地毯/布料主体,其他所有区域(桌面/地板/墙/阴影/物品)"
|
||||
"一律改为纯白色。产品边缘要干净锐利,不留任何半透明或灰色区域。\n"
|
||||
"重要:不论主体上摆放了什么东西,统统去掉,只输出产品本身+白色背景。"
|
||||
)
|
||||
|
||||
PROMPT_ENHANCE = (
|
||||
"请对这张已展平的图案进行高清增强:提升整体清晰度和色彩饱和度,"
|
||||
"修复边缘锯齿,补全缺失细节,输出印刷级高质量平面图,背景保持纯白。"
|
||||
)
|
||||
|
||||
# Step3 增强失败时的兜底提示词(更简单,成功率更高)
|
||||
PROMPT_ENHANCE_SIMPLE = (
|
||||
"请提升这张图片的清晰度和画质,输出高清版本,背景保持纯白。"
|
||||
)
|
||||
|
||||
|
||||
def _measure_white_coverage(image: np.ndarray) -> float:
|
||||
"""返回图片中白色像素的百分比,用于判断去背景效果"""
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
_, mask = cv2.threshold(gray, 245, 255, cv2.THRESH_BINARY)
|
||||
return float(np.sum(mask == 255)) / mask.size
|
||||
|
||||
|
||||
def _color_match(source: np.ndarray, target: np.ndarray,
|
||||
strength: float = 0.75, exclude_white: bool = True) -> np.ndarray:
|
||||
"""
|
||||
将 target 的色调匹配到 source(类 PS「匹配颜色」)。
|
||||
使用 LAB 色彩空间 Reinhard 均值/标准差迁移。
|
||||
|
||||
Args:
|
||||
source: 原图(色彩参考来源)
|
||||
target: 待调整图(处理后结果)
|
||||
strength: 迁移强度 0.0-1.0,推荐 0.6-0.85
|
||||
exclude_white: 统计时排除白色像素,避免背景影响肤色/图案计算
|
||||
Returns:
|
||||
调色后的 BGR 图像
|
||||
"""
|
||||
src_f = source.astype(np.float32) / 255.0
|
||||
tgt_f = target.astype(np.float32) / 255.0
|
||||
|
||||
src_lab = cv2.cvtColor(src_f, cv2.COLOR_BGR2Lab)
|
||||
tgt_lab = cv2.cvtColor(tgt_f, cv2.COLOR_BGR2Lab)
|
||||
result = tgt_lab.copy()
|
||||
|
||||
for ch in range(3):
|
||||
if exclude_white:
|
||||
# 排除极亮像素(L > 95)统计,只看图案区域
|
||||
src_mask = src_lab[:, :, 0] < 95
|
||||
tgt_mask = tgt_lab[:, :, 0] < 95
|
||||
src_vals = src_lab[:, :, ch][src_mask]
|
||||
tgt_vals = tgt_lab[:, :, ch][tgt_mask]
|
||||
else:
|
||||
src_vals = src_lab[:, :, ch].ravel()
|
||||
tgt_vals = tgt_lab[:, :, ch].ravel()
|
||||
|
||||
if src_vals.size == 0 or tgt_vals.size == 0:
|
||||
continue
|
||||
|
||||
src_mean, src_std = float(src_vals.mean()), float(src_vals.std())
|
||||
tgt_mean, tgt_std = float(tgt_vals.mean()), float(tgt_vals.std())
|
||||
|
||||
if tgt_std < 1e-6:
|
||||
continue
|
||||
|
||||
# Reinhard 迁移:先归一化到目标,再重映射到源分布
|
||||
shifted = (tgt_lab[:, :, ch] - tgt_mean) / tgt_std * src_std + src_mean
|
||||
# 按 strength 混合:strength=1 完全迁移,0 保持不变
|
||||
result[:, :, ch] = shifted * strength + tgt_lab[:, :, ch] * (1.0 - strength)
|
||||
|
||||
result_bgr = cv2.cvtColor(result, cv2.COLOR_Lab2BGR)
|
||||
result_bgr = np.clip(result_bgr * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
print(f" [颜色匹配] 强度={strength:.0%} | "
|
||||
f"源均值L={src_lab[:,:,0].mean():.1f} → 目标均值L={tgt_lab[:,:,0].mean():.1f}")
|
||||
return result_bgr
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# OpenCV 透视矫正
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
def order_points(pts: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
把四个点排列为 [左上, 右上, 右下, 左下]。
|
||||
使用质心角度排序,对矩形、菱形、平行四边形等各种透视形状均适用。
|
||||
"""
|
||||
cx, cy = pts[:, 0].mean(), pts[:, 1].mean()
|
||||
# 计算每个点相对质心的角度(从正上方顺时针)
|
||||
angles = np.arctan2(pts[:, 1] - cy, pts[:, 0] - cx)
|
||||
# 顺时针排序:从右上开始(角度最小的)
|
||||
order = np.argsort(angles)
|
||||
sorted_pts = pts[order]
|
||||
# 找到最左上角作为起点(x+y 最小)
|
||||
s = sorted_pts.sum(axis=1)
|
||||
start = np.argmin(s)
|
||||
# 从左上角开始顺时针排列 → [左上, 右上, 右下, 左下]
|
||||
indices = [(start + i) % 4 for i in range(4)]
|
||||
rect = sorted_pts[indices].astype("float32")
|
||||
return rect
|
||||
|
||||
|
||||
def four_point_transform(image: np.ndarray, pts: np.ndarray) -> np.ndarray:
|
||||
rect = order_points(pts)
|
||||
tl, tr, br, bl = rect
|
||||
|
||||
w1 = np.linalg.norm(br - bl)
|
||||
w2 = np.linalg.norm(tr - tl)
|
||||
h1 = np.linalg.norm(tr - br)
|
||||
h2 = np.linalg.norm(tl - bl)
|
||||
W = int(max(w1, w2))
|
||||
H = int(max(h1, h2))
|
||||
|
||||
print(f" [CV] 角点: TL={tl.astype(int)} TR={tr.astype(int)} BR={br.astype(int)} BL={bl.astype(int)}")
|
||||
print(f" [CV] 矫正后目标尺寸: {W}x{H}")
|
||||
|
||||
dst = np.array([
|
||||
[0, 0 ],
|
||||
[W - 1, 0 ],
|
||||
[W - 1, H - 1],
|
||||
[0, H - 1],
|
||||
], dtype="float32")
|
||||
|
||||
M = cv2.getPerspectiveTransform(rect, dst)
|
||||
warped = cv2.warpPerspective(
|
||||
image, M, (W, H),
|
||||
flags=cv2.INTER_LANCZOS4,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(255, 255, 255),
|
||||
)
|
||||
return warped
|
||||
|
||||
|
||||
def _detect_bg_color(image: np.ndarray, corner_size: int = 24) -> np.ndarray:
|
||||
"""
|
||||
从图片四个角落采样,估计背景颜色(BGR)。
|
||||
适用于白色、米色、黄色、灰色等各种背景。
|
||||
"""
|
||||
H, W = image.shape[:2]
|
||||
cs = min(corner_size, H // 5, W // 5)
|
||||
corners = [
|
||||
image[:cs, :cs], # 左上
|
||||
image[:cs, W-cs:], # 右上
|
||||
image[H-cs:, :cs], # 左下
|
||||
image[H-cs:, W-cs:], # 右下
|
||||
]
|
||||
pixels = np.concatenate([c.reshape(-1, 3) for c in corners], axis=0)
|
||||
bg = np.median(pixels, axis=0).astype(np.uint8)
|
||||
return bg # BGR
|
||||
|
||||
|
||||
def tool_trim_white_border(image: np.ndarray,
|
||||
tolerance: int = 18,
|
||||
bg_ratio: float = 0.90,
|
||||
padding: int = 4) -> tuple[np.ndarray, bool, dict]:
|
||||
"""
|
||||
【Tool】智能背景边裁切(支持任意背景色:白/黄/米/灰等)。
|
||||
|
||||
算法:
|
||||
1. 从四角采样估计背景色
|
||||
2. 逐行/列扫描:若该行/列中 bg_ratio 以上的像素与背景色差异 <= tolerance,则为背景行/列
|
||||
3. 找到内容区域边界后裁切
|
||||
|
||||
Returns:
|
||||
(裁切后图片, 是否裁切, 详情dict)
|
||||
"""
|
||||
H, W = image.shape[:2]
|
||||
bg_color = _detect_bg_color(image)
|
||||
img_f = image.astype(np.int32)
|
||||
|
||||
# 每个像素与背景色的最大通道差异
|
||||
diff = np.abs(img_f - bg_color.astype(np.int32)).max(axis=2) # H x W
|
||||
is_bg = diff <= tolerance # True = 接近背景色
|
||||
|
||||
row_bg_ratio = is_bg.mean(axis=1) # 每行的背景像素占比
|
||||
col_bg_ratio = is_bg.mean(axis=0) # 每列的背景像素占比
|
||||
|
||||
top = next((i for i in range(H) if row_bg_ratio[i] < bg_ratio), H)
|
||||
bottom = next((i for i in range(H-1,-1,-1) if row_bg_ratio[i] < bg_ratio), -1) + 1
|
||||
left = next((i for i in range(W) if col_bg_ratio[i] < bg_ratio), W)
|
||||
right = next((i for i in range(W-1,-1,-1) if col_bg_ratio[i] < bg_ratio), -1) + 1
|
||||
|
||||
border_top = top
|
||||
border_bottom = H - bottom
|
||||
border_left = left
|
||||
border_right = W - right
|
||||
max_border = max(border_top, border_bottom, border_left, border_right)
|
||||
|
||||
bg_hex = "#{:02X}{:02X}{:02X}".format(int(bg_color[2]), int(bg_color[1]), int(bg_color[0]))
|
||||
info = {"top": border_top, "bottom": border_bottom,
|
||||
"left": border_left, "right": border_right, "bg_color": bg_hex}
|
||||
|
||||
if max_border < 5:
|
||||
print(f" [裁边] 背景色{bg_hex} | 上{border_top} 下{border_bottom} 左{border_left} 右{border_right}px → 无需裁切")
|
||||
return image, False, info
|
||||
|
||||
y1 = max(0, top - padding)
|
||||
y2 = min(H, bottom + padding)
|
||||
x1 = max(0, left - padding)
|
||||
x2 = min(W, right + padding)
|
||||
cropped = image[y1:y2, x1:x2]
|
||||
ch, cw = cropped.shape[:2]
|
||||
print(f" [裁边] 背景色{bg_hex} | 上{border_top} 下{border_bottom} 左{border_left} 右{border_right}px → 裁切 {W}x{H}→{cw}x{ch}")
|
||||
return cropped, True, info
|
||||
|
||||
|
||||
async def tool_color_match(orig_img: np.ndarray, result_img: np.ndarray,
|
||||
strength: float = 0.75) -> np.ndarray:
|
||||
"""【Tool】颜色匹配(封装版,供 AI 决策层调用)"""
|
||||
return _color_match(orig_img, result_img, strength=strength)
|
||||
|
||||
|
||||
async def ai_decide_postprocess(orig_img: np.ndarray, result_img: np.ndarray) -> dict:
|
||||
"""
|
||||
【AI 决策层】用视觉模型分析出图效果,决定是否需要颜色匹配和白边裁切。
|
||||
|
||||
Returns:
|
||||
{
|
||||
"need_color_match": bool,
|
||||
"color_strength": float, # 0.5-0.9
|
||||
"need_trim": bool,
|
||||
"reason": str,
|
||||
}
|
||||
"""
|
||||
import base64
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_BASE_URL")
|
||||
model = os.getenv("VISION_MODEL", "glm-4v-flash")
|
||||
|
||||
# 无 API 时默认两个都做
|
||||
if not api_key:
|
||||
return {"need_color_match": True, "color_strength": 0.75,
|
||||
"need_trim": True, "reason": "无API Key,默认执行"}
|
||||
|
||||
def _encode(img: np.ndarray) -> str:
|
||||
resized = cv2.resize(img, (512, 512))
|
||||
_, buf = cv2.imencode(".jpg", resized, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
||||
return base64.b64encode(buf).decode()
|
||||
|
||||
orig_b64 = _encode(orig_img)
|
||||
result_b64 = _encode(result_img)
|
||||
|
||||
prompt = (
|
||||
"你是图片后处理决策助手。图一是原图,图二是AI处理后的结果图。请判断:\n\n"
|
||||
"【问题1】颜色差异:处理后图片的整体色调与原图相比,差异是否明显?\n"
|
||||
"(明显=色调/饱和度/冷暖差异很大;轻微=有轻微偏差;无=颜色基本一致)\n\n"
|
||||
"【问题2】多余边框:处理后图片四周是否有不属于图案内容的多余空白边框?\n"
|
||||
"注意:边框颜色不一定是白色,也可能是黄色、米色、灰色等任何纯色。\n"
|
||||
"判断标准:图案内容的外围是否有一圈明显的纯色空白带。\n\n"
|
||||
"严格按格式回答(每行一个字段,不要多余内容):\n"
|
||||
"颜色差异: <明显|轻微|无>\n"
|
||||
"多余边框: <有|无>\n"
|
||||
"边框位置: <有边框的方向如「上下」,没有则填无>"
|
||||
)
|
||||
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{orig_b64}"}},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{result_b64}"}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}],
|
||||
)
|
||||
text = response.choices[0].message.content or ""
|
||||
print(f" [AI决策] 原始回答: {text.strip()[:120]}")
|
||||
|
||||
def _get(key):
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith(key):
|
||||
return line.split(":", 1)[-1].strip()
|
||||
return ""
|
||||
|
||||
color_level = _get("颜色差异")
|
||||
has_border = "有" in _get("多余边框")
|
||||
border_pos = _get("边框位置")
|
||||
|
||||
strength_map = {"明显": 0.80, "轻微": 0.55, "无": 0.0}
|
||||
color_strength = strength_map.get(color_level, 0.75)
|
||||
need_color = color_strength > 0
|
||||
|
||||
reason = f"颜色差异={color_level or '?'}, 边框={'有('+border_pos+')' if has_border else '无'}"
|
||||
print(f" [AI决策] {reason} → 颜色匹配={'✓' if need_color else '✗'}(强度{color_strength:.0%}), 裁边={'✓' if has_border else '✗'}")
|
||||
|
||||
return {
|
||||
"need_color_match": need_color,
|
||||
"color_strength": color_strength,
|
||||
"need_trim": has_border,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f" [AI决策] 调用失败({e}),默认执行颜色匹配+裁边")
|
||||
return {"need_color_match": True, "color_strength": 0.75,
|
||||
"need_trim": True, "reason": f"AI决策失败: {e}"}
|
||||
|
||||
|
||||
def _points_are_unique(pts: np.ndarray, min_dist: float = 20.0) -> bool:
|
||||
"""检查4个角点两两之间距离都大于 min_dist,防止重复点导致退化变换"""
|
||||
for i in range(len(pts)):
|
||||
for j in range(i + 1, len(pts)):
|
||||
if np.linalg.norm(pts[i] - pts[j]) < min_dist:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def find_quad(image: np.ndarray):
|
||||
"""
|
||||
在白背景图上检测主体四边形角点。
|
||||
策略(按优先级):
|
||||
1. 二值化 + approxPolyDP(epsilon 从小到大尝试)
|
||||
2. 凸包取极值四点(最左/最右/最上/最下)
|
||||
3. minAreaRect 四角
|
||||
"""
|
||||
h, w = image.shape[:2]
|
||||
img_area = h * w
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# ── 获取主体轮廓 ──────────────────────────────────────────
|
||||
_, thresh = cv2.threshold(gray, 245, 255, cv2.THRESH_BINARY_INV)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (20, 20))
|
||||
closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
||||
|
||||
cnts, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not cnts:
|
||||
edges = cv2.Canny(gray, 30, 100)
|
||||
k2 = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
|
||||
closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, k2)
|
||||
cnts, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
if not cnts:
|
||||
print(" [CV] 无法检测轮廓")
|
||||
return None
|
||||
|
||||
c = max(cnts, key=cv2.contourArea)
|
||||
area = cv2.contourArea(c)
|
||||
print(f" [CV] 主体轮廓面积: {area:.0f} / {img_area} ({area/img_area*100:.1f}%)")
|
||||
if area < img_area * 0.05:
|
||||
print(" [CV] 面积太小,背景可能去除不完全")
|
||||
return None
|
||||
|
||||
peri = cv2.arcLength(c, True)
|
||||
|
||||
# ── 策略1:approxPolyDP,epsilon 逐步放大直到得到4个唯一角点 ──
|
||||
for eps_ratio in [0.02, 0.03, 0.04, 0.05, 0.06]:
|
||||
approx = cv2.approxPolyDP(c, eps_ratio * peri, True)
|
||||
pts = approx.reshape(-1, 2).astype("float32")
|
||||
if len(pts) == 4 and _points_are_unique(pts):
|
||||
print(f" [CV] approxPolyDP 成功 (eps={eps_ratio}), 4个唯一角点")
|
||||
return pts
|
||||
print(f" [CV] approxPolyDP eps={eps_ratio}: {len(pts)} 顶点,唯一={_points_are_unique(pts) if len(pts)==4 else 'N/A'}")
|
||||
|
||||
# ── 策略2:凸包极值四点(最左/最上/最右/最下)─────────────
|
||||
hull = cv2.convexHull(c).reshape(-1, 2).astype("float32")
|
||||
if len(hull) >= 4:
|
||||
# 取4个极值方向的点
|
||||
left = hull[np.argmin(hull[:, 0])] # 最左
|
||||
right = hull[np.argmax(hull[:, 0])] # 最右
|
||||
top = hull[np.argmin(hull[:, 1])] # 最上
|
||||
bottom = hull[np.argmax(hull[:, 1])] # 最下
|
||||
pts = np.array([left, top, right, bottom], dtype="float32")
|
||||
if _points_are_unique(pts):
|
||||
print(f" [CV] 使用凸包极值四点: L={left.astype(int)} T={top.astype(int)} R={right.astype(int)} B={bottom.astype(int)}")
|
||||
return pts
|
||||
|
||||
# ── 策略3:minAreaRect 四角(兜底)─────────────────────────
|
||||
print(f" [CV] 兜底:使用 minAreaRect")
|
||||
rect = cv2.minAreaRect(c)
|
||||
box = cv2.boxPoints(rect).astype("float32")
|
||||
return box
|
||||
|
||||
|
||||
def save_debug_img(image: np.ndarray, pts, path: str):
|
||||
"""保存带角点标注的调试图"""
|
||||
dbg = image.copy()
|
||||
if pts is not None:
|
||||
rect = order_points(pts)
|
||||
labels = ["TL", "TR", "BR", "BL"]
|
||||
colors = [(0,0,255), (0,255,0), (255,0,0), (0,165,255)]
|
||||
for i, (px, py) in enumerate(rect):
|
||||
cv2.circle(dbg, (int(px), int(py)), 12, colors[i], -1)
|
||||
cv2.putText(dbg, labels[i], (int(px)+15, int(py)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.2, colors[i], 3)
|
||||
box = rect.reshape((-1,1,2)).astype(np.int32)
|
||||
cv2.polylines(dbg, [box], True, (0,0,255), 3)
|
||||
cv2.imwrite(path, dbg, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
||||
print(f" [Debug] 调试图: {path}")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# 主流程
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
async def process(src: str, debug: bool = False,
|
||||
skip_step1: bool = False, skip_step3: bool = False) -> str | None:
|
||||
uid = uuid.uuid4().hex
|
||||
tmp = [] # 临时文件列表,最后统一清理
|
||||
|
||||
# ── 下载(URL 情况)──────────────────────────────────────
|
||||
if src.startswith("http"):
|
||||
import aiohttp
|
||||
dl = os.path.join(tempfile.gettempdir(), f"pfix_dl_{uid}.jpg")
|
||||
tmp.append(dl)
|
||||
print("[下载] 原图中...")
|
||||
async with aiohttp.ClientSession(headers={
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",
|
||||
"Referer": "https://www.taobao.com/",
|
||||
}) as sess:
|
||||
async with sess.get(src, timeout=aiohttp.ClientTimeout(total=30)) as r:
|
||||
if r.status != 200:
|
||||
print(f"[下载] 失败: HTTP {r.status}")
|
||||
return None
|
||||
with open(dl, "wb") as f:
|
||||
f.write(await r.read())
|
||||
local_src = dl
|
||||
else:
|
||||
local_src = src
|
||||
|
||||
current = local_src # 当前处理中的文件
|
||||
orig_img = cv2.imread(local_src) # 保留原图用于颜色匹配
|
||||
# 记录原图宽高比,用于检测 Gemini 旋转问题
|
||||
orig_ratio = (orig_img.shape[1] / orig_img.shape[0]) if orig_img is not None else 1.0
|
||||
|
||||
try:
|
||||
# ── Step 1: Gemini 去背景 → 白背景 ──────────────────
|
||||
if not skip_step1:
|
||||
print("\n" + "─"*50)
|
||||
print("Step 1 / 3 | Gemini 去背景 → 白色背景")
|
||||
print("─"*50)
|
||||
s1_out = os.path.join(tempfile.gettempdir(), f"pfix_s1_{uid}.jpg")
|
||||
tmp.append(s1_out)
|
||||
ok = await _gemini_call(current, s1_out, PROMPT_WHITE_BG,
|
||||
aspect_ratio="auto", label="去背景")
|
||||
if ok:
|
||||
# 检查白色覆盖率,判断背景去除是否充分
|
||||
s1_img = cv2.imread(s1_out)
|
||||
white_pct = _measure_white_coverage(s1_img) if s1_img is not None else 0.0
|
||||
print(f" [去背景] 白色覆盖率: {white_pct:.1%}", end="")
|
||||
if white_pct < 0.20:
|
||||
# 背景去除太差,用强化提示词重试
|
||||
print(" → 太低,强化提示词重试...")
|
||||
s1_retry = os.path.join(tempfile.gettempdir(), f"pfix_s1r_{uid}.jpg")
|
||||
tmp.append(s1_retry)
|
||||
ok2 = await _gemini_call(current, s1_retry, PROMPT_WHITE_BG_STRONG,
|
||||
aspect_ratio="auto", label="去背景(强化)")
|
||||
if ok2:
|
||||
r_img = cv2.imread(s1_retry)
|
||||
retry_pct = _measure_white_coverage(r_img) if r_img is not None else 0.0
|
||||
print(f" [去背景] 重试白色覆盖率: {retry_pct:.1%}", end="")
|
||||
if retry_pct >= white_pct:
|
||||
print(" → 效果更好,采用重试结果")
|
||||
current = s1_retry
|
||||
else:
|
||||
print(" → 效果未提升,保留首次结果")
|
||||
current = s1_out
|
||||
else:
|
||||
print(" [去背景] 重试失败,保留首次结果")
|
||||
current = s1_out
|
||||
else:
|
||||
print(" → 合格")
|
||||
current = s1_out
|
||||
else:
|
||||
print(" Step1 失败,用原图继续")
|
||||
else:
|
||||
print("\n[跳过 Step1] 直接用原图")
|
||||
|
||||
# ── Step 2: OpenCV 在白背景图上检测+透视矫正 ─────────
|
||||
print("\n" + "─"*50)
|
||||
print("Step 2 / 3 | OpenCV 轮廓检测 + 透视矫正")
|
||||
print("─"*50)
|
||||
img = cv2.imread(current)
|
||||
if img is None:
|
||||
print(f" 无法读取: {current}")
|
||||
return None
|
||||
|
||||
h, w = img.shape[:2]
|
||||
print(f" 输入尺寸: {w}x{h}")
|
||||
pts = find_quad(img)
|
||||
|
||||
if debug:
|
||||
dbg_path = os.path.join(_OUTPUT_DIR, f"debug_{uid}.jpg")
|
||||
save_debug_img(img, pts, dbg_path)
|
||||
|
||||
if pts is not None:
|
||||
warped = four_point_transform(img, pts)
|
||||
|
||||
# ── 方向校正:Gemini 可能把图旋转 90°,需要纠正 ──
|
||||
wh2, ww2 = warped.shape[:2]
|
||||
warped_ratio = ww2 / wh2 # 宽/高
|
||||
# 若原图横竖方向与矫正结果相反(比例差异超过 1.5 倍),旋转 90°
|
||||
if orig_ratio > 1.0 and warped_ratio < 1.0 / 1.5:
|
||||
# 原图横,结果竖 → 顺时针转 90°
|
||||
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
|
||||
print(f" [方向校正] 原图横({orig_ratio:.2f}) vs 矫正竖({warped_ratio:.2f}) → 旋转90°")
|
||||
elif orig_ratio < 1.0 and warped_ratio > 1.5:
|
||||
# 原图竖,结果横 → 逆时针转 90°
|
||||
warped = cv2.rotate(warped, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
print(f" [方向校正] 原图竖({orig_ratio:.2f}) vs 矫正横({warped_ratio:.2f}) → 旋转-90°")
|
||||
else:
|
||||
print(f" [方向校正] 方向一致,无需旋转 (原图比例={orig_ratio:.2f}, 矫正比例={warped_ratio:.2f})")
|
||||
|
||||
s2_out = os.path.join(tempfile.gettempdir(), f"pfix_s2_{uid}.jpg")
|
||||
tmp.append(s2_out)
|
||||
cv2.imwrite(s2_out, warped, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
current = s2_out
|
||||
wh2, ww2 = warped.shape[:2]
|
||||
print(f" 透视矫正完成 → {ww2}x{wh2}")
|
||||
else:
|
||||
print(" 角点检测失败,跳过透视矫正,继续用白背景图")
|
||||
|
||||
# ── Step 3: Qwen 高清增强 ─────────────────────────────
|
||||
if not skip_step3:
|
||||
print("\n" + "─"*50)
|
||||
print("Step 3 / 5 | Qwen 高清增强(RunningHub)")
|
||||
print("─"*50)
|
||||
final_out = os.path.join(_OUTPUT_DIR, f"pfix_final_{uid}.jpg")
|
||||
from services.service_qwen import 清晰化_api
|
||||
ok = await 清晰化_api(img_path=current, save_path=final_out)
|
||||
if ok:
|
||||
print(f" [高清增强] Qwen 成功")
|
||||
else:
|
||||
# Qwen 失败,用 Gemini 简化提示词兜底
|
||||
print(" Qwen 失败,Gemini 兜底重试...")
|
||||
ok = await _gemini_call(current, final_out, PROMPT_ENHANCE_SIMPLE,
|
||||
aspect_ratio="auto", label="高清增强(Gemini兜底)")
|
||||
if not ok:
|
||||
print(" Step3 全部失败,直接保存矫正结果")
|
||||
import shutil
|
||||
shutil.copy2(current, final_out)
|
||||
else:
|
||||
final_out = os.path.join(_OUTPUT_DIR, f"pfix_final_{uid}.jpg")
|
||||
import shutil
|
||||
shutil.copy2(current, final_out)
|
||||
print("\n[跳过 Step3] 直接保存矫正结果")
|
||||
|
||||
# ── Step 4: AI 决策 + 后处理(颜色匹配 & 白边裁切)────
|
||||
print("\n" + "─"*50)
|
||||
print("Step 4 / 4 | AI 决策后处理(颜色匹配 / 白边裁切)")
|
||||
print("─"*50)
|
||||
final_img = cv2.imread(final_out)
|
||||
if final_img is not None and orig_img is not None:
|
||||
decision = await ai_decide_postprocess(orig_img, final_img)
|
||||
|
||||
# Tool 1: 颜色匹配
|
||||
if decision["need_color_match"]:
|
||||
final_img = await tool_color_match(orig_img, final_img,
|
||||
strength=decision["color_strength"])
|
||||
cv2.imwrite(final_out, final_img, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
else:
|
||||
print(" [颜色匹配] AI 判断无需调色,跳过")
|
||||
|
||||
# Tool 2: 白边裁切
|
||||
if decision["need_trim"]:
|
||||
trimmed, did_trim, _ = tool_trim_white_border(final_img)
|
||||
if did_trim:
|
||||
cv2.imwrite(final_out, trimmed, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||
else:
|
||||
print(" [裁边] AI 判断无白边,跳过")
|
||||
else:
|
||||
print(" [Step4] 图片读取失败,跳过后处理")
|
||||
|
||||
size_kb = os.path.getsize(final_out) / 1024
|
||||
print(f"\n{'='*50}")
|
||||
print(f" 完成!输出文件: {final_out}")
|
||||
print(f" 文件大小: {size_kb:.0f} KB")
|
||||
print(f"{'='*50}")
|
||||
return final_out
|
||||
|
||||
finally:
|
||||
for f in tmp:
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("用法: python perspective_fix.py <图片路径或URL> [--debug] [--skip-step1] [--skip-step3]")
|
||||
sys.exit(1)
|
||||
|
||||
src_arg = sys.argv[1]
|
||||
debug_arg = "--debug" in sys.argv
|
||||
skip1_arg = "--skip-step1" in sys.argv
|
||||
skip3_arg = "--skip-step3" in sys.argv
|
||||
asyncio.run(process(src_arg, debug=debug_arg, skip_step1=skip1_arg, skip_step3=skip3_arg))
|
||||
Reference in New Issue
Block a user