Files
DP/Server/app/api/v1/ai_chat.py

572 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
AI 聊天接口
功能:对话管理 + 消息存库 + function calling工具调用
"""
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from typing import List, Optional
import json, logging
from sqlalchemy.orm import Session as DBSession
from app.core.config import settings
from app.core.security import get_current_user
from app.db import get_db
from app.models.user import User
from app.models.chat import ChatSession, ChatMessage
from app.api.v1.ai_skills import list_ai_skills_public, resolve_ai_skill
from app.api.v1.ai_llm import (
call_llm_with_tools,
call_vision_llm,
get_runtime_ai_config,
has_ai_config,
mock_reply,
)
log = logging.getLogger(__name__)
if not log.handlers:
log.setLevel(logging.INFO)
log.propagate = False
_h = logging.StreamHandler()
_h.setFormatter(
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s")
)
log.addHandler(_h)
router = APIRouter()
# ==================== 数据模型 ====================
class ChatRequest(BaseModel):
message: str
session_id: Optional[int] = None
image_base64: Optional[str] = None
skill_id: Optional[str] = None
model: Optional[str] = None
vision_model: Optional[str] = None
image_edit_model: Optional[str] = None
class ToolResultRequest(BaseModel):
session_id: int
tool_call_id: str
tool_name: str
result: str
skill_id: Optional[str] = None
model: Optional[str] = None
vision_model: Optional[str] = None
image_edit_model: Optional[str] = None
class WorkflowStageUpdate(BaseModel):
id: str
title: str
description: str
status: str
detail: Optional[str] = None
class SessionWorkflowUpdateRequest(BaseModel):
session_id: int
active_skill_id: Optional[str] = None
workflow_state: List[WorkflowStageUpdate] = []
def _should_use_vision_shortcut(skill_mode: str, message: str, has_image: bool) -> bool:
if not has_image:
return False
lowered = (message or "").lower()
help_or_action_tokens = (
"怎么做",
"怎么操作",
"如何操作",
"教我",
"帮我做",
"帮我改",
"帮我操作",
"直接做",
"不会",
"为什么",
"报错",
"错误",
"失败",
"没反应",
"不能",
"无法",
"步骤",
"下一步",
)
if any(token in lowered for token in help_or_action_tokens):
return False
if skill_mode == "vision_chat":
return True
if skill_mode != "hybrid_chat":
return False
action_tokens = (
"出图",
"生成",
"方案图",
"方向图",
"效果图",
"背景图",
"概念图",
"排版",
"执行",
"落地",
)
if any(token in lowered for token in action_tokens):
return False
analysis_tokens = ("看图", "分析", "你看见", "风格", "版型", "配色", "元素", "建议")
return any(token in lowered for token in analysis_tokens)
def _decode_workflow_state(raw: Optional[str]) -> list[dict]:
if not raw:
return []
try:
payload = json.loads(raw)
return payload if isinstance(payload, list) else []
except Exception:
return []
def _store_session_state(
session: ChatSession,
db: DBSession,
*,
active_skill_id: Optional[str] = None,
workflow_state: Optional[list[dict]] = None,
):
if active_skill_id is not None:
session.active_skill_id = active_skill_id
if workflow_state is not None:
session.workflow_state = json.dumps(workflow_state, ensure_ascii=False)
db.add(session)
# ==================== 对话管理接口 ====================
@router.get("/ai/models")
async def get_models(
db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user)
):
"""获取可用模型列表和当前默认值"""
def build_options(configured: str, current: str, default_name: str):
items = [item.strip() for item in configured.split(",") if item.strip()]
if current and current not in items:
items.insert(0, current)
if not items and current:
items = [current]
return [{"id": item, "name": item or default_name} for item in items]
user = db.query(User).filter(User.username == current_username).first()
chat_model = (getattr(user, "ai_model", "") or "").strip() or settings.AI_MODEL
vision_model = (getattr(user, "ai_vision_model", "") or "").strip() or settings.AI_VISION_MODEL
image_model = (getattr(user, "ai_image_model", "") or "").strip() or settings.AI_IMAGE_EDIT_MODEL
return {
"code": 200,
"data": {
"chat_model": chat_model,
"vision_model": vision_model,
"image_edit_model": image_model,
"chat_models": build_options(
settings.AI_CHAT_MODELS, chat_model, "当前对话模型"
),
"vision_models": build_options(
settings.AI_VISION_MODELS,
vision_model,
"当前视觉模型",
),
"image_edit_models": build_options(
settings.AI_IMAGE_EDIT_MODELS,
image_model,
"当前图片模型",
),
},
}
@router.get("/ai/skills")
async def get_skills(current_username: str = Depends(get_current_user)):
"""获取可用技能列表"""
return {
"code": 200,
"data": {
"default_skill": "auto",
"skills": list_ai_skills_public(),
},
}
@router.get("/ai/sessions")
async def list_sessions(
db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user)
):
"""获取当前用户的对话列表"""
user = db.query(User).filter(User.username == current_username).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
sessions = (
db.query(ChatSession)
.filter(ChatSession.user_id == user.id)
.order_by(ChatSession.updated_at.desc())
.limit(50)
.all()
)
return {
"code": 200,
"data": [
{
"id": s.id,
"title": s.title,
"active_skill_id": s.active_skill_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"updated_at": s.updated_at.isoformat() if s.updated_at else None,
}
for s in sessions
],
}
@router.get("/ai/sessions/{session_id}/messages")
async def get_session_messages(
session_id: int,
db: DBSession = Depends(get_db),
current_username: str = Depends(get_current_user),
):
"""获取某个对话的全部消息"""
user = db.query(User).filter(User.username == current_username).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="对话不存在")
messages = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session_id)
.order_by(ChatMessage.created_at)
.all()
)
return {
"code": 200,
"data": {
"session_id": session.id,
"title": session.title,
"active_skill_id": session.active_skill_id,
"workflow_state": _decode_workflow_state(session.workflow_state),
"messages": [
{
"id": m.id,
"role": m.role,
"content": m.content,
"tool_calls": m.tool_calls,
"created_at": m.created_at.isoformat() if m.created_at else None,
}
for m in messages
],
},
}
@router.delete("/ai/sessions/{session_id}")
async def delete_session(
session_id: int,
db: DBSession = Depends(get_db),
current_username: str = Depends(get_current_user),
):
"""删除对话"""
user = db.query(User).filter(User.username == current_username).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
session = (
db.query(ChatSession)
.filter(ChatSession.id == session_id, ChatSession.user_id == user.id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="对话不存在")
db.query(ChatMessage).filter(ChatMessage.session_id == session_id).delete()
db.delete(session)
db.commit()
return {"code": 200, "message": "对话已删除"}
@router.post("/ai/sessions/workflow")
async def update_session_workflow(
data: SessionWorkflowUpdateRequest,
db: DBSession = Depends(get_db),
current_username: str = Depends(get_current_user),
):
user = db.query(User).filter(User.username == current_username).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
session = (
db.query(ChatSession)
.filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="对话不存在")
workflow_payload = [stage.model_dump() for stage in data.workflow_state]
_store_session_state(
session,
db,
active_skill_id=data.active_skill_id,
workflow_state=workflow_payload,
)
db.commit()
return {
"code": 200,
"data": {
"session_id": session.id,
"active_skill_id": session.active_skill_id,
"workflow_state": workflow_payload,
},
}
# ==================== 聊天接口 ====================
@router.post("/ai/chat")
async def chat(
data: ChatRequest,
db: DBSession = Depends(get_db),
current_username: str = Depends(get_current_user),
):
"""发送消息 → AI 回复(可能包含工具调用指令)"""
log.info(f"{'=' * 60}")
log.info(
f"[Chat] 收到消息: '{data.message[:80]}' session={data.session_id} 有图片={'' if data.image_base64 else ''}"
)
user = db.query(User).filter(User.username == current_username).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
session = None
if data.session_id:
session = (
db.query(ChatSession)
.filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id)
.first()
)
if not session:
title = data.message[:30] + ("..." if len(data.message) > 30 else "")
session = ChatSession(user_id=user.id, username=current_username, title=title)
db.add(session)
db.flush()
has_image = bool(data.image_base64)
save_content = ("[图片分析] " + data.message) if has_image else data.message
db.add(ChatMessage(session_id=session.id, role="user", content=save_content))
db.commit()
history = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.created_at)
.all()
)
history_list = [
{"role": m.role if m.role != "tool" else "user", "content": m.content}
for m in history[-20:]
]
history_context = history_list[:-1] if history_list else []
selected_skill = resolve_ai_skill(
message=data.message,
history_messages=history_context,
requested_skill_id=data.skill_id,
has_image=has_image,
)
selected_skill_data = selected_skill.to_public_dict()
_store_session_state(session, db, active_skill_id=selected_skill.id)
runtime_ai_config = get_runtime_ai_config(user)
effective_chat_model = (data.model or user.ai_model or settings.AI_MODEL) if user else (data.model or settings.AI_MODEL)
effective_vision_model = (data.vision_model or user.ai_vision_model or settings.AI_VISION_MODEL) if user else (data.vision_model or settings.AI_VISION_MODEL)
try:
if not has_ai_config(runtime_ai_config):
reply_content = mock_reply(data.message, has_image)
tool_calls_data = None
else:
if _should_use_vision_shortcut(
selected_skill.mode, data.message, has_image
):
reply_content = call_vision_llm(
data.message or "请分析这张图片。",
data.image_base64 or "",
history_context,
model_override=effective_vision_model,
skill_id=selected_skill.id,
runtime_config=runtime_ai_config,
)
tool_calls_data = None
else:
call_history = history_list
if has_image:
call_history = history_list[:-1] + [
{
"role": "user",
"content": (
f"[用户上传了一张图片,已保存可供工具使用]"
f"[当前技能: {selected_skill.name}] "
f"{data.message or '请处理这张图片'}"
),
}
]
(
reply_content,
tool_calls_data,
selected_skill_data,
) = call_llm_with_tools(
call_history,
model_override=effective_chat_model,
skill_id=selected_skill.id,
has_image=has_image,
runtime_config=runtime_ai_config,
)
except Exception as e:
reply_content = f"AI 请求出错: {str(e)}"
tool_calls_data = None
tc_json = (
json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
)
db.add(
ChatMessage(
session_id=session.id,
role="assistant",
content=reply_content or "",
tool_calls=tc_json,
)
)
_store_session_state(session, db, active_skill_id=selected_skill_data.get("id"))
db.commit()
return {
"code": 200,
"data": {
"session_id": session.id,
"content": reply_content or "",
"tool_calls": tool_calls_data,
"skill": selected_skill_data,
},
}
@router.post("/ai/tool-result")
async def submit_tool_result(
data: ToolResultRequest,
db: DBSession = Depends(get_db),
current_username: str = Depends(get_current_user),
):
"""前端执行完工具后回传结果,后端继续让 AI 总结"""
log.info(f"{'=' * 60}")
log.info(f"[ToolResult] 工具: {data.tool_name}, 结果: {data.result[:150]}")
user = db.query(User).filter(User.username == current_username).first()
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
session = (
db.query(ChatSession)
.filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id)
.first()
)
if not session:
raise HTTPException(status_code=404, detail="对话不存在")
db.add(
ChatMessage(
session_id=session.id,
role="tool",
content=f"[工具 {data.tool_name} 执行结果]: {data.result}",
)
)
db.commit()
history = (
db.query(ChatMessage)
.filter(ChatMessage.session_id == session.id)
.order_by(ChatMessage.created_at)
.all()
)
history_list = [
{"role": m.role if m.role != "tool" else "user", "content": m.content}
for m in history[-20:]
]
selected_skill = resolve_ai_skill(
message=f"{data.tool_name} 已执行",
history_messages=history_list,
requested_skill_id=data.skill_id,
has_image=False,
)
selected_skill_data = selected_skill.to_public_dict()
_store_session_state(session, db, active_skill_id=selected_skill.id)
runtime_ai_config = get_runtime_ai_config(user)
effective_chat_model = (data.model or user.ai_model or settings.AI_MODEL) if user else (data.model or settings.AI_MODEL)
try:
if not has_ai_config(runtime_ai_config):
reply_content = f"工具 {data.tool_name} 执行完成。"
tool_calls_data = None
else:
(
reply_content,
tool_calls_data,
selected_skill_data,
) = call_llm_with_tools(
history_list,
model_override=effective_chat_model,
skill_id=selected_skill.id,
has_image=False,
runtime_config=runtime_ai_config,
)
except Exception as e:
reply_content = f"AI 总结出错: {str(e)}"
tool_calls_data = None
tc_json = (
json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
)
db.add(
ChatMessage(
session_id=session.id,
role="assistant",
content=reply_content or "",
tool_calls=tc_json,
)
)
_store_session_state(session, db, active_skill_id=selected_skill_data.get("id"))
db.commit()
return {
"code": 200,
"data": {
"session_id": session.id,
"content": reply_content or "",
"tool_calls": tool_calls_data,
"skill": selected_skill_data,
},
}