base_agent/client/agent_client.py

386 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
client/agent_client.py
Agent 客户端:驱动完整 OpenAI Function Calling + Tool Chain 执行流程
新增: OpenAI 格式消息序列管理,支持多轮工具调用上下文传递
"""
import uuid
from dataclasses import dataclass, field
from config.settings import settings
from llm.llm_engine import LLMEngine
from mcp.mcp_protocol import (
ChainPlan, ChainResult, MCPRequest, MCPResponse,
StepResult, ToolStep,
)
from mcp.mcp_server import MCPServer
from memory.memory_store import MemoryStore
from utils.logger import get_logger
@dataclass
class AgentResponse:
"""一次完整 Agent 调用的结果"""
user_input: str
final_reply: str
chain_result: ChainResult | None = None
tool_used: str | None = None
tool_output: str | None = None
success: bool = True
error: str | None = None
token_usage: dict = field(default_factory=dict)
@property
def is_multi_step(self) -> bool:
return self.chain_result is not None and self.chain_result.total_steps > 1
class AgentClient:
"""
Agent 客户端OpenAI Function Calling + Multi-Step Tool Chain
执行流程:
1. [CLIENT] 接收用户输入,写入 Memory
2. [LLM] plan_tool_chain() → OpenAI Function Calling → ChainPlan
3. [CHAIN] 串行执行每个 ToolStep:
a. 构造 MCPRequest → MCPServer 执行工具
b. 将工具结果追加到 OpenAI 消息序列tool role
c. 记录 StepResult更新链路上下文占位符替换
4. [LLM] generate_chain_reply() → OpenAI 整合结果 → 最终回复
5. [MEMORY] 写入完整调用链记录
OpenAI 消息序列示例(多步骤):
{"role": "system", "content": "规划器提示"}
{"role": "user", "content": "搜索天气然后计算..."}
{"role": "assistant", "tool_calls": [{web_search}, {calculator}]}
{"role": "tool", "content": "web_search 结果", "tool_call_id": "call_1"}
{"role": "tool", "content": "calculator 结果", "tool_call_id": "call_2"}
→ generate_reply() → 最终自然语言回复
"""
def __init__(
self,
llm: LLMEngine,
mcp_server: MCPServer,
memory: MemoryStore,
):
self.llm = llm
self.mcp_server = mcp_server
self.memory = memory
self.logger = get_logger("CLIENT")
# OpenAI 格式的结构化对话历史(跨轮次保持上下文)
self._openai_history: list[dict] = []
self.logger.info("💻 Agent Client 初始化完成OpenAI Function Calling 模式)")
# ════════════════════════════════════════════════════════════
# 主入口
# ════════════════════════════════════════════════════════════
def chat(self, user_input: str) -> AgentResponse:
"""处理一轮用户对话"""
sep = "" * 60
self.logger.info(sep)
self.logger.info(f"📨 收到用户输入: {user_input}")
self.logger.info(sep)
# Step 1: 记录用户消息
self.memory.add_user_message(user_input)
context = self.memory.get_context_summary()
# Step 2: LLM 规划工具调用链
self.logger.info("🗺 Step 2 [LLM] 规划工具调用链...")
tool_schemas = self.mcp_server.get_tool_schemas()
plan: ChainPlan = self.llm.plan_tool_chain(
user_input=user_input,
tool_schemas=tool_schemas,
context=context,
history=self._openai_history,
)
# 无需工具:直接回复
if not plan.steps:
return self._handle_direct_reply(user_input, context)
# Step 3~4: 执行工具调用链,构造 OpenAI 消息序列
chain_result, tool_messages = self._execute_chain(plan, user_input)
# Step 5: 调用 OpenAI 整合结果,生成最终回复
return self._generate_response(user_input, chain_result, tool_messages, context)
# ════════════════════════════════════════════════════════════
# 串行执行引擎
# ════════════════════════════════════════════════════════════
def _execute_chain(
self,
plan: ChainPlan,
user_input: str,
) -> tuple[ChainResult, list[dict]]:
"""
串行执行工具调用链,同步构造 OpenAI 消息序列
Returns:
(ChainResult, tool_messages)
tool_messages 为 OpenAI 格式的工具调用消息列表,
用于后续 generate_reply() 调用
"""
self.logger.info(
f"\n{'' * 60}\n"
f" 🔗 开始执行工具调用链\n"
f" 目标: {plan.goal}\n"
f" 步骤: {plan.step_count}\n"
f"{'' * 60}"
)
step_results: list[StepResult] = []
chain_context: dict[str, str] = {}
tool_messages: list[dict] = []
failed_step: int | None = None
# 构造 assistant 消息(含 tool_calls 声明)
assistant_tool_calls = self._build_assistant_tool_calls(plan)
tool_messages.append({
"role": "assistant",
"content": None,
"tool_calls": assistant_tool_calls,
})
for step in plan.steps:
# 检查前置依赖
if self._has_failed_dependency(step, failed_step):
self.logger.warning(
f"⏭ Step {step.step_id} [{step.tool_name}] 跳过"
f"(依赖步骤 {failed_step} 失败)"
)
step_results.append(StepResult(
step_id=step.step_id,
tool_name=step.tool_name,
success=False,
output="",
error=f"跳过:依赖步骤 {failed_step} 失败",
))
# 向 OpenAI 消息序列写入失败占位
tool_messages.append({
"role": "tool",
"content": f"步骤跳过:依赖步骤 {failed_step} 执行失败",
"tool_call_id": assistant_tool_calls[step.step_id - 1]["id"],
})
continue
# 执行单步
result, tool_call_id = self._execute_single_step(
step, chain_context, assistant_tool_calls
)
step_results.append(result)
# 追加 tool 消息到 OpenAI 序列
tool_messages.append({
"role": "tool",
"content": result.output if result.success else f"执行失败: {result.error}",
"tool_call_id": tool_call_id,
})
if result.success:
chain_context[result.context_key] = result.output
self.memory.add_tool_result(step.tool_name, result.output)
else:
failed_step = step.step_id
overall_success = failed_step is None
chain_result = ChainResult(
goal=plan.goal,
step_results=step_results,
success=overall_success,
failed_step=failed_step,
)
self.logger.info(
f"{'' * 60}\n"
f" {'✅ 调用链执行完成' if overall_success else '⚠️ 调用链部分失败'}\n"
f" 完成: {chain_result.completed_steps}/{chain_result.total_steps}\n"
f"{'' * 60}"
)
return chain_result, tool_messages
def _execute_single_step(
self,
step: ToolStep,
chain_context: dict[str, str],
assistant_tool_calls: list[dict],
) -> tuple[StepResult, str]:
"""
执行单个步骤,返回 (StepResult, tool_call_id)
Returns:
StepResult: 步骤执行结果
tool_call_id: 对应的 OpenAI tool_call_id用于消息序列关联
"""
# 注入前步上下文(占位符替换)
resolved_step = step.inject_context(chain_context)
tool_call_id = assistant_tool_calls[step.step_id - 1]["id"]
self.logger.info(
f"\n ▶ Step {step.step_id} 执行中\n"
f" 工具: [{resolved_step.tool_name}]\n"
f" 说明: {resolved_step.description}\n"
f" 参数: {resolved_step.arguments}\n"
f" call_id: {tool_call_id}"
)
# 构造并发送 MCP 请求
mcp_request: MCPRequest = resolved_step.to_mcp_request()
mcp_response: MCPResponse = self.mcp_server.handle_request(mcp_request)
if mcp_response.success:
output = mcp_response.content
self.logger.info(f" ✅ Step {step.step_id} 成功: {output[:80]}...")
return StepResult(
step_id=step.step_id,
tool_name=step.tool_name,
success=True,
output=output,
), tool_call_id
else:
error_msg = mcp_response.error.get("message", "未知错误")
self.logger.error(f" ❌ Step {step.step_id} 失败: {error_msg}")
return StepResult(
step_id=step.step_id,
tool_name=step.tool_name,
success=False,
output="",
error=error_msg,
), tool_call_id
# ════════════════════════════════════════════════════════════
# 回复生成
# ════════════════════════════════════════════════════════════
def _generate_response(
self,
user_input: str,
chain_result: ChainResult,
tool_messages: list[dict],
context: str,
) -> AgentResponse:
"""调用 OpenAI 整合工具结果,生成最终 AgentResponse"""
self.logger.info("✍️ Step 5 [LLM] 调用 OpenAI 生成最终回复...")
chain_summary = self._build_chain_summary(chain_result)
# 单步走简洁路径
if chain_result.total_steps == 1:
r = chain_result.step_results[0]
final_reply = self.llm.generate_final_reply(
user_input=user_input,
tool_name=r.tool_name,
tool_output=r.output,
context=context,
tool_call_id=tool_messages[-1].get("tool_call_id", "") if tool_messages else "",
)
else:
# 多步走完整 OpenAI 消息序列路径
final_reply = self.llm.generate_chain_reply(
user_input=user_input,
chain_summary=chain_summary,
context=context,
tool_messages=tool_messages,
)
chain_result.final_reply = final_reply
# 更新 OpenAI 结构化历史(供下一轮使用)
self._openai_history.append({"role": "user", "content": user_input})
self._openai_history.append({"role": "assistant", "content": final_reply})
# 保留最近 10 轮
if len(self._openai_history) > 20:
self._openai_history = self._openai_history[-20:]
# 写入 Memory
if chain_result.total_steps > 1:
self.memory.add_chain_result(chain_result)
else:
self.memory.add_assistant_message(final_reply)
self.logger.info("🎉 [CLIENT] 流程完成,回复已返回")
return AgentResponse(
user_input=user_input,
final_reply=final_reply,
chain_result=chain_result,
tool_used=chain_result.step_results[0].tool_name if chain_result.total_steps == 1 else None,
tool_output=chain_result.step_results[0].output if chain_result.total_steps == 1 else None,
success=chain_result.success,
)
def _handle_direct_reply(self, user_input: str, context: str) -> AgentResponse:
"""无需工具时直接调用 OpenAI 生成回复"""
self.logger.info("💬 无需工具,直接调用 OpenAI 生成回复")
reply = self.llm.generate_direct_reply(user_input, context)
self.memory.add_assistant_message(reply)
self._openai_history.append({"role": "user", "content": user_input})
self._openai_history.append({"role": "assistant", "content": reply})
return AgentResponse(user_input=user_input, final_reply=reply)
# ════════════════════════════════════════════════════════════
# 工具方法
# ════════════════════════════════════════════════════════════
@staticmethod
def _build_assistant_tool_calls(plan: ChainPlan) -> list[dict]:
"""
构造 OpenAI assistant 消息中的 tool_calls 字段
格式:
[
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "calculator",
"arguments": '{"expression": "1+2"}'
}
}
]
"""
import json
tool_calls = []
for step in plan.steps:
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:8]}",
"type": "function",
"function": {
"name": step.tool_name,
"arguments": json.dumps(step.arguments, ensure_ascii=False),
},
})
return tool_calls
@staticmethod
def _build_chain_summary(chain_result: ChainResult) -> str:
"""将调用链结果格式化为 LLM 可读的摘要"""
lines = []
for r in chain_result.step_results:
if r.success:
lines.append(
f"**Step {r.step_id} [{r.tool_name}]** ✅\n"
f"```\n{r.output[:300]}\n```"
)
else:
lines.append(
f"**Step {r.step_id} [{r.tool_name}]** ❌\n"
f"错误: {r.error}"
)
return "\n\n".join(lines)
@staticmethod
def _has_failed_dependency(step: ToolStep, failed_step: int | None) -> bool:
return failed_step is not None and failed_step in step.depends_on
def get_memory_stats(self) -> dict:
stats = self.memory.stats()
stats["openai_history_len"] = len(self._openai_history)
return stats
def clear_session(self) -> None:
self.memory.clear_history()
self._openai_history.clear()
self.logger.info("🗑 会话已清空(含 OpenAI 消息历史)")