ai-demo/app/api/chat.py

148 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API 路由模块
定义 FastAPI 接口路由
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional, List, Dict
from app.services import ai_service
from app.tools.base import ToolResult
import json
from app.tools.registry import ToolRegistry
router = APIRouter()
class ChatRequest(BaseModel):
"""聊天请求模型"""
message: str
system_prompt: Optional[str] = None
history: Optional[List[Dict[str, str]]] = None
class ChatResponse(BaseModel):
"""聊天响应模型"""
success: bool
message: str
response: str
class TimeResponse(BaseModel):
"""时间响应模型"""
success: bool
data: Dict
@router.post("/chat-with-tools", response_model=ChatResponse)
async def chat_with_tools(request: ChatRequest):
"""
带工具调用的 AI 聊天接口
AI 可以调用工具(如获取当前时间)来回答问题
Args:
request: 聊天请求
Returns:
AI 的回复(可能包含工具调用结果)
"""
try:
# 获取所有可用工具的 schema
tools_instances = ToolRegistry.get_instances()
tools = [tool.to_openai_format() for tool in tools_instances]
# 系统提示,告诉 AI 可以使用哪些工具
system_prompt = """你是一个智能助手,可以使用工具来帮助回答问题"""
# 第一次调用 AI
result = await ai_service.chat_with_tools(
message=request.message,
tools=tools,
system_prompt=system_prompt
)
message = result["message"]
# 检查是否有工具调用
if message.tool_calls:
# 处理工具调用
tool_results = []
for tool_call in message.tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
# 从 ToolRegistry 获取并执行工具
tools_dict = ToolRegistry.get_all()
if function_name in tools_dict:
tool_instance = tools_dict[function_name]()
tool_result = tool_instance.execute(**function_args)
# 处理 ToolResult 对象
if isinstance(tool_result, ToolResult):
result_data = {
"success": tool_result.success,
"data": tool_result.data,
"error": tool_result.error
}
else:
result_data = tool_result
else:
result_data = {
"success": False,
"error": f"工具 '{function_name}' 未找到"
}
tool_results.append({
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": json.dumps(result_data, ensure_ascii=False)
})
# 构建新的消息历史
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request.message},
{"role": "assistant", "content": message.content, "tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
} for tc in message.tool_calls
]}
]
messages.extend(tool_results)
# 再次调用 AI让它根据工具结果生成最终回复
from openai import OpenAI
from app.config import settings
client = OpenAI(
api_key=settings.openai_api_key,
base_url=settings.openai_base_url
)
final_response = client.chat.completions.create(
model=settings.openai_model,
messages=messages
)
final_content = final_response.choices[0].message.content
else:
# 没有工具调用,直接返回 AI 的回复
final_content = message.content or "抱歉,我无法理解您的问题。"
return ChatResponse(
success=True,
message="对话成功",
response=final_content
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))