148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
|
|
"""
|
|||
|
|
API 路由模块
|
|||
|
|
定义 FastAPI 接口路由
|
|||
|
|
"""
|
|||
|
|
from fastapi import APIRouter, HTTPException
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
from typing import Optional, List, Dict
|
|||
|
|
from app.services import ai_service
|
|||
|
|
from app.tools.base import ToolResult
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
from app.tools.registry import ToolRegistry
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatRequest(BaseModel):
|
|||
|
|
"""聊天请求模型"""
|
|||
|
|
message: str
|
|||
|
|
system_prompt: Optional[str] = None
|
|||
|
|
history: Optional[List[Dict[str, str]]] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatResponse(BaseModel):
|
|||
|
|
"""聊天响应模型"""
|
|||
|
|
success: bool
|
|||
|
|
message: str
|
|||
|
|
response: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TimeResponse(BaseModel):
|
|||
|
|
"""时间响应模型"""
|
|||
|
|
success: bool
|
|||
|
|
data: Dict
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/chat-with-tools", response_model=ChatResponse)
|
|||
|
|
async def chat_with_tools(request: ChatRequest):
|
|||
|
|
"""
|
|||
|
|
带工具调用的 AI 聊天接口
|
|||
|
|
|
|||
|
|
AI 可以调用工具(如获取当前时间)来回答问题
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
request: 聊天请求
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
AI 的回复(可能包含工具调用结果)
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 获取所有可用工具的 schema
|
|||
|
|
tools_instances = ToolRegistry.get_instances()
|
|||
|
|
tools = [tool.to_openai_format() for tool in tools_instances]
|
|||
|
|
|
|||
|
|
# 系统提示,告诉 AI 可以使用哪些工具
|
|||
|
|
system_prompt = """你是一个智能助手,可以使用工具来帮助回答问题"""
|
|||
|
|
|
|||
|
|
# 第一次调用 AI
|
|||
|
|
result = await ai_service.chat_with_tools(
|
|||
|
|
message=request.message,
|
|||
|
|
tools=tools,
|
|||
|
|
system_prompt=system_prompt
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
message = result["message"]
|
|||
|
|
|
|||
|
|
# 检查是否有工具调用
|
|||
|
|
if message.tool_calls:
|
|||
|
|
# 处理工具调用
|
|||
|
|
tool_results = []
|
|||
|
|
|
|||
|
|
for tool_call in message.tool_calls:
|
|||
|
|
function_name = tool_call.function.name
|
|||
|
|
function_args = json.loads(tool_call.function.arguments)
|
|||
|
|
|
|||
|
|
# 从 ToolRegistry 获取并执行工具
|
|||
|
|
tools_dict = ToolRegistry.get_all()
|
|||
|
|
if function_name in tools_dict:
|
|||
|
|
tool_instance = tools_dict[function_name]()
|
|||
|
|
tool_result = tool_instance.execute(**function_args)
|
|||
|
|
|
|||
|
|
# 处理 ToolResult 对象
|
|||
|
|
if isinstance(tool_result, ToolResult):
|
|||
|
|
result_data = {
|
|||
|
|
"success": tool_result.success,
|
|||
|
|
"data": tool_result.data,
|
|||
|
|
"error": tool_result.error
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
result_data = tool_result
|
|||
|
|
else:
|
|||
|
|
result_data = {
|
|||
|
|
"success": False,
|
|||
|
|
"error": f"工具 '{function_name}' 未找到"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tool_results.append({
|
|||
|
|
"tool_call_id": tool_call.id,
|
|||
|
|
"role": "tool",
|
|||
|
|
"name": function_name,
|
|||
|
|
"content": json.dumps(result_data, ensure_ascii=False)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 构建新的消息历史
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": request.message},
|
|||
|
|
{"role": "assistant", "content": message.content, "tool_calls": [
|
|||
|
|
{
|
|||
|
|
"id": tc.id,
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": tc.function.name,
|
|||
|
|
"arguments": tc.function.arguments
|
|||
|
|
}
|
|||
|
|
} for tc in message.tool_calls
|
|||
|
|
]}
|
|||
|
|
]
|
|||
|
|
messages.extend(tool_results)
|
|||
|
|
|
|||
|
|
# 再次调用 AI,让它根据工具结果生成最终回复
|
|||
|
|
from openai import OpenAI
|
|||
|
|
from app.config import settings
|
|||
|
|
|
|||
|
|
client = OpenAI(
|
|||
|
|
api_key=settings.openai_api_key,
|
|||
|
|
base_url=settings.openai_base_url
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
final_response = client.chat.completions.create(
|
|||
|
|
model=settings.openai_model,
|
|||
|
|
messages=messages
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
final_content = final_response.choices[0].message.content
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
# 没有工具调用,直接返回 AI 的回复
|
|||
|
|
final_content = message.content or "抱歉,我无法理解您的问题。"
|
|||
|
|
|
|||
|
|
return ChatResponse(
|
|||
|
|
success=True,
|
|||
|
|
message="对话成功",
|
|||
|
|
response=final_content
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
raise HTTPException(status_code=500, detail=str(e))
|