feat: localize logs, colorize streams, and fix draw pipeline params
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
@@ -6,9 +6,12 @@ import uuid
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from .logger import setup_logger
|
||||||
|
|
||||||
from .config import AUTO_DRAW_TIMEOUT_SECONDS
|
from .config import AUTO_DRAW_TIMEOUT_SECONDS
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
async def auto_draw_preview(
|
async def auto_draw_preview(
|
||||||
image_url: str,
|
image_url: str,
|
||||||
@@ -22,16 +25,19 @@ async def auto_draw_preview(
|
|||||||
3) 上传图绘,返回可外发 URL
|
3) 上传图绘,返回可外发 URL
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
logger.info("[作图] 开始 customer=%s image=%s", customer_id, image_url)
|
||||||
from services.service_gemini import GeminiExtractV2Service # type: ignore
|
from services.service_gemini import GeminiExtractV2Service # type: ignore
|
||||||
from services.service_tuhui_upload import upload_to_tuhui # type: ignore
|
from services.service_tuhui_upload import upload_to_tuhui # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"ok": False, "error": f"import_failed:{e}"}
|
logger.error("[作图] 依赖加载失败: %s", e)
|
||||||
|
return {"ok": False, "error": f"依赖加载失败:{e}"}
|
||||||
|
|
||||||
prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本"
|
prompt = requirement.strip() or "按原图做高清修复,保留主体细节,输出清晰可用版本"
|
||||||
input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg")
|
input_path = os.path.join(tempfile.gettempdir(), f"qjcs_in_{uuid.uuid4().hex}.jpg")
|
||||||
output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg")
|
output_path = os.path.join(tempfile.gettempdir(), f"qjcs_out_{uuid.uuid4().hex}.jpg")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info("[作图] 下载原图中")
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": (
|
"User-Agent": (
|
||||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||||
@@ -43,10 +49,13 @@ async def auto_draw_preview(
|
|||||||
}
|
}
|
||||||
resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
|
resp = requests.get(image_url, headers=headers, timeout=AUTO_DRAW_TIMEOUT_SECONDS)
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
return {"ok": False, "error": f"download_http_{resp.status_code}"}
|
logger.error("[作图] 原图下载失败: http_%s", resp.status_code)
|
||||||
|
return {"ok": False, "error": f"原图下载失败:http_{resp.status_code}"}
|
||||||
with open(input_path, "wb") as f:
|
with open(input_path, "wb") as f:
|
||||||
f.write(resp.content)
|
f.write(resp.content)
|
||||||
|
logger.info("[作图] 原图下载完成 size=%s", len(resp.content))
|
||||||
|
|
||||||
|
logger.info("[作图] Gemini 生成中")
|
||||||
service = GeminiExtractV2Service()
|
service = GeminiExtractV2Service()
|
||||||
ok_extract, msg_extract, _ = await service.extract_pattern(
|
ok_extract, msg_extract, _ = await service.extract_pattern(
|
||||||
input_path=input_path,
|
input_path=input_path,
|
||||||
@@ -55,10 +64,14 @@ async def auto_draw_preview(
|
|||||||
aspect_ratio="1:1",
|
aspect_ratio="1:1",
|
||||||
)
|
)
|
||||||
if not ok_extract:
|
if not ok_extract:
|
||||||
return {"ok": False, "error": f"extract_failed:{msg_extract}"}
|
logger.error("[作图] Gemini 生成失败: %s", msg_extract)
|
||||||
|
return {"ok": False, "error": f"生成失败:{msg_extract}"}
|
||||||
if not os.path.exists(output_path):
|
if not os.path.exists(output_path):
|
||||||
return {"ok": False, "error": "extract_no_output_file"}
|
logger.error("[作图] Gemini 未产出文件")
|
||||||
|
return {"ok": False, "error": "生成失败:未产出文件"}
|
||||||
|
logger.info("[作图] Gemini 生成完成")
|
||||||
|
|
||||||
|
logger.info("[作图] 上传图绘中")
|
||||||
ok_upload, link, _ = await upload_to_tuhui(
|
ok_upload, link, _ = await upload_to_tuhui(
|
||||||
output_path,
|
output_path,
|
||||||
title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图",
|
title=f"客户{customer_id[-4:]}-预览图" if customer_id else "预览图",
|
||||||
@@ -66,10 +79,13 @@ async def auto_draw_preview(
|
|||||||
price=1,
|
price=1,
|
||||||
)
|
)
|
||||||
if not ok_upload:
|
if not ok_upload:
|
||||||
return {"ok": False, "error": f"upload_failed:{link}"}
|
logger.error("[作图] 图绘上传失败: %s", link)
|
||||||
|
return {"ok": False, "error": f"上传失败:{link}"}
|
||||||
|
logger.info("[作图] 上传成功 url=%s", link)
|
||||||
return {"ok": True, "url": str(link)}
|
return {"ok": True, "url": str(link)}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"ok": False, "error": str(e)}
|
logger.exception("[作图] 异常")
|
||||||
|
return {"ok": False, "error": f"作图异常:{e}"}
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
if os.path.exists(input_path):
|
if os.path.exists(input_path):
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from .config import (
|
|||||||
AUTO_QUOTE_WAIT_SECONDS,
|
AUTO_QUOTE_WAIT_SECONDS,
|
||||||
MESSAGE_DEBOUNCE_SECONDS,
|
MESSAGE_DEBOUNCE_SECONDS,
|
||||||
QINGJIAN_WS_URI,
|
QINGJIAN_WS_URI,
|
||||||
SHORT_REPLY_MAX_CHARS,
|
|
||||||
)
|
)
|
||||||
from .logger import setup_logger
|
from .logger import setup_logger
|
||||||
from .observability import activity_event, build_trace_id
|
from .observability import activity_event, build_trace_id
|
||||||
@@ -49,6 +48,48 @@ class QingjianClient:
|
|||||||
def _msg_text(data: dict) -> str:
|
def _msg_text(data: dict) -> str:
|
||||||
return str(data.get("msg", "") or "").strip()
|
return str(data.get("msg", "") or "").strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_price_tokens(text: str) -> list[str]:
|
||||||
|
s = str(text or "")
|
||||||
|
if not s:
|
||||||
|
return []
|
||||||
|
out: list[str] = []
|
||||||
|
out += re.findall(r"(?:¥|¥)\s*\d+(?:\.\d{1,2})?", s)
|
||||||
|
out += re.findall(r"\d+(?:\.\d{1,2})?\s*元", s)
|
||||||
|
# 去重保序
|
||||||
|
seen = set()
|
||||||
|
uniq: list[str] = []
|
||||||
|
for x in out:
|
||||||
|
k = x.strip()
|
||||||
|
if k and k not in seen:
|
||||||
|
seen.add(k)
|
||||||
|
uniq.append(k)
|
||||||
|
return uniq
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _route_cn(route: str) -> str:
|
||||||
|
return {
|
||||||
|
"pre_sales": "售前",
|
||||||
|
"quote": "报价",
|
||||||
|
"after_sales": "售后",
|
||||||
|
"risk": "风控",
|
||||||
|
}.get(str(route or ""), "未知")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _status_text(route: str, action: str) -> str:
|
||||||
|
a = str(action or "")
|
||||||
|
if a == "quote":
|
||||||
|
return "开始作图中"
|
||||||
|
if a == "reply":
|
||||||
|
return "已回复客户"
|
||||||
|
if a == "transfer":
|
||||||
|
return "已转人工"
|
||||||
|
if a == "noop":
|
||||||
|
return "仅监听中"
|
||||||
|
if a == "update_state":
|
||||||
|
return "状态已更新"
|
||||||
|
return f"{QingjianClient._route_cn(route)}处理中"
|
||||||
|
|
||||||
def _append_dialogue(self, key: str, role: str, text: str) -> None:
|
def _append_dialogue(self, key: str, role: str, text: str) -> None:
|
||||||
t = str(text or "").strip()
|
t = str(text or "").strip()
|
||||||
if not t:
|
if not t:
|
||||||
@@ -148,21 +189,16 @@ class QingjianClient:
|
|||||||
return t
|
return t
|
||||||
|
|
||||||
def _shorten_reply(self, text: str) -> str:
|
def _shorten_reply(self, text: str) -> str:
|
||||||
max_len = max(8, int(SHORT_REPLY_MAX_CHARS))
|
|
||||||
t = str(text or "").strip()
|
t = str(text or "").strip()
|
||||||
t = self._humanize_reply(t)
|
t = self._humanize_reply(t)
|
||||||
if len(t) <= max_len:
|
# 只取首句,不做按字数硬截断,避免半句/残句
|
||||||
return t
|
|
||||||
# 优先按句号切,避免把一句话硬腰斩成“AI半句”
|
|
||||||
parts = re.split(r"[。!?!?]", t)
|
parts = re.split(r"[。!?!?]", t)
|
||||||
head = next((p.strip() for p in parts if p and p.strip()), "")
|
head = next((p.strip() for p in parts if p and p.strip()), "")
|
||||||
if not head:
|
if not head:
|
||||||
# 无句号时按逗号切第一短分句
|
# 无句号时按逗号切第一分句
|
||||||
sub_parts = re.split(r"[,,;;::]", t)
|
sub_parts = re.split(r"[,,;;::]", t)
|
||||||
head = next((p.strip() for p in sub_parts if p and p.strip()), t)
|
head = next((p.strip() for p in sub_parts if p and p.strip()), t)
|
||||||
if len(head) > max_len:
|
return head or t
|
||||||
head = head[:max_len].rstrip(",,;;:: ")
|
|
||||||
return head or t[:max_len]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _humanize_reply(text: str) -> str:
|
def _humanize_reply(text: str) -> str:
|
||||||
@@ -250,17 +286,27 @@ class QingjianClient:
|
|||||||
activity_event(self.logger, "agent_process_start", trace_id=trace_id, customer_id=context["customer_id"], acc_id=context["acc_id"], intent=context["intent"])
|
activity_event(self.logger, "agent_process_start", trace_id=trace_id, customer_id=context["customer_id"], acc_id=context["acc_id"], intent=context["intent"])
|
||||||
route, decision, state = await self.orchestrator.decide(context)
|
route, decision, state = await self.orchestrator.decide(context)
|
||||||
latency_ms = int((time.perf_counter() - t0) * 1000)
|
latency_ms = int((time.perf_counter() - t0) * 1000)
|
||||||
|
status_text = self._status_text(route, decision.action)
|
||||||
activity_event(
|
activity_event(
|
||||||
self.logger,
|
self.logger,
|
||||||
"agent_process_done",
|
"agent_process_done",
|
||||||
trace_id=trace_id,
|
trace_id=trace_id,
|
||||||
customer_id=context["customer_id"],
|
customer_id=context["customer_id"],
|
||||||
route=route,
|
route=route,
|
||||||
|
route_cn=self._route_cn(route),
|
||||||
action=decision.action,
|
action=decision.action,
|
||||||
reason=decision.reason,
|
reason=status_text,
|
||||||
|
raw_reason=decision.reason,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
after_sales_stage=state.get("after_sales_stage", "new"),
|
after_sales_stage=state.get("after_sales_stage", "new"),
|
||||||
)
|
)
|
||||||
|
# 价格日志(中文)
|
||||||
|
price_hits = []
|
||||||
|
price_hits.extend(self._extract_price_tokens(merged_msg))
|
||||||
|
price_hits.extend(self._extract_price_tokens(decision.reply))
|
||||||
|
price_hits.extend(self._extract_price_tokens(decision.reason))
|
||||||
|
if price_hits:
|
||||||
|
self.logger.info("[价格] 客户=%s 金额=%s", context["customer_id"], " | ".join(price_hits))
|
||||||
|
|
||||||
if decision.action == "transfer":
|
if decision.action == "transfer":
|
||||||
text = (decision.transfer_msg or "").strip()
|
text = (decision.transfer_msg or "").strip()
|
||||||
@@ -281,6 +327,7 @@ class QingjianClient:
|
|||||||
customer_id=context["customer_id"],
|
customer_id=context["customer_id"],
|
||||||
image_url=latest_image,
|
image_url=latest_image,
|
||||||
)
|
)
|
||||||
|
self.logger.info("[作图] 开始 customer=%s", context["customer_id"])
|
||||||
draw_res = await auto_draw_preview(
|
draw_res = await auto_draw_preview(
|
||||||
image_url=latest_image,
|
image_url=latest_image,
|
||||||
customer_id=context["customer_id"],
|
customer_id=context["customer_id"],
|
||||||
@@ -298,6 +345,7 @@ class QingjianClient:
|
|||||||
customer_id=context["customer_id"],
|
customer_id=context["customer_id"],
|
||||||
preview_url=preview_url,
|
preview_url=preview_url,
|
||||||
)
|
)
|
||||||
|
self.logger.info("[作图] 成功 customer=%s url=%s", context["customer_id"], preview_url)
|
||||||
await post_tianwang_callback(
|
await post_tianwang_callback(
|
||||||
"message_processed",
|
"message_processed",
|
||||||
data,
|
data,
|
||||||
@@ -311,6 +359,7 @@ class QingjianClient:
|
|||||||
customer_id=context["customer_id"],
|
customer_id=context["customer_id"],
|
||||||
error=str(draw_res.get("error", "unknown")),
|
error=str(draw_res.get("error", "unknown")),
|
||||||
)
|
)
|
||||||
|
self.logger.error("[作图] 失败 customer=%s error=%s", context["customer_id"], draw_res.get("error", "unknown"))
|
||||||
text = (decision.reply or "").strip()
|
text = (decision.reply or "").strip()
|
||||||
if self._is_invalid_ai_reply(text):
|
if self._is_invalid_ai_reply(text):
|
||||||
text = self._fallback_reply("quote")
|
text = self._fallback_reply("quote")
|
||||||
|
|||||||
@@ -2,21 +2,115 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamColorizer:
|
||||||
|
RESET = "\033[0m"
|
||||||
|
C_AI_THINK = "\033[33m" # yellow
|
||||||
|
C_INBOUND = "\033[36m" # cyan
|
||||||
|
C_OUTBOUND = "\033[32m" # green
|
||||||
|
|
||||||
|
def __init__(self, stream):
|
||||||
|
self.stream = stream
|
||||||
|
|
||||||
|
def write(self, data):
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
out = str(data)
|
||||||
|
if "\x1b[" in out:
|
||||||
|
self.stream.write(out)
|
||||||
|
return
|
||||||
|
if "(thinking):" in out:
|
||||||
|
out = f"{self.C_AI_THINK}{out}{self.RESET}"
|
||||||
|
elif "[收消息]" in out:
|
||||||
|
out = f"{self.C_INBOUND}{out}{self.RESET}"
|
||||||
|
elif "[发送]" in out:
|
||||||
|
out = f"{self.C_OUTBOUND}{out}{self.RESET}"
|
||||||
|
self.stream.write(out)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.stream.flush()
|
||||||
|
|
||||||
|
def isatty(self):
|
||||||
|
return getattr(self.stream, "isatty", lambda: False)()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return getattr(self.stream, "encoding", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
_stream_color_installed = False
|
||||||
|
|
||||||
|
|
||||||
|
def install_stream_colorizer() -> None:
|
||||||
|
global _stream_color_installed
|
||||||
|
if _stream_color_installed:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
sys.stdout = _StreamColorizer(sys.stdout)
|
||||||
|
sys.stderr = _StreamColorizer(sys.stderr)
|
||||||
|
_stream_color_installed = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _ColorFormatter(logging.Formatter):
|
||||||
|
RESET = "\033[0m"
|
||||||
|
C_INFO = "\033[36m" # cyan
|
||||||
|
C_WARN = "\033[33m" # yellow
|
||||||
|
C_ERR = "\033[31m" # red
|
||||||
|
C_EVENT = "\033[35m" # magenta
|
||||||
|
C_DRAW = "\033[34m" # blue
|
||||||
|
C_PRICE = "\033[32m" # green
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
base = super().format(record)
|
||||||
|
msg = str(record.getMessage() or "")
|
||||||
|
if msg.startswith("[活动日志]"):
|
||||||
|
return f"{self.C_EVENT}{base}{self.RESET}"
|
||||||
|
if msg.startswith("[作图]"):
|
||||||
|
return f"{self.C_DRAW}{base}{self.RESET}"
|
||||||
|
if msg.startswith("[价格]"):
|
||||||
|
return f"{self.C_PRICE}{base}{self.RESET}"
|
||||||
|
if record.levelno >= logging.ERROR:
|
||||||
|
return f"{self.C_ERR}{base}{self.RESET}"
|
||||||
|
if record.levelno >= logging.WARNING:
|
||||||
|
return f"{self.C_WARN}{base}{self.RESET}"
|
||||||
|
return f"{self.C_INFO}{base}{self.RESET}"
|
||||||
|
|
||||||
|
|
||||||
def setup_logger() -> logging.Logger:
|
def setup_logger() -> logging.Logger:
|
||||||
|
install_stream_colorizer()
|
||||||
logger = logging.getLogger("qingjian_cs")
|
logger = logging.getLogger("qingjian_cs")
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
return logger
|
return logger
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
formatter = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", "%H:%M:%S")
|
formatter = _ColorFormatter("[%(asctime)s] %(levelname)s: %(message)s", "%H:%M:%S")
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
# 降低 AgentScope 内部推理/格式器日志噪音,保留本项目活动日志。
|
# 降低 AgentScope 内部推理/格式器日志噪音,保留本项目活动日志。
|
||||||
logging.getLogger("agentscope").setLevel(logging.ERROR)
|
logging.getLogger("agentscope").setLevel(logging.ERROR)
|
||||||
logging.getLogger("agentscope.formatter").setLevel(logging.ERROR)
|
logging.getLogger("agentscope.formatter").setLevel(logging.ERROR)
|
||||||
logging.getLogger("agentscope.agent").setLevel(logging.ERROR)
|
logging.getLogger("agentscope.agent").setLevel(logging.ERROR)
|
||||||
logging.getLogger("_openai_formatter").setLevel(logging.ERROR)
|
fmt_logger = logging.getLogger("_openai_formatter")
|
||||||
|
fmt_logger.setLevel(logging.CRITICAL)
|
||||||
|
fmt_logger.propagate = False
|
||||||
|
fmt_logger.disabled = True
|
||||||
|
fmt_logger.handlers.clear()
|
||||||
|
fmt_logger.addHandler(logging.NullHandler())
|
||||||
logging.getLogger("_react_agent").setLevel(logging.ERROR)
|
logging.getLogger("_react_agent").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||||
|
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# 兜底:把当前已注册的同类噪声 logger 一并禁掉
|
||||||
|
for name in list(logging.root.manager.loggerDict.keys()):
|
||||||
|
if "_openai_formatter" in name:
|
||||||
|
lg = logging.getLogger(name)
|
||||||
|
lg.setLevel(logging.CRITICAL)
|
||||||
|
lg.propagate = False
|
||||||
|
lg.disabled = True
|
||||||
|
lg.handlers.clear()
|
||||||
|
lg.addHandler(logging.NullHandler())
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|||||||
@@ -87,7 +87,8 @@ class GeminiExtractV2Service:
|
|||||||
self,
|
self,
|
||||||
input_path: str,
|
input_path: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
custom_prompt: str = None
|
custom_prompt: str = None,
|
||||||
|
aspect_ratio: str = "1:1",
|
||||||
) -> tuple[bool, str, dict]:
|
) -> tuple[bool, str, dict]:
|
||||||
"""
|
"""
|
||||||
使用多API配置进行印花图案提取
|
使用多API配置进行印花图案提取
|
||||||
@@ -124,6 +125,10 @@ class GeminiExtractV2Service:
|
|||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
image_config = {}
|
||||||
|
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||||||
|
if aspect_ratio in valid_ratios:
|
||||||
|
image_config["aspectRatio"] = aspect_ratio
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"contents": [
|
"contents": [
|
||||||
@@ -143,8 +148,8 @@ class GeminiExtractV2Service:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"generationConfig": {
|
"generationConfig": {
|
||||||
"responseModalities": ["IMAGE"] # 只生成图片
|
"responseModalities": ["IMAGE"], # 只生成图片
|
||||||
# 不传imageConfig,让输出图片比例与输入图片一致
|
**({"imageConfig": image_config} if image_config else {}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user