base_agent/client/agent_client.py

412 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
client/agent_client.py
修复OpenAI 消息序列必须满足
user → assistant(tool_calls) → tool(s) → assistant(reply)
每个 tool 消息的 tool_call_id 必须与同一 assistant 消息中的 tool_calls[].id 完全匹配
"""
import json
import uuid
from dataclasses import dataclass, field
from config.settings import settings
from llm.llm_engine import LLMEngine
from mcp.mcp_protocol import (
ChainPlan, ChainResult, MCPRequest, MCPResponse,
StepResult, ToolStep,
)
from mcp.mcp_server import MCPServer
from memory.memory_store import MemoryStore
from utils.logger import get_logger
@dataclass
class AgentResponse:
"""一次完整 Agent 调用的结果"""
user_input: str
final_reply: str
chain_result: ChainResult | None = None
tool_used: str | None = None
tool_output: str | None = None
success: bool = True
error: str | None = None
token_usage: dict = field(default_factory=dict)
@property
def is_multi_step(self) -> bool:
return self.chain_result is not None and self.chain_result.total_steps > 1
class AgentClient:
"""
Agent 客户端OpenAI Function Calling + Multi-Step Tool Chain
✅ 修复后的 OpenAI 消息序列规范:
【单步工具调用】
{"role": "system", "content": "..."}
{"role": "user", "content": "计算 1+1"}
{"role": "assistant", "content": null,
"tool_calls": [{"id": "call_abc", "type": "function",
"function": {"name": "calculator", "arguments": "{...}"}}]}
{"role": "tool", "content": "结果: 2", "tool_call_id": "call_abc"}
→ generate_reply() → 最终回复
【多步工具调用(并行声明,串行执行)】
{"role": "system", "content": "..."}
{"role": "user", "content": "搜索天气然后计算..."}
{"role": "assistant", "content": null,
"tool_calls": [{"id": "call_001", ...web_search...},
{"id": "call_002", ...calculator...}]}
{"role": "tool", "content": "搜索结果", "tool_call_id": "call_001"}
{"role": "tool", "content": "计算结果", "tool_call_id": "call_002"}
→ generate_reply() → 最终回复
⚠️ 关键约束:
1. tool 消息必须紧跟在含 tool_calls 的 assistant 消息之后
2. 每个 tool 的 tool_call_id 必须与 assistant.tool_calls[i].id 完全一致
3. tool_calls 中声明了几个工具,就必须有几条对应的 tool 消息
"""
def __init__(
self,
llm: LLMEngine,
mcp_server: MCPServer,
memory: MemoryStore,
):
self.llm = llm
self.mcp_server = mcp_server
self.memory = memory
self.logger = get_logger("CLIENT")
self._openai_history: list[dict] = []
self.logger.info("💻 Agent Client 初始化完成OpenAI Function Calling 模式)")
# ════════════════════════════════════════════════════════════
# 主入口
# ════════════════════════════════════════════════════════════
def chat(self, user_input: str) -> AgentResponse:
"""处理一轮用户对话"""
sep = "" * 60
self.logger.info(sep)
self.logger.info(f"📨 收到用户输入: {user_input}")
self.logger.info(sep)
self.memory.add_user_message(user_input)
context = self.memory.get_context_summary()
# Step 1: LLM 规划工具调用链
self.logger.info("🗺 [LLM] 规划工具调用链...")
tool_schemas = self.mcp_server.get_tool_schemas()
plan: ChainPlan = self.llm.plan_tool_chain(
user_input=user_input,
tool_schemas=tool_schemas,
context=context,
history=self._openai_history,
)
# 无需工具:直接回复
if not plan.steps:
return self._handle_direct_reply(user_input, context)
# Step 2: 执行工具调用链,构造合规的 OpenAI 消息序列
chain_result, openai_tool_block = self._execute_chain(plan, user_input)
# Step 3: 调用 OpenAI 整合结果,生成最终回复
return self._generate_response(
user_input, chain_result, openai_tool_block, context
)
# ════════════════════════════════════════════════════════════
# 串行执行引擎
# ════════════════════════════════════════════════════════════
def _execute_chain(
self,
plan: ChainPlan,
user_input: str,
) -> tuple[ChainResult, list[dict]]:
"""
串行执行工具调用链,构造符合 OpenAI 协议的消息块
Returns:
chain_result: 执行结果汇总
openai_tool_block: 合规的 OpenAI 消息块,结构为:
[
assistant(tool_calls=[call_001, call_002, ...]),
tool(tool_call_id=call_001, content=...),
tool(tool_call_id=call_002, content=...),
...
]
⚠️ 关键:所有 tool_calls 在 assistant 消息中一次性声明,
然后逐条追加对应的 tool 结果消息,保证 id 完全匹配。
"""
self.logger.info(
f"\n{'' * 60}\n"
f" 🔗 开始执行工具调用链\n"
f" 目标: {plan.goal}\n"
f" 步骤: {plan.step_count}\n"
f"{'' * 60}"
)
# ── 1. 预先为每个步骤生成稳定的 tool_call_id ──────────
# 必须在执行前全部生成assistant 消息和 tool 消息共享同一批 id
step_call_ids: dict[int, str] = {
step.step_id: f"call_{uuid.uuid4().hex[:12]}"
for step in plan.steps
}
self.logger.debug(f"🔑 预生成 tool_call_ids: {step_call_ids}")
# ── 2. 构造 assistant 消息(一次性声明全部 tool_calls──
assistant_msg = self._build_assistant_message(plan, step_call_ids)
# ── 3. 串行执行每个步骤,收集 tool 结果消息 ────────────
step_results: list[StepResult] = []
tool_result_msgs: list[dict] = []
chain_context: dict[str, str] = {}
failed_step: int | None = None
for step in plan.steps:
call_id = step_call_ids[step.step_id]
if self._has_failed_dependency(step, failed_step):
self.logger.warning(
f"⏭ Step {step.step_id} [{step.tool_name}] 跳过"
f"(依赖步骤 {failed_step} 失败)"
)
step_results.append(StepResult(
step_id=step.step_id,
tool_name=step.tool_name,
success=False,
output="",
error=f"跳过:依赖步骤 {failed_step} 失败",
))
# ⚠️ 即使跳过,也必须补一条 tool 消息,保持 id 完全匹配
tool_result_msgs.append({
"role": "tool",
"content": f"步骤跳过:依赖步骤 {failed_step} 执行失败",
"tool_call_id": call_id,
})
continue
result = self._execute_single_step(step, chain_context, call_id)
step_results.append(result)
# 追加 tool 结果消息tool_call_id 与 assistant 声明完全一致)
tool_result_msgs.append({
"role": "tool",
"content": result.output if result.success else f"执行失败: {result.error}",
"tool_call_id": call_id,
})
if result.success:
chain_context[result.context_key] = result.output
self.memory.add_tool_result(step.tool_name, result.output)
else:
failed_step = step.step_id
# ── 4. 组装合规的 OpenAI 消息块 ─────────────────────────
# 格式: [assistant(tool_calls), tool, tool, ...]
openai_tool_block = [assistant_msg] + tool_result_msgs
self.logger.debug("📦 OpenAI 消息块结构:")
for i, msg in enumerate(openai_tool_block):
role = msg["role"]
if role == "assistant":
ids = [tc["id"] for tc in (msg.get("tool_calls") or [])]
self.logger.debug(f" [{i}] assistant tool_calls.ids = {ids}")
else:
self.logger.debug(
f" [{i}] tool tool_call_id = {msg.get('tool_call_id')}"
f" content = {str(msg.get('content', ''))[:50]}..."
)
overall_success = failed_step is None
chain_result = ChainResult(
goal=plan.goal,
step_results=step_results,
success=overall_success,
failed_step=failed_step,
)
self.logger.info(
f"{'' * 60}\n"
f" {'✅ 调用链执行完成' if overall_success else '⚠️ 调用链部分失败'}\n"
f" 完成: {chain_result.completed_steps}/{chain_result.total_steps}\n"
f"{'' * 60}"
)
return chain_result, openai_tool_block
def _execute_single_step(
self,
step: ToolStep,
chain_context: dict[str, str],
call_id: str,
) -> StepResult:
"""执行单个步骤,返回 StepResult"""
resolved_step = step.inject_context(chain_context)
self.logger.info(
f"\n ▶ Step {step.step_id} 执行中\n"
f" 工具 : [{resolved_step.tool_name}]\n"
f" 说明 : {resolved_step.description}\n"
f" 参数 : {resolved_step.arguments}\n"
f" call_id : {call_id}"
)
mcp_request: MCPRequest = resolved_step.to_mcp_request()
mcp_response: MCPResponse = self.mcp_server.handle_request(mcp_request)
if mcp_response.success:
output = mcp_response.content
self.logger.info(f" ✅ Step {step.step_id} 成功: {output[:80]}...")
return StepResult(
step_id=step.step_id,
tool_name=step.tool_name,
success=True,
output=output,
)
else:
error_msg = mcp_response.error.get("message", "未知错误")
self.logger.error(f" ❌ Step {step.step_id} 失败: {error_msg}")
return StepResult(
step_id=step.step_id,
tool_name=step.tool_name,
success=False,
output="",
error=error_msg,
)
# ════════════════════════════════════════════════════════════
# 回复生成
# ════════════════════════════════════════════════════════════
def _generate_response(
self,
user_input: str,
chain_result: ChainResult,
openai_tool_block: list[dict],
context: str,
) -> AgentResponse:
"""调用 OpenAI 整合工具结果,生成最终 AgentResponse"""
self.logger.info("✍️ [LLM] 调用 OpenAI 生成最终回复...")
chain_summary = self._build_chain_summary(chain_result)
# 将合规的消息块传给 LLMEngine
final_reply = self.llm.generate_chain_reply(
user_input=user_input,
chain_summary=chain_summary,
context=context,
openai_tool_block=openai_tool_block, # ← 传递完整合规消息块
)
chain_result.final_reply = final_reply
# 更新跨轮次 OpenAI 历史(只保留 user/assistant 摘要,不含 tool 消息)
self._openai_history.append({"role": "user", "content": user_input})
self._openai_history.append({"role": "assistant", "content": final_reply})
if len(self._openai_history) > 20:
self._openai_history = self._openai_history[-20:]
if chain_result.total_steps > 1:
self.memory.add_chain_result(chain_result)
else:
self.memory.add_assistant_message(final_reply)
self.logger.info("🎉 流程完成,回复已返回")
return AgentResponse(
user_input=user_input,
final_reply=final_reply,
chain_result=chain_result,
tool_used=(chain_result.step_results[0].tool_name
if chain_result.total_steps == 1 else None),
tool_output=(chain_result.step_results[0].output
if chain_result.total_steps == 1 else None),
success=chain_result.success,
)
def _handle_direct_reply(self, user_input: str, context: str) -> AgentResponse:
"""无需工具时直接调用 OpenAI 生成回复"""
self.logger.info("💬 无需工具,直接调用 OpenAI 生成回复")
reply = self.llm.generate_direct_reply(user_input, context)
self.memory.add_assistant_message(reply)
self._openai_history.append({"role": "user", "content": user_input})
self._openai_history.append({"role": "assistant", "content": reply})
return AgentResponse(user_input=user_input, final_reply=reply)
# ════════════════════════════════════════════════════════════
# 消息构造工具方法
# ════════════════════════════════════════════════════════════
@staticmethod
def _build_assistant_message(
plan: ChainPlan,
step_call_ids: dict[int, str],
) -> dict:
"""
构造 assistant 消息,一次性声明全部 tool_calls
✅ 正确格式:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123", ← 与后续 tool 消息的 tool_call_id 完全一致
"type": "function",
"function": {
"name": "calculator",
"arguments": "{\"expression\": \"1+2\"}" ← 必须是 JSON 字符串
}
}
]
}
"""
tool_calls = []
for step in plan.steps:
tool_calls.append({
"id": step_call_ids[step.step_id],
"type": "function",
"function": {
"name": step.tool_name,
"arguments": json.dumps(step.arguments, ensure_ascii=False),
},
})
return {
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
}
@staticmethod
def _build_chain_summary(chain_result: ChainResult) -> str:
"""将调用链结果格式化为 LLM 可读的摘要(降级时使用)"""
lines = []
for r in chain_result.step_results:
if r.success:
lines.append(
f"**Step {r.step_id} [{r.tool_name}]** ✅\n"
f"```\n{r.output[:300]}\n```"
)
else:
lines.append(
f"**Step {r.step_id} [{r.tool_name}]** ❌\n"
f"错误: {r.error}"
)
return "\n\n".join(lines)
@staticmethod
def _has_failed_dependency(step: ToolStep, failed_step: int | None) -> bool:
return failed_step is not None and failed_step in step.depends_on
def get_memory_stats(self) -> dict:
stats = self.memory.stats()
stats["openai_history_len"] = len(self._openai_history)
return stats
def clear_session(self) -> None:
self.memory.clear_history()
self._openai_history.clear()
self.logger.info("🗑 会话已清空(含 OpenAI 消息历史)")