148 lines
4.7 KiB
Python
148 lines
4.7 KiB
Python
"""
|
||
API 路由模块
|
||
定义 FastAPI 接口路由
|
||
"""
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel
|
||
from typing import Optional, List, Dict
|
||
from app.services import ai_service
|
||
from app.tools.base import ToolResult
|
||
import json
|
||
|
||
from app.tools.registry import ToolRegistry
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
"""聊天请求模型"""
|
||
message: str
|
||
system_prompt: Optional[str] = None
|
||
history: Optional[List[Dict[str, str]]] = None
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
"""聊天响应模型"""
|
||
success: bool
|
||
message: str
|
||
response: str
|
||
|
||
|
||
class TimeResponse(BaseModel):
|
||
"""时间响应模型"""
|
||
success: bool
|
||
data: Dict
|
||
|
||
|
||
@router.post("/chat-with-tools", response_model=ChatResponse)
|
||
async def chat_with_tools(request: ChatRequest):
|
||
"""
|
||
带工具调用的 AI 聊天接口
|
||
|
||
AI 可以调用工具(如获取当前时间)来回答问题
|
||
|
||
Args:
|
||
request: 聊天请求
|
||
|
||
Returns:
|
||
AI 的回复(可能包含工具调用结果)
|
||
"""
|
||
try:
|
||
# 获取所有可用工具的 schema
|
||
tools_instances = ToolRegistry.get_instances()
|
||
tools = [tool.to_openai_format() for tool in tools_instances]
|
||
|
||
# 系统提示,告诉 AI 可以使用哪些工具
|
||
system_prompt = """你是一个智能助手,可以使用工具来帮助回答问题"""
|
||
|
||
# 第一次调用 AI
|
||
result = await ai_service.chat_with_tools(
|
||
message=request.message,
|
||
tools=tools,
|
||
system_prompt=system_prompt
|
||
)
|
||
|
||
message = result["message"]
|
||
|
||
# 检查是否有工具调用
|
||
if message.tool_calls:
|
||
# 处理工具调用
|
||
tool_results = []
|
||
|
||
for tool_call in message.tool_calls:
|
||
function_name = tool_call.function.name
|
||
function_args = json.loads(tool_call.function.arguments)
|
||
|
||
# 从 ToolRegistry 获取并执行工具
|
||
tools_dict = ToolRegistry.get_all()
|
||
if function_name in tools_dict:
|
||
tool_instance = tools_dict[function_name]()
|
||
tool_result = tool_instance.execute(**function_args)
|
||
|
||
# 处理 ToolResult 对象
|
||
if isinstance(tool_result, ToolResult):
|
||
result_data = {
|
||
"success": tool_result.success,
|
||
"data": tool_result.data,
|
||
"error": tool_result.error
|
||
}
|
||
else:
|
||
result_data = tool_result
|
||
else:
|
||
result_data = {
|
||
"success": False,
|
||
"error": f"工具 '{function_name}' 未找到"
|
||
}
|
||
|
||
tool_results.append({
|
||
"tool_call_id": tool_call.id,
|
||
"role": "tool",
|
||
"name": function_name,
|
||
"content": json.dumps(result_data, ensure_ascii=False)
|
||
})
|
||
|
||
# 构建新的消息历史
|
||
messages = [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": request.message},
|
||
{"role": "assistant", "content": message.content, "tool_calls": [
|
||
{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.function.name,
|
||
"arguments": tc.function.arguments
|
||
}
|
||
} for tc in message.tool_calls
|
||
]}
|
||
]
|
||
messages.extend(tool_results)
|
||
|
||
# 再次调用 AI,让它根据工具结果生成最终回复
|
||
from openai import OpenAI
|
||
from app.config import settings
|
||
|
||
client = OpenAI(
|
||
api_key=settings.openai_api_key,
|
||
base_url=settings.openai_base_url
|
||
)
|
||
|
||
final_response = client.chat.completions.create(
|
||
model=settings.openai_model,
|
||
messages=messages
|
||
)
|
||
|
||
final_content = final_response.choices[0].message.content
|
||
|
||
else:
|
||
# 没有工具调用,直接返回 AI 的回复
|
||
final_content = message.content or "抱歉,我无法理解您的问题。"
|
||
|
||
return ChatResponse(
|
||
success=True,
|
||
message="对话成功",
|
||
response=final_content
|
||
)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e)) |