572 lines
18 KiB
Python
572 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI 聊天接口
|
||
功能:对话管理 + 消息存库 + function calling(工具调用)
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
import json, logging
|
||
from sqlalchemy.orm import Session as DBSession
|
||
from app.core.config import settings
|
||
from app.core.security import get_current_user
|
||
from app.db import get_db
|
||
from app.models.user import User
|
||
from app.models.chat import ChatSession, ChatMessage
|
||
from app.api.v1.ai_skills import list_ai_skills_public, resolve_ai_skill
|
||
from app.api.v1.ai_llm import (
|
||
call_llm_with_tools,
|
||
call_vision_llm,
|
||
get_runtime_ai_config,
|
||
has_ai_config,
|
||
mock_reply,
|
||
)
|
||
|
||
log = logging.getLogger(__name__)
|
||
if not log.handlers:
|
||
log.setLevel(logging.INFO)
|
||
log.propagate = False
|
||
_h = logging.StreamHandler()
|
||
_h.setFormatter(
|
||
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s")
|
||
)
|
||
log.addHandler(_h)
|
||
|
||
router = APIRouter()
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
message: str
|
||
session_id: Optional[int] = None
|
||
image_base64: Optional[str] = None
|
||
skill_id: Optional[str] = None
|
||
model: Optional[str] = None
|
||
vision_model: Optional[str] = None
|
||
image_edit_model: Optional[str] = None
|
||
|
||
|
||
class ToolResultRequest(BaseModel):
|
||
session_id: int
|
||
tool_call_id: str
|
||
tool_name: str
|
||
result: str
|
||
skill_id: Optional[str] = None
|
||
model: Optional[str] = None
|
||
vision_model: Optional[str] = None
|
||
image_edit_model: Optional[str] = None
|
||
|
||
|
||
class WorkflowStageUpdate(BaseModel):
|
||
id: str
|
||
title: str
|
||
description: str
|
||
status: str
|
||
detail: Optional[str] = None
|
||
|
||
|
||
class SessionWorkflowUpdateRequest(BaseModel):
|
||
session_id: int
|
||
active_skill_id: Optional[str] = None
|
||
workflow_state: List[WorkflowStageUpdate] = []
|
||
|
||
|
||
def _should_use_vision_shortcut(skill_mode: str, message: str, has_image: bool) -> bool:
|
||
if not has_image:
|
||
return False
|
||
lowered = (message or "").lower()
|
||
help_or_action_tokens = (
|
||
"怎么做",
|
||
"怎么操作",
|
||
"如何操作",
|
||
"教我",
|
||
"帮我做",
|
||
"帮我改",
|
||
"帮我操作",
|
||
"直接做",
|
||
"不会",
|
||
"为什么",
|
||
"报错",
|
||
"错误",
|
||
"失败",
|
||
"没反应",
|
||
"不能",
|
||
"无法",
|
||
"步骤",
|
||
"下一步",
|
||
)
|
||
if any(token in lowered for token in help_or_action_tokens):
|
||
return False
|
||
if skill_mode == "vision_chat":
|
||
return True
|
||
if skill_mode != "hybrid_chat":
|
||
return False
|
||
|
||
action_tokens = (
|
||
"出图",
|
||
"生成",
|
||
"方案图",
|
||
"方向图",
|
||
"效果图",
|
||
"背景图",
|
||
"概念图",
|
||
"排版",
|
||
"执行",
|
||
"落地",
|
||
)
|
||
if any(token in lowered for token in action_tokens):
|
||
return False
|
||
|
||
analysis_tokens = ("看图", "分析", "你看见", "风格", "版型", "配色", "元素", "建议")
|
||
return any(token in lowered for token in analysis_tokens)
|
||
|
||
|
||
def _decode_workflow_state(raw: Optional[str]) -> list[dict]:
|
||
if not raw:
|
||
return []
|
||
try:
|
||
payload = json.loads(raw)
|
||
return payload if isinstance(payload, list) else []
|
||
except Exception:
|
||
return []
|
||
|
||
|
||
def _store_session_state(
|
||
session: ChatSession,
|
||
db: DBSession,
|
||
*,
|
||
active_skill_id: Optional[str] = None,
|
||
workflow_state: Optional[list[dict]] = None,
|
||
):
|
||
if active_skill_id is not None:
|
||
session.active_skill_id = active_skill_id
|
||
if workflow_state is not None:
|
||
session.workflow_state = json.dumps(workflow_state, ensure_ascii=False)
|
||
db.add(session)
|
||
|
||
|
||
# ==================== 对话管理接口 ====================
|
||
|
||
|
||
@router.get("/ai/models")
|
||
async def get_models(
|
||
db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user)
|
||
):
|
||
"""获取可用模型列表和当前默认值"""
|
||
def build_options(configured: str, current: str, default_name: str):
|
||
items = [item.strip() for item in configured.split(",") if item.strip()]
|
||
if current and current not in items:
|
||
items.insert(0, current)
|
||
if not items and current:
|
||
items = [current]
|
||
return [{"id": item, "name": item or default_name} for item in items]
|
||
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
chat_model = (getattr(user, "ai_model", "") or "").strip() or settings.AI_MODEL
|
||
vision_model = (getattr(user, "ai_vision_model", "") or "").strip() or settings.AI_VISION_MODEL
|
||
image_model = (getattr(user, "ai_image_model", "") or "").strip() or settings.AI_IMAGE_EDIT_MODEL
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"chat_model": chat_model,
|
||
"vision_model": vision_model,
|
||
"image_edit_model": image_model,
|
||
"chat_models": build_options(
|
||
settings.AI_CHAT_MODELS, chat_model, "当前对话模型"
|
||
),
|
||
"vision_models": build_options(
|
||
settings.AI_VISION_MODELS,
|
||
vision_model,
|
||
"当前视觉模型",
|
||
),
|
||
"image_edit_models": build_options(
|
||
settings.AI_IMAGE_EDIT_MODELS,
|
||
image_model,
|
||
"当前图片模型",
|
||
),
|
||
},
|
||
}
|
||
|
||
|
||
@router.get("/ai/skills")
|
||
async def get_skills(current_username: str = Depends(get_current_user)):
|
||
"""获取可用技能列表"""
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"default_skill": "auto",
|
||
"skills": list_ai_skills_public(),
|
||
},
|
||
}
|
||
|
||
|
||
@router.get("/ai/sessions")
|
||
async def list_sessions(
|
||
db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user)
|
||
):
|
||
"""获取当前用户的对话列表"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
sessions = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.user_id == user.id)
|
||
.order_by(ChatSession.updated_at.desc())
|
||
.limit(50)
|
||
.all()
|
||
)
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": [
|
||
{
|
||
"id": s.id,
|
||
"title": s.title,
|
||
"active_skill_id": s.active_skill_id,
|
||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||
"updated_at": s.updated_at.isoformat() if s.updated_at else None,
|
||
}
|
||
for s in sessions
|
||
],
|
||
}
|
||
|
||
|
||
@router.get("/ai/sessions/{session_id}/messages")
|
||
async def get_session_messages(
|
||
session_id: int,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user),
|
||
):
|
||
"""获取某个对话的全部消息"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
|
||
.first()
|
||
)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
messages = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session_id)
|
||
.order_by(ChatMessage.created_at)
|
||
.all()
|
||
)
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"title": session.title,
|
||
"active_skill_id": session.active_skill_id,
|
||
"workflow_state": _decode_workflow_state(session.workflow_state),
|
||
"messages": [
|
||
{
|
||
"id": m.id,
|
||
"role": m.role,
|
||
"content": m.content,
|
||
"tool_calls": m.tool_calls,
|
||
"created_at": m.created_at.isoformat() if m.created_at else None,
|
||
}
|
||
for m in messages
|
||
],
|
||
},
|
||
}
|
||
|
||
|
||
@router.delete("/ai/sessions/{session_id}")
|
||
async def delete_session(
|
||
session_id: int,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user),
|
||
):
|
||
"""删除对话"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
|
||
.first()
|
||
)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
db.query(ChatMessage).filter(ChatMessage.session_id == session_id).delete()
|
||
db.delete(session)
|
||
db.commit()
|
||
return {"code": 200, "message": "对话已删除"}
|
||
|
||
|
||
@router.post("/ai/sessions/workflow")
|
||
async def update_session_workflow(
|
||
data: SessionWorkflowUpdateRequest,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user),
|
||
):
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id)
|
||
.first()
|
||
)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
workflow_payload = [stage.model_dump() for stage in data.workflow_state]
|
||
_store_session_state(
|
||
session,
|
||
db,
|
||
active_skill_id=data.active_skill_id,
|
||
workflow_state=workflow_payload,
|
||
)
|
||
db.commit()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"active_skill_id": session.active_skill_id,
|
||
"workflow_state": workflow_payload,
|
||
},
|
||
}
|
||
|
||
|
||
# ==================== 聊天接口 ====================
|
||
|
||
|
||
@router.post("/ai/chat")
|
||
async def chat(
|
||
data: ChatRequest,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user),
|
||
):
|
||
"""发送消息 → AI 回复(可能包含工具调用指令)"""
|
||
log.info(f"{'=' * 60}")
|
||
log.info(
|
||
f"[Chat] 收到消息: '{data.message[:80]}' session={data.session_id} 有图片={'是' if data.image_base64 else '否'}"
|
||
)
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = None
|
||
if data.session_id:
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id)
|
||
.first()
|
||
)
|
||
|
||
if not session:
|
||
title = data.message[:30] + ("..." if len(data.message) > 30 else "")
|
||
session = ChatSession(user_id=user.id, username=current_username, title=title)
|
||
db.add(session)
|
||
db.flush()
|
||
|
||
has_image = bool(data.image_base64)
|
||
save_content = ("[图片分析] " + data.message) if has_image else data.message
|
||
db.add(ChatMessage(session_id=session.id, role="user", content=save_content))
|
||
db.commit()
|
||
|
||
history = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id)
|
||
.order_by(ChatMessage.created_at)
|
||
.all()
|
||
)
|
||
history_list = [
|
||
{"role": m.role if m.role != "tool" else "user", "content": m.content}
|
||
for m in history[-20:]
|
||
]
|
||
history_context = history_list[:-1] if history_list else []
|
||
selected_skill = resolve_ai_skill(
|
||
message=data.message,
|
||
history_messages=history_context,
|
||
requested_skill_id=data.skill_id,
|
||
has_image=has_image,
|
||
)
|
||
selected_skill_data = selected_skill.to_public_dict()
|
||
_store_session_state(session, db, active_skill_id=selected_skill.id)
|
||
runtime_ai_config = get_runtime_ai_config(user)
|
||
effective_chat_model = (data.model or user.ai_model or settings.AI_MODEL) if user else (data.model or settings.AI_MODEL)
|
||
effective_vision_model = (data.vision_model or user.ai_vision_model or settings.AI_VISION_MODEL) if user else (data.vision_model or settings.AI_VISION_MODEL)
|
||
|
||
try:
|
||
if not has_ai_config(runtime_ai_config):
|
||
reply_content = mock_reply(data.message, has_image)
|
||
tool_calls_data = None
|
||
else:
|
||
if _should_use_vision_shortcut(
|
||
selected_skill.mode, data.message, has_image
|
||
):
|
||
reply_content = call_vision_llm(
|
||
data.message or "请分析这张图片。",
|
||
data.image_base64 or "",
|
||
history_context,
|
||
model_override=effective_vision_model,
|
||
skill_id=selected_skill.id,
|
||
runtime_config=runtime_ai_config,
|
||
)
|
||
tool_calls_data = None
|
||
else:
|
||
call_history = history_list
|
||
if has_image:
|
||
call_history = history_list[:-1] + [
|
||
{
|
||
"role": "user",
|
||
"content": (
|
||
f"[用户上传了一张图片,已保存可供工具使用]"
|
||
f"[当前技能: {selected_skill.name}] "
|
||
f"{data.message or '请处理这张图片'}"
|
||
),
|
||
}
|
||
]
|
||
|
||
(
|
||
reply_content,
|
||
tool_calls_data,
|
||
selected_skill_data,
|
||
) = call_llm_with_tools(
|
||
call_history,
|
||
model_override=effective_chat_model,
|
||
skill_id=selected_skill.id,
|
||
has_image=has_image,
|
||
runtime_config=runtime_ai_config,
|
||
)
|
||
except Exception as e:
|
||
reply_content = f"AI 请求出错: {str(e)}"
|
||
tool_calls_data = None
|
||
|
||
tc_json = (
|
||
json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
|
||
)
|
||
db.add(
|
||
ChatMessage(
|
||
session_id=session.id,
|
||
role="assistant",
|
||
content=reply_content or "",
|
||
tool_calls=tc_json,
|
||
)
|
||
)
|
||
_store_session_state(session, db, active_skill_id=selected_skill_data.get("id"))
|
||
db.commit()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"content": reply_content or "",
|
||
"tool_calls": tool_calls_data,
|
||
"skill": selected_skill_data,
|
||
},
|
||
}
|
||
|
||
|
||
@router.post("/ai/tool-result")
|
||
async def submit_tool_result(
|
||
data: ToolResultRequest,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user),
|
||
):
|
||
"""前端执行完工具后回传结果,后端继续让 AI 总结"""
|
||
log.info(f"{'=' * 60}")
|
||
log.info(f"[ToolResult] 工具: {data.tool_name}, 结果: {data.result[:150]}")
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = (
|
||
db.query(ChatSession)
|
||
.filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id)
|
||
.first()
|
||
)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
db.add(
|
||
ChatMessage(
|
||
session_id=session.id,
|
||
role="tool",
|
||
content=f"[工具 {data.tool_name} 执行结果]: {data.result}",
|
||
)
|
||
)
|
||
db.commit()
|
||
|
||
history = (
|
||
db.query(ChatMessage)
|
||
.filter(ChatMessage.session_id == session.id)
|
||
.order_by(ChatMessage.created_at)
|
||
.all()
|
||
)
|
||
history_list = [
|
||
{"role": m.role if m.role != "tool" else "user", "content": m.content}
|
||
for m in history[-20:]
|
||
]
|
||
selected_skill = resolve_ai_skill(
|
||
message=f"{data.tool_name} 已执行",
|
||
history_messages=history_list,
|
||
requested_skill_id=data.skill_id,
|
||
has_image=False,
|
||
)
|
||
selected_skill_data = selected_skill.to_public_dict()
|
||
_store_session_state(session, db, active_skill_id=selected_skill.id)
|
||
runtime_ai_config = get_runtime_ai_config(user)
|
||
effective_chat_model = (data.model or user.ai_model or settings.AI_MODEL) if user else (data.model or settings.AI_MODEL)
|
||
|
||
try:
|
||
if not has_ai_config(runtime_ai_config):
|
||
reply_content = f"工具 {data.tool_name} 执行完成。"
|
||
tool_calls_data = None
|
||
else:
|
||
(
|
||
reply_content,
|
||
tool_calls_data,
|
||
selected_skill_data,
|
||
) = call_llm_with_tools(
|
||
history_list,
|
||
model_override=effective_chat_model,
|
||
skill_id=selected_skill.id,
|
||
has_image=False,
|
||
runtime_config=runtime_ai_config,
|
||
)
|
||
except Exception as e:
|
||
reply_content = f"AI 总结出错: {str(e)}"
|
||
tool_calls_data = None
|
||
|
||
tc_json = (
|
||
json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
|
||
)
|
||
db.add(
|
||
ChatMessage(
|
||
session_id=session.id,
|
||
role="assistant",
|
||
content=reply_content or "",
|
||
tool_calls=tc_json,
|
||
)
|
||
)
|
||
_store_session_state(session, db, active_skill_id=selected_skill_data.get("id"))
|
||
db.commit()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"content": reply_content or "",
|
||
"tool_calls": tool_calls_data,
|
||
"skill": selected_skill_data,
|
||
},
|
||
}
|