412 lines
17 KiB
Python
412 lines
17 KiB
Python
"""
|
||
client/agent_client.py
|
||
修复:OpenAI 消息序列必须满足
|
||
user → assistant(tool_calls) → tool(s) → assistant(reply)
|
||
每个 tool 消息的 tool_call_id 必须与同一 assistant 消息中的 tool_calls[].id 完全匹配
|
||
"""
|
||
|
||
import json
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
|
||
from config.settings import settings
|
||
from llm.llm_engine import LLMEngine
|
||
from mcp.mcp_protocol import (
|
||
ChainPlan, ChainResult, MCPRequest, MCPResponse,
|
||
StepResult, ToolStep,
|
||
)
|
||
from mcp.mcp_server import MCPServer
|
||
from memory.memory_store import MemoryStore
|
||
from utils.logger import get_logger
|
||
|
||
|
||
@dataclass
|
||
class AgentResponse:
|
||
"""一次完整 Agent 调用的结果"""
|
||
user_input: str
|
||
final_reply: str
|
||
chain_result: ChainResult | None = None
|
||
tool_used: str | None = None
|
||
tool_output: str | None = None
|
||
success: bool = True
|
||
error: str | None = None
|
||
token_usage: dict = field(default_factory=dict)
|
||
|
||
@property
|
||
def is_multi_step(self) -> bool:
|
||
return self.chain_result is not None and self.chain_result.total_steps > 1
|
||
|
||
|
||
class AgentClient:
|
||
"""
|
||
Agent 客户端:OpenAI Function Calling + Multi-Step Tool Chain
|
||
|
||
✅ 修复后的 OpenAI 消息序列规范:
|
||
|
||
【单步工具调用】
|
||
{"role": "system", "content": "..."}
|
||
{"role": "user", "content": "计算 1+1"}
|
||
{"role": "assistant", "content": null,
|
||
"tool_calls": [{"id": "call_abc", "type": "function",
|
||
"function": {"name": "calculator", "arguments": "{...}"}}]}
|
||
{"role": "tool", "content": "结果: 2", "tool_call_id": "call_abc"}
|
||
→ generate_reply() → 最终回复
|
||
|
||
【多步工具调用(并行声明,串行执行)】
|
||
{"role": "system", "content": "..."}
|
||
{"role": "user", "content": "搜索天气然后计算..."}
|
||
{"role": "assistant", "content": null,
|
||
"tool_calls": [{"id": "call_001", ...web_search...},
|
||
{"id": "call_002", ...calculator...}]}
|
||
{"role": "tool", "content": "搜索结果", "tool_call_id": "call_001"}
|
||
{"role": "tool", "content": "计算结果", "tool_call_id": "call_002"}
|
||
→ generate_reply() → 最终回复
|
||
|
||
⚠️ 关键约束:
|
||
1. tool 消息必须紧跟在含 tool_calls 的 assistant 消息之后
|
||
2. 每个 tool 的 tool_call_id 必须与 assistant.tool_calls[i].id 完全一致
|
||
3. tool_calls 中声明了几个工具,就必须有几条对应的 tool 消息
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm: LLMEngine,
|
||
mcp_server: MCPServer,
|
||
memory: MemoryStore,
|
||
):
|
||
self.llm = llm
|
||
self.mcp_server = mcp_server
|
||
self.memory = memory
|
||
self.logger = get_logger("CLIENT")
|
||
self._openai_history: list[dict] = []
|
||
self.logger.info("💻 Agent Client 初始化完成(OpenAI Function Calling 模式)")
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 主入口
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def chat(self, user_input: str) -> AgentResponse:
|
||
"""处理一轮用户对话"""
|
||
sep = "═" * 60
|
||
self.logger.info(sep)
|
||
self.logger.info(f"📨 收到用户输入: {user_input}")
|
||
self.logger.info(sep)
|
||
|
||
self.memory.add_user_message(user_input)
|
||
context = self.memory.get_context_summary()
|
||
|
||
# Step 1: LLM 规划工具调用链
|
||
self.logger.info("🗺 [LLM] 规划工具调用链...")
|
||
tool_schemas = self.mcp_server.get_tool_schemas()
|
||
plan: ChainPlan = self.llm.plan_tool_chain(
|
||
user_input=user_input,
|
||
tool_schemas=tool_schemas,
|
||
context=context,
|
||
history=self._openai_history,
|
||
)
|
||
|
||
# 无需工具:直接回复
|
||
if not plan.steps:
|
||
return self._handle_direct_reply(user_input, context)
|
||
|
||
# Step 2: 执行工具调用链,构造合规的 OpenAI 消息序列
|
||
chain_result, openai_tool_block = self._execute_chain(plan, user_input)
|
||
|
||
# Step 3: 调用 OpenAI 整合结果,生成最终回复
|
||
return self._generate_response(
|
||
user_input, chain_result, openai_tool_block, context
|
||
)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 串行执行引擎
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def _execute_chain(
|
||
self,
|
||
plan: ChainPlan,
|
||
user_input: str,
|
||
) -> tuple[ChainResult, list[dict]]:
|
||
"""
|
||
串行执行工具调用链,构造符合 OpenAI 协议的消息块
|
||
|
||
Returns:
|
||
chain_result: 执行结果汇总
|
||
openai_tool_block: 合规的 OpenAI 消息块,结构为:
|
||
[
|
||
assistant(tool_calls=[call_001, call_002, ...]),
|
||
tool(tool_call_id=call_001, content=...),
|
||
tool(tool_call_id=call_002, content=...),
|
||
...
|
||
]
|
||
|
||
⚠️ 关键:所有 tool_calls 在 assistant 消息中一次性声明,
|
||
然后逐条追加对应的 tool 结果消息,保证 id 完全匹配。
|
||
"""
|
||
self.logger.info(
|
||
f"\n{'─' * 60}\n"
|
||
f" 🔗 开始执行工具调用链\n"
|
||
f" 目标: {plan.goal}\n"
|
||
f" 步骤: {plan.step_count} 步\n"
|
||
f"{'─' * 60}"
|
||
)
|
||
|
||
# ── 1. 预先为每个步骤生成稳定的 tool_call_id ──────────
|
||
# 必须在执行前全部生成,assistant 消息和 tool 消息共享同一批 id
|
||
step_call_ids: dict[int, str] = {
|
||
step.step_id: f"call_{uuid.uuid4().hex[:12]}"
|
||
for step in plan.steps
|
||
}
|
||
self.logger.debug(f"🔑 预生成 tool_call_ids: {step_call_ids}")
|
||
|
||
# ── 2. 构造 assistant 消息(一次性声明全部 tool_calls)──
|
||
assistant_msg = self._build_assistant_message(plan, step_call_ids)
|
||
|
||
# ── 3. 串行执行每个步骤,收集 tool 结果消息 ────────────
|
||
step_results: list[StepResult] = []
|
||
tool_result_msgs: list[dict] = []
|
||
chain_context: dict[str, str] = {}
|
||
failed_step: int | None = None
|
||
|
||
for step in plan.steps:
|
||
call_id = step_call_ids[step.step_id]
|
||
|
||
if self._has_failed_dependency(step, failed_step):
|
||
self.logger.warning(
|
||
f"⏭ Step {step.step_id} [{step.tool_name}] 跳过"
|
||
f"(依赖步骤 {failed_step} 失败)"
|
||
)
|
||
step_results.append(StepResult(
|
||
step_id=step.step_id,
|
||
tool_name=step.tool_name,
|
||
success=False,
|
||
output="",
|
||
error=f"跳过:依赖步骤 {failed_step} 失败",
|
||
))
|
||
# ⚠️ 即使跳过,也必须补一条 tool 消息,保持 id 完全匹配
|
||
tool_result_msgs.append({
|
||
"role": "tool",
|
||
"content": f"步骤跳过:依赖步骤 {failed_step} 执行失败",
|
||
"tool_call_id": call_id,
|
||
})
|
||
continue
|
||
|
||
result = self._execute_single_step(step, chain_context, call_id)
|
||
step_results.append(result)
|
||
|
||
# 追加 tool 结果消息(tool_call_id 与 assistant 声明完全一致)
|
||
tool_result_msgs.append({
|
||
"role": "tool",
|
||
"content": result.output if result.success else f"执行失败: {result.error}",
|
||
"tool_call_id": call_id,
|
||
})
|
||
|
||
if result.success:
|
||
chain_context[result.context_key] = result.output
|
||
self.memory.add_tool_result(step.tool_name, result.output)
|
||
else:
|
||
failed_step = step.step_id
|
||
|
||
# ── 4. 组装合规的 OpenAI 消息块 ─────────────────────────
|
||
# 格式: [assistant(tool_calls), tool, tool, ...]
|
||
openai_tool_block = [assistant_msg] + tool_result_msgs
|
||
|
||
self.logger.debug("📦 OpenAI 消息块结构:")
|
||
for i, msg in enumerate(openai_tool_block):
|
||
role = msg["role"]
|
||
if role == "assistant":
|
||
ids = [tc["id"] for tc in (msg.get("tool_calls") or [])]
|
||
self.logger.debug(f" [{i}] assistant tool_calls.ids = {ids}")
|
||
else:
|
||
self.logger.debug(
|
||
f" [{i}] tool tool_call_id = {msg.get('tool_call_id')}"
|
||
f" content = {str(msg.get('content', ''))[:50]}..."
|
||
)
|
||
|
||
overall_success = failed_step is None
|
||
chain_result = ChainResult(
|
||
goal=plan.goal,
|
||
step_results=step_results,
|
||
success=overall_success,
|
||
failed_step=failed_step,
|
||
)
|
||
|
||
self.logger.info(
|
||
f"{'─' * 60}\n"
|
||
f" {'✅ 调用链执行完成' if overall_success else '⚠️ 调用链部分失败'}\n"
|
||
f" 完成: {chain_result.completed_steps}/{chain_result.total_steps} 步\n"
|
||
f"{'─' * 60}"
|
||
)
|
||
return chain_result, openai_tool_block
|
||
|
||
def _execute_single_step(
|
||
self,
|
||
step: ToolStep,
|
||
chain_context: dict[str, str],
|
||
call_id: str,
|
||
) -> StepResult:
|
||
"""执行单个步骤,返回 StepResult"""
|
||
resolved_step = step.inject_context(chain_context)
|
||
|
||
self.logger.info(
|
||
f"\n ▶ Step {step.step_id} 执行中\n"
|
||
f" 工具 : [{resolved_step.tool_name}]\n"
|
||
f" 说明 : {resolved_step.description}\n"
|
||
f" 参数 : {resolved_step.arguments}\n"
|
||
f" call_id : {call_id}"
|
||
)
|
||
|
||
mcp_request: MCPRequest = resolved_step.to_mcp_request()
|
||
mcp_response: MCPResponse = self.mcp_server.handle_request(mcp_request)
|
||
|
||
if mcp_response.success:
|
||
output = mcp_response.content
|
||
self.logger.info(f" ✅ Step {step.step_id} 成功: {output[:80]}...")
|
||
return StepResult(
|
||
step_id=step.step_id,
|
||
tool_name=step.tool_name,
|
||
success=True,
|
||
output=output,
|
||
)
|
||
else:
|
||
error_msg = mcp_response.error.get("message", "未知错误")
|
||
self.logger.error(f" ❌ Step {step.step_id} 失败: {error_msg}")
|
||
return StepResult(
|
||
step_id=step.step_id,
|
||
tool_name=step.tool_name,
|
||
success=False,
|
||
output="",
|
||
error=error_msg,
|
||
)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 回复生成
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def _generate_response(
|
||
self,
|
||
user_input: str,
|
||
chain_result: ChainResult,
|
||
openai_tool_block: list[dict],
|
||
context: str,
|
||
) -> AgentResponse:
|
||
"""调用 OpenAI 整合工具结果,生成最终 AgentResponse"""
|
||
self.logger.info("✍️ [LLM] 调用 OpenAI 生成最终回复...")
|
||
|
||
chain_summary = self._build_chain_summary(chain_result)
|
||
|
||
# 将合规的消息块传给 LLMEngine
|
||
final_reply = self.llm.generate_chain_reply(
|
||
user_input=user_input,
|
||
chain_summary=chain_summary,
|
||
context=context,
|
||
openai_tool_block=openai_tool_block, # ← 传递完整合规消息块
|
||
)
|
||
|
||
chain_result.final_reply = final_reply
|
||
|
||
# 更新跨轮次 OpenAI 历史(只保留 user/assistant 摘要,不含 tool 消息)
|
||
self._openai_history.append({"role": "user", "content": user_input})
|
||
self._openai_history.append({"role": "assistant", "content": final_reply})
|
||
if len(self._openai_history) > 20:
|
||
self._openai_history = self._openai_history[-20:]
|
||
|
||
if chain_result.total_steps > 1:
|
||
self.memory.add_chain_result(chain_result)
|
||
else:
|
||
self.memory.add_assistant_message(final_reply)
|
||
|
||
self.logger.info("🎉 流程完成,回复已返回")
|
||
return AgentResponse(
|
||
user_input=user_input,
|
||
final_reply=final_reply,
|
||
chain_result=chain_result,
|
||
tool_used=(chain_result.step_results[0].tool_name
|
||
if chain_result.total_steps == 1 else None),
|
||
tool_output=(chain_result.step_results[0].output
|
||
if chain_result.total_steps == 1 else None),
|
||
success=chain_result.success,
|
||
)
|
||
|
||
def _handle_direct_reply(self, user_input: str, context: str) -> AgentResponse:
|
||
"""无需工具时直接调用 OpenAI 生成回复"""
|
||
self.logger.info("💬 无需工具,直接调用 OpenAI 生成回复")
|
||
reply = self.llm.generate_direct_reply(user_input, context)
|
||
self.memory.add_assistant_message(reply)
|
||
self._openai_history.append({"role": "user", "content": user_input})
|
||
self._openai_history.append({"role": "assistant", "content": reply})
|
||
return AgentResponse(user_input=user_input, final_reply=reply)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 消息构造工具方法
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
@staticmethod
|
||
def _build_assistant_message(
|
||
plan: ChainPlan,
|
||
step_call_ids: dict[int, str],
|
||
) -> dict:
|
||
"""
|
||
构造 assistant 消息,一次性声明全部 tool_calls
|
||
|
||
✅ 正确格式:
|
||
{
|
||
"role": "assistant",
|
||
"content": null,
|
||
"tool_calls": [
|
||
{
|
||
"id": "call_abc123", ← 与后续 tool 消息的 tool_call_id 完全一致
|
||
"type": "function",
|
||
"function": {
|
||
"name": "calculator",
|
||
"arguments": "{\"expression\": \"1+2\"}" ← 必须是 JSON 字符串
|
||
}
|
||
}
|
||
]
|
||
}
|
||
"""
|
||
tool_calls = []
|
||
for step in plan.steps:
|
||
tool_calls.append({
|
||
"id": step_call_ids[step.step_id],
|
||
"type": "function",
|
||
"function": {
|
||
"name": step.tool_name,
|
||
"arguments": json.dumps(step.arguments, ensure_ascii=False),
|
||
},
|
||
})
|
||
return {
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": tool_calls,
|
||
}
|
||
|
||
@staticmethod
|
||
def _build_chain_summary(chain_result: ChainResult) -> str:
|
||
"""将调用链结果格式化为 LLM 可读的摘要(降级时使用)"""
|
||
lines = []
|
||
for r in chain_result.step_results:
|
||
if r.success:
|
||
lines.append(
|
||
f"**Step {r.step_id} [{r.tool_name}]** ✅\n"
|
||
f"```\n{r.output[:300]}\n```"
|
||
)
|
||
else:
|
||
lines.append(
|
||
f"**Step {r.step_id} [{r.tool_name}]** ❌\n"
|
||
f"错误: {r.error}"
|
||
)
|
||
return "\n\n".join(lines)
|
||
|
||
@staticmethod
|
||
def _has_failed_dependency(step: ToolStep, failed_step: int | None) -> bool:
|
||
return failed_step is not None and failed_step in step.depends_on
|
||
|
||
def get_memory_stats(self) -> dict:
|
||
stats = self.memory.stats()
|
||
stats["openai_history_len"] = len(self._openai_history)
|
||
return stats
|
||
|
||
def clear_session(self) -> None:
|
||
self.memory.clear_history()
|
||
self._openai_history.clear()
|
||
self.logger.info("🗑 会话已清空(含 OpenAI 消息历史)") |