386 lines
15 KiB
Python
386 lines
15 KiB
Python
"""
|
||
client/agent_client.py
|
||
Agent 客户端:驱动完整 OpenAI Function Calling + Tool Chain 执行流程
|
||
新增: OpenAI 格式消息序列管理,支持多轮工具调用上下文传递
|
||
"""
|
||
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
|
||
from config.settings import settings
|
||
from llm.llm_engine import LLMEngine
|
||
from mcp.mcp_protocol import (
|
||
ChainPlan, ChainResult, MCPRequest, MCPResponse,
|
||
StepResult, ToolStep,
|
||
)
|
||
from mcp.mcp_server import MCPServer
|
||
from memory.memory_store import MemoryStore
|
||
from utils.logger import get_logger
|
||
|
||
|
||
@dataclass
|
||
class AgentResponse:
|
||
"""一次完整 Agent 调用的结果"""
|
||
user_input: str
|
||
final_reply: str
|
||
chain_result: ChainResult | None = None
|
||
tool_used: str | None = None
|
||
tool_output: str | None = None
|
||
success: bool = True
|
||
error: str | None = None
|
||
token_usage: dict = field(default_factory=dict)
|
||
|
||
@property
|
||
def is_multi_step(self) -> bool:
|
||
return self.chain_result is not None and self.chain_result.total_steps > 1
|
||
|
||
|
||
class AgentClient:
|
||
"""
|
||
Agent 客户端:OpenAI Function Calling + Multi-Step Tool Chain
|
||
|
||
执行流程:
|
||
1. [CLIENT] 接收用户输入,写入 Memory
|
||
2. [LLM] plan_tool_chain() → OpenAI Function Calling → ChainPlan
|
||
3. [CHAIN] 串行执行每个 ToolStep:
|
||
a. 构造 MCPRequest → MCPServer 执行工具
|
||
b. 将工具结果追加到 OpenAI 消息序列(tool role)
|
||
c. 记录 StepResult,更新链路上下文(占位符替换)
|
||
4. [LLM] generate_chain_reply() → OpenAI 整合结果 → 最终回复
|
||
5. [MEMORY] 写入完整调用链记录
|
||
|
||
OpenAI 消息序列示例(多步骤):
|
||
{"role": "system", "content": "规划器提示"}
|
||
{"role": "user", "content": "搜索天气然后计算..."}
|
||
{"role": "assistant", "tool_calls": [{web_search}, {calculator}]}
|
||
{"role": "tool", "content": "web_search 结果", "tool_call_id": "call_1"}
|
||
{"role": "tool", "content": "calculator 结果", "tool_call_id": "call_2"}
|
||
→ generate_reply() → 最终自然语言回复
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm: LLMEngine,
|
||
mcp_server: MCPServer,
|
||
memory: MemoryStore,
|
||
):
|
||
self.llm = llm
|
||
self.mcp_server = mcp_server
|
||
self.memory = memory
|
||
self.logger = get_logger("CLIENT")
|
||
# OpenAI 格式的结构化对话历史(跨轮次保持上下文)
|
||
self._openai_history: list[dict] = []
|
||
self.logger.info("💻 Agent Client 初始化完成(OpenAI Function Calling 模式)")
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 主入口
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def chat(self, user_input: str) -> AgentResponse:
|
||
"""处理一轮用户对话"""
|
||
sep = "═" * 60
|
||
self.logger.info(sep)
|
||
self.logger.info(f"📨 收到用户输入: {user_input}")
|
||
self.logger.info(sep)
|
||
|
||
# Step 1: 记录用户消息
|
||
self.memory.add_user_message(user_input)
|
||
context = self.memory.get_context_summary()
|
||
|
||
# Step 2: LLM 规划工具调用链
|
||
self.logger.info("🗺 Step 2 [LLM] 规划工具调用链...")
|
||
tool_schemas = self.mcp_server.get_tool_schemas()
|
||
plan: ChainPlan = self.llm.plan_tool_chain(
|
||
user_input=user_input,
|
||
tool_schemas=tool_schemas,
|
||
context=context,
|
||
history=self._openai_history,
|
||
)
|
||
|
||
# 无需工具:直接回复
|
||
if not plan.steps:
|
||
return self._handle_direct_reply(user_input, context)
|
||
|
||
# Step 3~4: 执行工具调用链,构造 OpenAI 消息序列
|
||
chain_result, tool_messages = self._execute_chain(plan, user_input)
|
||
|
||
# Step 5: 调用 OpenAI 整合结果,生成最终回复
|
||
return self._generate_response(user_input, chain_result, tool_messages, context)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 串行执行引擎
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def _execute_chain(
|
||
self,
|
||
plan: ChainPlan,
|
||
user_input: str,
|
||
) -> tuple[ChainResult, list[dict]]:
|
||
"""
|
||
串行执行工具调用链,同步构造 OpenAI 消息序列
|
||
|
||
Returns:
|
||
(ChainResult, tool_messages)
|
||
tool_messages 为 OpenAI 格式的工具调用消息列表,
|
||
用于后续 generate_reply() 调用
|
||
"""
|
||
self.logger.info(
|
||
f"\n{'─' * 60}\n"
|
||
f" 🔗 开始执行工具调用链\n"
|
||
f" 目标: {plan.goal}\n"
|
||
f" 步骤: {plan.step_count} 步\n"
|
||
f"{'─' * 60}"
|
||
)
|
||
|
||
step_results: list[StepResult] = []
|
||
chain_context: dict[str, str] = {}
|
||
tool_messages: list[dict] = []
|
||
failed_step: int | None = None
|
||
|
||
# 构造 assistant 消息(含 tool_calls 声明)
|
||
assistant_tool_calls = self._build_assistant_tool_calls(plan)
|
||
tool_messages.append({
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": assistant_tool_calls,
|
||
})
|
||
|
||
for step in plan.steps:
|
||
# 检查前置依赖
|
||
if self._has_failed_dependency(step, failed_step):
|
||
self.logger.warning(
|
||
f"⏭ Step {step.step_id} [{step.tool_name}] 跳过"
|
||
f"(依赖步骤 {failed_step} 失败)"
|
||
)
|
||
step_results.append(StepResult(
|
||
step_id=step.step_id,
|
||
tool_name=step.tool_name,
|
||
success=False,
|
||
output="",
|
||
error=f"跳过:依赖步骤 {failed_step} 失败",
|
||
))
|
||
# 向 OpenAI 消息序列写入失败占位
|
||
tool_messages.append({
|
||
"role": "tool",
|
||
"content": f"步骤跳过:依赖步骤 {failed_step} 执行失败",
|
||
"tool_call_id": assistant_tool_calls[step.step_id - 1]["id"],
|
||
})
|
||
continue
|
||
|
||
# 执行单步
|
||
result, tool_call_id = self._execute_single_step(
|
||
step, chain_context, assistant_tool_calls
|
||
)
|
||
step_results.append(result)
|
||
|
||
# 追加 tool 消息到 OpenAI 序列
|
||
tool_messages.append({
|
||
"role": "tool",
|
||
"content": result.output if result.success else f"执行失败: {result.error}",
|
||
"tool_call_id": tool_call_id,
|
||
})
|
||
|
||
if result.success:
|
||
chain_context[result.context_key] = result.output
|
||
self.memory.add_tool_result(step.tool_name, result.output)
|
||
else:
|
||
failed_step = step.step_id
|
||
|
||
overall_success = failed_step is None
|
||
chain_result = ChainResult(
|
||
goal=plan.goal,
|
||
step_results=step_results,
|
||
success=overall_success,
|
||
failed_step=failed_step,
|
||
)
|
||
|
||
self.logger.info(
|
||
f"{'─' * 60}\n"
|
||
f" {'✅ 调用链执行完成' if overall_success else '⚠️ 调用链部分失败'}\n"
|
||
f" 完成: {chain_result.completed_steps}/{chain_result.total_steps} 步\n"
|
||
f"{'─' * 60}"
|
||
)
|
||
return chain_result, tool_messages
|
||
|
||
def _execute_single_step(
|
||
self,
|
||
step: ToolStep,
|
||
chain_context: dict[str, str],
|
||
assistant_tool_calls: list[dict],
|
||
) -> tuple[StepResult, str]:
|
||
"""
|
||
执行单个步骤,返回 (StepResult, tool_call_id)
|
||
|
||
Returns:
|
||
StepResult: 步骤执行结果
|
||
tool_call_id: 对应的 OpenAI tool_call_id(用于消息序列关联)
|
||
"""
|
||
# 注入前步上下文(占位符替换)
|
||
resolved_step = step.inject_context(chain_context)
|
||
tool_call_id = assistant_tool_calls[step.step_id - 1]["id"]
|
||
|
||
self.logger.info(
|
||
f"\n ▶ Step {step.step_id} 执行中\n"
|
||
f" 工具: [{resolved_step.tool_name}]\n"
|
||
f" 说明: {resolved_step.description}\n"
|
||
f" 参数: {resolved_step.arguments}\n"
|
||
f" call_id: {tool_call_id}"
|
||
)
|
||
|
||
# 构造并发送 MCP 请求
|
||
mcp_request: MCPRequest = resolved_step.to_mcp_request()
|
||
mcp_response: MCPResponse = self.mcp_server.handle_request(mcp_request)
|
||
|
||
if mcp_response.success:
|
||
output = mcp_response.content
|
||
self.logger.info(f" ✅ Step {step.step_id} 成功: {output[:80]}...")
|
||
return StepResult(
|
||
step_id=step.step_id,
|
||
tool_name=step.tool_name,
|
||
success=True,
|
||
output=output,
|
||
), tool_call_id
|
||
else:
|
||
error_msg = mcp_response.error.get("message", "未知错误")
|
||
self.logger.error(f" ❌ Step {step.step_id} 失败: {error_msg}")
|
||
return StepResult(
|
||
step_id=step.step_id,
|
||
tool_name=step.tool_name,
|
||
success=False,
|
||
output="",
|
||
error=error_msg,
|
||
), tool_call_id
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 回复生成
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def _generate_response(
|
||
self,
|
||
user_input: str,
|
||
chain_result: ChainResult,
|
||
tool_messages: list[dict],
|
||
context: str,
|
||
) -> AgentResponse:
|
||
"""调用 OpenAI 整合工具结果,生成最终 AgentResponse"""
|
||
self.logger.info("✍️ Step 5 [LLM] 调用 OpenAI 生成最终回复...")
|
||
|
||
chain_summary = self._build_chain_summary(chain_result)
|
||
|
||
# 单步走简洁路径
|
||
if chain_result.total_steps == 1:
|
||
r = chain_result.step_results[0]
|
||
final_reply = self.llm.generate_final_reply(
|
||
user_input=user_input,
|
||
tool_name=r.tool_name,
|
||
tool_output=r.output,
|
||
context=context,
|
||
tool_call_id=tool_messages[-1].get("tool_call_id", "") if tool_messages else "",
|
||
)
|
||
else:
|
||
# 多步走完整 OpenAI 消息序列路径
|
||
final_reply = self.llm.generate_chain_reply(
|
||
user_input=user_input,
|
||
chain_summary=chain_summary,
|
||
context=context,
|
||
tool_messages=tool_messages,
|
||
)
|
||
|
||
chain_result.final_reply = final_reply
|
||
|
||
# 更新 OpenAI 结构化历史(供下一轮使用)
|
||
self._openai_history.append({"role": "user", "content": user_input})
|
||
self._openai_history.append({"role": "assistant", "content": final_reply})
|
||
# 保留最近 10 轮
|
||
if len(self._openai_history) > 20:
|
||
self._openai_history = self._openai_history[-20:]
|
||
|
||
# 写入 Memory
|
||
if chain_result.total_steps > 1:
|
||
self.memory.add_chain_result(chain_result)
|
||
else:
|
||
self.memory.add_assistant_message(final_reply)
|
||
|
||
self.logger.info("🎉 [CLIENT] 流程完成,回复已返回")
|
||
return AgentResponse(
|
||
user_input=user_input,
|
||
final_reply=final_reply,
|
||
chain_result=chain_result,
|
||
tool_used=chain_result.step_results[0].tool_name if chain_result.total_steps == 1 else None,
|
||
tool_output=chain_result.step_results[0].output if chain_result.total_steps == 1 else None,
|
||
success=chain_result.success,
|
||
)
|
||
|
||
def _handle_direct_reply(self, user_input: str, context: str) -> AgentResponse:
|
||
"""无需工具时直接调用 OpenAI 生成回复"""
|
||
self.logger.info("💬 无需工具,直接调用 OpenAI 生成回复")
|
||
reply = self.llm.generate_direct_reply(user_input, context)
|
||
self.memory.add_assistant_message(reply)
|
||
self._openai_history.append({"role": "user", "content": user_input})
|
||
self._openai_history.append({"role": "assistant", "content": reply})
|
||
return AgentResponse(user_input=user_input, final_reply=reply)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 工具方法
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
@staticmethod
|
||
def _build_assistant_tool_calls(plan: ChainPlan) -> list[dict]:
|
||
"""
|
||
构造 OpenAI assistant 消息中的 tool_calls 字段
|
||
|
||
格式:
|
||
[
|
||
{
|
||
"id": "call_abc123",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "calculator",
|
||
"arguments": '{"expression": "1+2"}'
|
||
}
|
||
}
|
||
]
|
||
"""
|
||
import json
|
||
tool_calls = []
|
||
for step in plan.steps:
|
||
tool_calls.append({
|
||
"id": f"call_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": step.tool_name,
|
||
"arguments": json.dumps(step.arguments, ensure_ascii=False),
|
||
},
|
||
})
|
||
return tool_calls
|
||
|
||
@staticmethod
|
||
def _build_chain_summary(chain_result: ChainResult) -> str:
|
||
"""将调用链结果格式化为 LLM 可读的摘要"""
|
||
lines = []
|
||
for r in chain_result.step_results:
|
||
if r.success:
|
||
lines.append(
|
||
f"**Step {r.step_id} [{r.tool_name}]** ✅\n"
|
||
f"```\n{r.output[:300]}\n```"
|
||
)
|
||
else:
|
||
lines.append(
|
||
f"**Step {r.step_id} [{r.tool_name}]** ❌\n"
|
||
f"错误: {r.error}"
|
||
)
|
||
return "\n\n".join(lines)
|
||
|
||
@staticmethod
|
||
def _has_failed_dependency(step: ToolStep, failed_step: int | None) -> bool:
|
||
return failed_step is not None and failed_step in step.depends_on
|
||
|
||
def get_memory_stats(self) -> dict:
|
||
stats = self.memory.stats()
|
||
stats["openai_history_len"] = len(self._openai_history)
|
||
return stats
|
||
|
||
def clear_session(self) -> None:
|
||
self.memory.clear_history()
|
||
self._openai_history.clear()
|
||
self.logger.info("🗑 会话已清空(含 OpenAI 消息历史)") |