219 lines
7.9 KiB
Python
219 lines
7.9 KiB
Python
import time
|
||
import asyncio
|
||
import aiohttp
|
||
from PIL import Image
|
||
from pydantic import BaseModel
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
api_key = '8e32d44e3007447cb4be6ee52c5d3110'
|
||
|
||
|
||
class UploadInfo(BaseModel):
|
||
fileName: str
|
||
fileType: str
|
||
|
||
|
||
class CreateInfo(BaseModel):
|
||
taskId: str # 创建的任务 ID,可用于查询状态或获取结果
|
||
taskStatus: str # 初始状态,可能为:QUEUED、RUNNING、FAILED
|
||
clientId: str # 平台内部标识,用于排错,无需关注
|
||
netWssUrl: str # WebSocket 地址(当前不稳定,不推荐使用)
|
||
promptTips: str # ComfyUI 校验信息(字符串格式的 JSON),可用于识别配置异常节点
|
||
|
||
|
||
class RunHubResponse(BaseModel):
|
||
code: int # 状态码,0 表示成功
|
||
msg: str # 提示信息
|
||
data: UploadInfo | CreateInfo | str | None = None # 数据对象
|
||
|
||
class Config:
|
||
extra = 'allow' # 允许添加额外字段
|
||
|
||
|
||
async def upload(img_path: str) -> RunHubResponse:
|
||
with open(img_path, 'rb') as f:
|
||
img_data = f.read()
|
||
|
||
form = aiohttp.FormData()
|
||
form.add_field('apiKey', api_key)
|
||
form.add_field('file', img_data, filename='image.jpg', content_type='image/jpeg')
|
||
form.add_field('fileType', 'image')
|
||
|
||
url = 'https://www.runninghub.cn/task/openapi/upload'
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, data=form) as resp:
|
||
response = await resp.json()
|
||
|
||
return RunHubResponse.model_validate(response)
|
||
|
||
|
||
async def create(workflow_id: str, node_info_list: list[dict[str, str]]) -> RunHubResponse:
|
||
url = 'https://www.runninghub.cn/task/openapi/create'
|
||
json_data = {'apiKey': api_key, 'workflowId': workflow_id, 'nodeInfoList': node_info_list}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=json_data) as resp:
|
||
response = await resp.json()
|
||
|
||
return RunHubResponse.model_validate(response)
|
||
|
||
|
||
async def status(task_id: str) -> RunHubResponse:
|
||
# 查询状态
|
||
url = 'https://www.runninghub.cn/task/openapi/status'
|
||
payload = {'apiKey': api_key, 'taskId': task_id}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload) as resp:
|
||
response = await resp.json()
|
||
|
||
# ["QUEUED","RUNNING","FAILED","SUCCESS"]
|
||
return RunHubResponse.model_validate(response)
|
||
|
||
|
||
async def outputs(task_id: str) -> dict:
|
||
# 获取结果
|
||
url = 'https://www.runninghub.cn/task/openapi/outputs'
|
||
payload = {'apiKey': api_key, 'taskId': task_id}
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(url, json=payload) as resp:
|
||
response = await resp.json()
|
||
|
||
return response
|
||
|
||
|
||
async def 花纹提取_api(img_path: str, save_path: str, prompt: str = '') -> bool:
|
||
"""
|
||
异步花纹提取API
|
||
|
||
Args:
|
||
img_path: 输入图片路径
|
||
save_path: 输出图片路径
|
||
prompt: 自定义提示词,为空则使用默认提示词
|
||
|
||
Returns:
|
||
bool: 处理是否成功
|
||
"""
|
||
try:
|
||
upload_res = await upload(img_path=img_path)
|
||
if upload_res.code != 0 or not upload_res.data:
|
||
logger.error(f"Qwen上传失败: code={upload_res.code}, msg={upload_res.msg}")
|
||
return False
|
||
|
||
# 确保 data 是 UploadInfo 类型
|
||
if not hasattr(upload_res.data, 'fileName'):
|
||
logger.error(f"Qwen上传返回数据格式错误: {upload_res.data}")
|
||
return False
|
||
|
||
logger.info(f"Qwen上传成功: {upload_res.data.fileName}")
|
||
|
||
workflow_id = '1980864078929379330'
|
||
if len(prompt) == 0:
|
||
prompt = '提取桌布上的花纹,自动补全空白,使得所有位置饱满并且完美衔接,去除所有的皱纹和扭曲和凸凹不平,图案自动摆正对齐并且铺平,使直线变得笔直,平行的花纹更有规律,没有残缺的花纹和折痕和断痕,铺满画布,完整的图案。简单的纯色背景'
|
||
|
||
node_info_list = [
|
||
{
|
||
'nodeId': '78',
|
||
'fieldName': 'image',
|
||
'fieldValue': upload_res.data.fileName,
|
||
},
|
||
{
|
||
'nodeId': '103',
|
||
'fieldName': 'text',
|
||
'fieldValue': prompt,
|
||
},
|
||
]
|
||
create_res = await create(workflow_id=workflow_id, node_info_list=node_info_list)
|
||
|
||
if create_res.code != 0 or not create_res.data:
|
||
logger.error(f"Qwen任务创建失败: code={create_res.code}, msg={create_res.msg}")
|
||
return False
|
||
|
||
# 确保 data 是 CreateInfo 类型
|
||
if not hasattr(create_res.data, 'taskId'):
|
||
logger.error(f"Qwen任务创建返回数据格式错误: {create_res.data}")
|
||
return False
|
||
|
||
task_id = create_res.data.taskId
|
||
logger.info(f"Qwen任务创建成功: {task_id}")
|
||
|
||
# 轮询检查状态
|
||
max_retries = 120 # 最多等待10分钟(120次 * 5秒)
|
||
retry_count = 0
|
||
|
||
while retry_count < max_retries:
|
||
status_res = await status(task_id=task_id)
|
||
if status_res.code == 0:
|
||
if status_res.data == 'QUEUED':
|
||
logger.info('Qwen队列排队中...')
|
||
elif status_res.data == 'RUNNING':
|
||
logger.info('Qwen正在处理中...')
|
||
elif status_res.data == 'FAILED':
|
||
logger.error(f'Qwen处理失败: {status_res}')
|
||
return False
|
||
elif status_res.data == 'SUCCESS':
|
||
logger.info('Qwen处理完成,开始下载结果')
|
||
outputs_res = await outputs(task_id=task_id)
|
||
img_url = outputs_res['data'][0]['fileUrl']
|
||
|
||
# 下载结果图片
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(img_url) as resp:
|
||
img_data = await resp.read()
|
||
|
||
with open(save_path, 'wb') as f:
|
||
f.write(img_data)
|
||
|
||
logger.info(f"Qwen结果保存成功: {save_path}")
|
||
try:
|
||
from utils.api_cost_tracker import record
|
||
record("qwen_enhance", count=1)
|
||
except Exception:
|
||
pass
|
||
return True
|
||
|
||
await asyncio.sleep(5) # 每5秒检查一次
|
||
retry_count += 1
|
||
else:
|
||
logger.error(f'Qwen处理失败: {status_res}')
|
||
return False
|
||
|
||
logger.error(f"Qwen处理超时,超过{max_retries * 5}秒")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Qwen花纹提取异常: {e}")
|
||
import traceback
|
||
logger.error(f"异常堆栈: {traceback.format_exc()}")
|
||
return False
|
||
|
||
|
||
async def 清晰化_api(img_path: str, save_path: str) -> bool:
|
||
"""
|
||
高清增强:对透视矫正后的图案进行清晰化处理。
|
||
使用与花纹提取相同的 ComfyUI 工作流,但提示词聚焦于清晰度增强。
|
||
|
||
Args:
|
||
img_path: 输入图片路径(透视矫正后的结果)
|
||
save_path: 输出图片路径
|
||
|
||
Returns:
|
||
bool: 处理是否成功
|
||
"""
|
||
prompt = (
|
||
"对这张已展平的图案进行高清增强处理:"
|
||
"提升整体清晰度和锐利度,修复模糊边缘,补全细节纹理,"
|
||
"使图案线条清晰笔直,颜色鲜艳均匀,"
|
||
"去除噪点和压缩痕迹,输出印刷级高质量平面图,"
|
||
"背景保持纯白色,不要改变图案内容和构图。"
|
||
)
|
||
return await 花纹提取_api(img_path=img_path, save_path=save_path, prompt=prompt)
|
||
|
||
|
||
# 测试代码(注释掉)
|
||
# if __name__ == "__main__":
|
||
# asyncio.run(花纹提取_api(img_path=r'1.jpg', save_path='save1.png', prompt='')) |