# -*- coding: utf-8 -*- """ AI 聊天接口 功能:对话管理 + 消息存库 + function calling(工具调用) """ from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel from typing import List, Optional import json, logging from sqlalchemy.orm import Session as DBSession from app.core.config import settings from app.core.security import get_current_user from app.db import get_db from app.models.user import User from app.models.chat import ChatSession, ChatMessage from app.api.v1.ai_skills import list_ai_skills_public, resolve_ai_skill from app.api.v1.ai_llm import ( call_llm_with_tools, call_vision_llm, get_runtime_ai_config, has_ai_config, mock_reply, ) log = logging.getLogger(__name__) if not log.handlers: log.setLevel(logging.INFO) log.propagate = False _h = logging.StreamHandler() _h.setFormatter( logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s") ) log.addHandler(_h) router = APIRouter() # ==================== 数据模型 ==================== class ChatRequest(BaseModel): message: str session_id: Optional[int] = None image_base64: Optional[str] = None skill_id: Optional[str] = None model: Optional[str] = None vision_model: Optional[str] = None image_edit_model: Optional[str] = None class ToolResultRequest(BaseModel): session_id: int tool_call_id: str tool_name: str result: str skill_id: Optional[str] = None model: Optional[str] = None vision_model: Optional[str] = None image_edit_model: Optional[str] = None class WorkflowStageUpdate(BaseModel): id: str title: str description: str status: str detail: Optional[str] = None class SessionWorkflowUpdateRequest(BaseModel): session_id: int active_skill_id: Optional[str] = None workflow_state: List[WorkflowStageUpdate] = [] def _should_use_vision_shortcut(skill_mode: str, message: str, has_image: bool) -> bool: if not has_image: return False lowered = (message or "").lower() help_or_action_tokens = ( "怎么做", "怎么操作", "如何操作", "教我", "帮我做", "帮我改", "帮我操作", "直接做", "不会", "为什么", "报错", "错误", "失败", "没反应", "不能", "无法", "步骤", "下一步", ) if any(token in lowered for token in help_or_action_tokens): return False if skill_mode == "vision_chat": return True if skill_mode != "hybrid_chat": return False action_tokens = ( "出图", "生成", "方案图", "方向图", "效果图", "背景图", "概念图", "排版", "执行", "落地", ) if any(token in lowered for token in action_tokens): return False analysis_tokens = ("看图", "分析", "你看见", "风格", "版型", "配色", "元素", "建议") return any(token in lowered for token in analysis_tokens) def _decode_workflow_state(raw: Optional[str]) -> list[dict]: if not raw: return [] try: payload = json.loads(raw) return payload if isinstance(payload, list) else [] except Exception: return [] def _store_session_state( session: ChatSession, db: DBSession, *, active_skill_id: Optional[str] = None, workflow_state: Optional[list[dict]] = None, ): if active_skill_id is not None: session.active_skill_id = active_skill_id if workflow_state is not None: session.workflow_state = json.dumps(workflow_state, ensure_ascii=False) db.add(session) # ==================== 对话管理接口 ==================== @router.get("/ai/models") async def get_models( db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user) ): """获取可用模型列表和当前默认值""" def build_options(configured: str, current: str, default_name: str): items = [item.strip() for item in configured.split(",") if item.strip()] if current and current not in items: items.insert(0, current) if not items and current: items = [current] return [{"id": item, "name": item or default_name} for item in items] user = db.query(User).filter(User.username == current_username).first() chat_model = (getattr(user, "ai_model", "") or "").strip() or settings.AI_MODEL vision_model = (getattr(user, "ai_vision_model", "") or "").strip() or settings.AI_VISION_MODEL image_model = (getattr(user, "ai_image_model", "") or "").strip() or settings.AI_IMAGE_EDIT_MODEL return { "code": 200, "data": { "chat_model": chat_model, "vision_model": vision_model, "image_edit_model": image_model, "chat_models": build_options( settings.AI_CHAT_MODELS, chat_model, "当前对话模型" ), "vision_models": build_options( settings.AI_VISION_MODELS, vision_model, "当前视觉模型", ), "image_edit_models": build_options( settings.AI_IMAGE_EDIT_MODELS, image_model, "当前图片模型", ), }, } @router.get("/ai/skills") async def get_skills(current_username: str = Depends(get_current_user)): """获取可用技能列表""" return { "code": 200, "data": { "default_skill": "auto", "skills": list_ai_skills_public(), }, } @router.get("/ai/sessions") async def list_sessions( db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user) ): """获取当前用户的对话列表""" user = db.query(User).filter(User.username == current_username).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") sessions = ( db.query(ChatSession) .filter(ChatSession.user_id == user.id) .order_by(ChatSession.updated_at.desc()) .limit(50) .all() ) return { "code": 200, "data": [ { "id": s.id, "title": s.title, "active_skill_id": s.active_skill_id, "created_at": s.created_at.isoformat() if s.created_at else None, "updated_at": s.updated_at.isoformat() if s.updated_at else None, } for s in sessions ], } @router.get("/ai/sessions/{session_id}/messages") async def get_session_messages( session_id: int, db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user), ): """获取某个对话的全部消息""" user = db.query(User).filter(User.username == current_username).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") session = ( db.query(ChatSession) .filter(ChatSession.id == session_id, ChatSession.user_id == user.id) .first() ) if not session: raise HTTPException(status_code=404, detail="对话不存在") messages = ( db.query(ChatMessage) .filter(ChatMessage.session_id == session_id) .order_by(ChatMessage.created_at) .all() ) return { "code": 200, "data": { "session_id": session.id, "title": session.title, "active_skill_id": session.active_skill_id, "workflow_state": _decode_workflow_state(session.workflow_state), "messages": [ { "id": m.id, "role": m.role, "content": m.content, "tool_calls": m.tool_calls, "created_at": m.created_at.isoformat() if m.created_at else None, } for m in messages ], }, } @router.delete("/ai/sessions/{session_id}") async def delete_session( session_id: int, db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user), ): """删除对话""" user = db.query(User).filter(User.username == current_username).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") session = ( db.query(ChatSession) .filter(ChatSession.id == session_id, ChatSession.user_id == user.id) .first() ) if not session: raise HTTPException(status_code=404, detail="对话不存在") db.query(ChatMessage).filter(ChatMessage.session_id == session_id).delete() db.delete(session) db.commit() return {"code": 200, "message": "对话已删除"} @router.post("/ai/sessions/workflow") async def update_session_workflow( data: SessionWorkflowUpdateRequest, db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user), ): user = db.query(User).filter(User.username == current_username).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") session = ( db.query(ChatSession) .filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id) .first() ) if not session: raise HTTPException(status_code=404, detail="对话不存在") workflow_payload = [stage.model_dump() for stage in data.workflow_state] _store_session_state( session, db, active_skill_id=data.active_skill_id, workflow_state=workflow_payload, ) db.commit() return { "code": 200, "data": { "session_id": session.id, "active_skill_id": session.active_skill_id, "workflow_state": workflow_payload, }, } # ==================== 聊天接口 ==================== @router.post("/ai/chat") async def chat( data: ChatRequest, db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user), ): """发送消息 → AI 回复(可能包含工具调用指令)""" log.info(f"{'=' * 60}") log.info( f"[Chat] 收到消息: '{data.message[:80]}' session={data.session_id} 有图片={'是' if data.image_base64 else '否'}" ) user = db.query(User).filter(User.username == current_username).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") session = None if data.session_id: session = ( db.query(ChatSession) .filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id) .first() ) if not session: title = data.message[:30] + ("..." if len(data.message) > 30 else "") session = ChatSession(user_id=user.id, username=current_username, title=title) db.add(session) db.flush() has_image = bool(data.image_base64) save_content = ("[图片分析] " + data.message) if has_image else data.message db.add(ChatMessage(session_id=session.id, role="user", content=save_content)) db.commit() history = ( db.query(ChatMessage) .filter(ChatMessage.session_id == session.id) .order_by(ChatMessage.created_at) .all() ) history_list = [ {"role": m.role if m.role != "tool" else "user", "content": m.content} for m in history[-20:] ] history_context = history_list[:-1] if history_list else [] selected_skill = resolve_ai_skill( message=data.message, history_messages=history_context, requested_skill_id=data.skill_id, has_image=has_image, ) selected_skill_data = selected_skill.to_public_dict() _store_session_state(session, db, active_skill_id=selected_skill.id) runtime_ai_config = get_runtime_ai_config(user) effective_chat_model = (data.model or user.ai_model or settings.AI_MODEL) if user else (data.model or settings.AI_MODEL) effective_vision_model = (data.vision_model or user.ai_vision_model or settings.AI_VISION_MODEL) if user else (data.vision_model or settings.AI_VISION_MODEL) try: if not has_ai_config(runtime_ai_config): reply_content = mock_reply(data.message, has_image) tool_calls_data = None else: if _should_use_vision_shortcut( selected_skill.mode, data.message, has_image ): reply_content = call_vision_llm( data.message or "请分析这张图片。", data.image_base64 or "", history_context, model_override=effective_vision_model, skill_id=selected_skill.id, runtime_config=runtime_ai_config, ) tool_calls_data = None else: call_history = history_list if has_image: call_history = history_list[:-1] + [ { "role": "user", "content": ( f"[用户上传了一张图片,已保存可供工具使用]" f"[当前技能: {selected_skill.name}] " f"{data.message or '请处理这张图片'}" ), } ] ( reply_content, tool_calls_data, selected_skill_data, ) = call_llm_with_tools( call_history, model_override=effective_chat_model, skill_id=selected_skill.id, has_image=has_image, runtime_config=runtime_ai_config, ) except Exception as e: reply_content = f"AI 请求出错: {str(e)}" tool_calls_data = None tc_json = ( json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None ) db.add( ChatMessage( session_id=session.id, role="assistant", content=reply_content or "", tool_calls=tc_json, ) ) _store_session_state(session, db, active_skill_id=selected_skill_data.get("id")) db.commit() return { "code": 200, "data": { "session_id": session.id, "content": reply_content or "", "tool_calls": tool_calls_data, "skill": selected_skill_data, }, } @router.post("/ai/tool-result") async def submit_tool_result( data: ToolResultRequest, db: DBSession = Depends(get_db), current_username: str = Depends(get_current_user), ): """前端执行完工具后回传结果,后端继续让 AI 总结""" log.info(f"{'=' * 60}") log.info(f"[ToolResult] 工具: {data.tool_name}, 结果: {data.result[:150]}") user = db.query(User).filter(User.username == current_username).first() if not user: raise HTTPException(status_code=404, detail="用户不存在") session = ( db.query(ChatSession) .filter(ChatSession.id == data.session_id, ChatSession.user_id == user.id) .first() ) if not session: raise HTTPException(status_code=404, detail="对话不存在") db.add( ChatMessage( session_id=session.id, role="tool", content=f"[工具 {data.tool_name} 执行结果]: {data.result}", ) ) db.commit() history = ( db.query(ChatMessage) .filter(ChatMessage.session_id == session.id) .order_by(ChatMessage.created_at) .all() ) history_list = [ {"role": m.role if m.role != "tool" else "user", "content": m.content} for m in history[-20:] ] selected_skill = resolve_ai_skill( message=f"{data.tool_name} 已执行", history_messages=history_list, requested_skill_id=data.skill_id, has_image=False, ) selected_skill_data = selected_skill.to_public_dict() _store_session_state(session, db, active_skill_id=selected_skill.id) runtime_ai_config = get_runtime_ai_config(user) effective_chat_model = (data.model or user.ai_model or settings.AI_MODEL) if user else (data.model or settings.AI_MODEL) try: if not has_ai_config(runtime_ai_config): reply_content = f"工具 {data.tool_name} 执行完成。" tool_calls_data = None else: ( reply_content, tool_calls_data, selected_skill_data, ) = call_llm_with_tools( history_list, model_override=effective_chat_model, skill_id=selected_skill.id, has_image=False, runtime_config=runtime_ai_config, ) except Exception as e: reply_content = f"AI 总结出错: {str(e)}" tool_calls_data = None tc_json = ( json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None ) db.add( ChatMessage( session_id=session.id, role="assistant", content=reply_content or "", tool_calls=tc_json, ) ) _store_session_state(session, db, active_skill_id=selected_skill_data.get("id")) db.commit() return { "code": 200, "data": { "session_id": session.id, "content": reply_content or "", "tool_calls": tool_calls_data, "skill": selected_skill_data, }, }