488 lines
21 KiB
Python
488 lines
21 KiB
Python
"""
|
||
llm/llm_engine.py
|
||
修复:generate_chain_reply / generate_final_reply 消息序列构造
|
||
确保 tool 消息始终紧跟在含 tool_calls 的 assistant 消息之后
|
||
"""
|
||
|
||
import re
|
||
from dataclasses import dataclass
|
||
|
||
from config.settings import LLMConfig, settings
|
||
from llm.provider_factory import create_provider
|
||
from llm.providers.base_provider import BaseProvider
|
||
from mcp.mcp_protocol import ChainPlan, MCPMethod, MCPRequest, ToolSchema, ToolStep
|
||
from utils.logger import get_logger
|
||
|
||
|
||
@dataclass
|
||
class ToolDecision:
|
||
need_tool: bool
|
||
tool_name: str = ""
|
||
arguments: dict = None
|
||
reasoning: str = ""
|
||
|
||
def __post_init__(self):
|
||
self.arguments = self.arguments or {}
|
||
|
||
def to_mcp_request(self) -> MCPRequest | None:
|
||
if not self.need_tool:
|
||
return None
|
||
return MCPRequest(
|
||
method=MCPMethod.TOOLS_CALL,
|
||
params={"name": self.tool_name, "arguments": self.arguments},
|
||
)
|
||
|
||
|
||
class LLMEngine:
|
||
"""
|
||
LLM 推理引擎
|
||
|
||
✅ 修复后的消息序列规范:
|
||
|
||
generate_chain_reply() 构造的完整消息:
|
||
[
|
||
{"role": "system", "content": REPLY_SYSTEM_PROMPT},
|
||
{"role": "user", "content": "用户输入"},
|
||
# ↓ openai_tool_block(来自 AgentClient._execute_chain)
|
||
{"role": "assistant", "content": null,
|
||
"tool_calls": [{"id":"call_001","type":"function","function":{...}},
|
||
{"id":"call_002","type":"function","function":{...}}]},
|
||
{"role": "tool", "content": "结果1", "tool_call_id": "call_001"},
|
||
{"role": "tool", "content": "结果2", "tool_call_id": "call_002"},
|
||
]
|
||
→ provider.generate_reply() → 最终回复
|
||
|
||
⚠️ 绝对不能在 tool 消息前插入任何其他消息(尤其是 system/user)
|
||
"""
|
||
|
||
_MULTI_STEP_KEYWORDS = [
|
||
"然后", "接着", "再", "并且", "同时", "之后",
|
||
"先.*再", "首先.*然后", "搜索.*计算", "读取.*执行",
|
||
"多个", "分别", "依次",
|
||
]
|
||
|
||
def __init__(self, cfg: LLMConfig | None = None):
|
||
self.cfg = cfg or settings.llm
|
||
self.logger = get_logger("LLM")
|
||
self.provider: BaseProvider = create_provider(self.cfg)
|
||
self._log_init()
|
||
|
||
def _log_init(self) -> None:
|
||
self.logger.info("🧠 LLM 引擎初始化完成")
|
||
self.logger.info(f" provider = {self.cfg.provider}")
|
||
self.logger.info(f" model_name = {self.cfg.model_name}")
|
||
self.logger.info(f" function_calling = {self.cfg.function_calling}")
|
||
self.logger.info(f" temperature = {self.cfg.temperature}")
|
||
self.logger.info(f" fallback_rules = {settings.agent.fallback_to_rules}")
|
||
|
||
def reconfigure(self, cfg: LLMConfig) -> None:
|
||
self.cfg = cfg
|
||
self.provider = create_provider(cfg)
|
||
self.logger.info(f"🔄 LLM 配置已更新: model={cfg.model_name}")
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 工具规划
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def plan_tool_chain(
|
||
self,
|
||
user_input: str,
|
||
tool_schemas: list[ToolSchema],
|
||
context: str = "",
|
||
history: list[dict] | None = None,
|
||
) -> ChainPlan:
|
||
"""使用 OpenAI Function Calling 规划工具调用链"""
|
||
self.logger.info(f"🗺 规划工具调用链: {user_input[:60]}...")
|
||
|
||
messages = self._build_plan_messages(user_input, context, history)
|
||
|
||
if self.cfg.function_calling:
|
||
result = self.provider.plan_with_tools(messages, tool_schemas)
|
||
|
||
if result.success and result.plan is not None:
|
||
plan = result.plan
|
||
if not plan.goal:
|
||
plan.goal = user_input
|
||
self.logger.info(f"📋 OpenAI 规划完成: {plan.step_count} 步")
|
||
for step in plan.steps:
|
||
self.logger.info(
|
||
f" Step {step.step_id}: [{step.tool_name}] "
|
||
f"args={step.arguments}"
|
||
)
|
||
return plan
|
||
|
||
self.logger.warning(f"⚠️ OpenAI 规划失败: {result.error}")
|
||
if settings.agent.fallback_to_rules:
|
||
self.logger.info("🔄 降级到规则引擎...")
|
||
return self._rule_based_plan(user_input)
|
||
return ChainPlan(goal=user_input, steps=[])
|
||
else:
|
||
self.logger.info("⚙️ function_calling=false,使用规则引擎")
|
||
return self._rule_based_plan(user_input)
|
||
|
||
def think_and_decide(
|
||
self,
|
||
user_input: str,
|
||
tool_schemas: list[ToolSchema],
|
||
context: str = "",
|
||
) -> ToolDecision:
|
||
plan = self.plan_tool_chain(user_input, tool_schemas, context)
|
||
if not plan.steps:
|
||
return ToolDecision(need_tool=False, reasoning="无需工具,直接回复")
|
||
first = plan.steps[0]
|
||
return ToolDecision(
|
||
need_tool=True,
|
||
tool_name=first.tool_name,
|
||
arguments=first.arguments,
|
||
reasoning=first.description,
|
||
)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 回复生成(核心修复区域)
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def generate_chain_reply(
|
||
self,
|
||
user_input: str,
|
||
chain_summary: str,
|
||
context: str = "",
|
||
openai_tool_block: list[dict] | None = None,
|
||
) -> str:
|
||
"""
|
||
整合工具执行结果,调用 OpenAI 生成最终自然语言回复
|
||
|
||
✅ 修复后消息序列:
|
||
system → 回复生成提示
|
||
user → 原始用户输入
|
||
← openai_tool_block 直接追加(已包含 assistant+tool 消息)→
|
||
assistant(tool_calls=[...])
|
||
tool(tool_call_id=..., content=...)
|
||
tool(tool_call_id=..., content=...)
|
||
...
|
||
|
||
❌ 修复前的错误(导致 400):
|
||
system → user → tool → tool ← tool 前缺少 assistant(tool_calls)
|
||
|
||
Args:
|
||
user_input: 原始用户输入
|
||
chain_summary: 步骤摘要(API 失败时的降级内容)
|
||
context: 对话历史(仅规划阶段使用,回复阶段不注入)
|
||
openai_tool_block: 由 AgentClient 构造的合规消息块
|
||
格式: [assistant(tool_calls), tool, tool, ...]
|
||
"""
|
||
self.logger.info("✍️ 生成最终回复(工具调用链模式)...")
|
||
|
||
if not openai_tool_block:
|
||
self.logger.warning("⚠️ openai_tool_block 为空,降级到摘要模板")
|
||
return self._fallback_chain_reply(user_input, chain_summary)
|
||
|
||
# 验证消息块合规性
|
||
if not self._validate_tool_block(openai_tool_block):
|
||
self.logger.warning("⚠️ 消息块验证失败,降级到摘要模板")
|
||
return self._fallback_chain_reply(user_input, chain_summary)
|
||
|
||
# ✅ 正确的消息序列构造
|
||
messages = self._build_reply_messages_with_block(user_input, openai_tool_block)
|
||
|
||
self.logger.debug(f"📤 发送回复请求,消息数: {len(messages)}")
|
||
self._log_messages_structure(messages)
|
||
|
||
result = self.provider.generate_reply(messages)
|
||
|
||
if result.success and result.content:
|
||
self.logger.info(
|
||
f"✅ OpenAI 回复生成成功 ({len(result.content)} chars)"
|
||
)
|
||
return result.content
|
||
|
||
self.logger.warning(f"⚠️ OpenAI 回复生成失败: {result.error}")
|
||
return self._fallback_chain_reply(user_input, chain_summary)
|
||
|
||
def generate_final_reply(
|
||
self,
|
||
user_input: str,
|
||
tool_name: str,
|
||
tool_output: str,
|
||
context: str = "",
|
||
tool_call_id: str = "",
|
||
) -> str:
|
||
"""
|
||
单步工具结果整合(调用 OpenAI 生成自然语言回复)
|
||
|
||
✅ 修复后消息序列:
|
||
system → 回复生成提示
|
||
user → 原始用户输入
|
||
assistant → tool_calls=[{id: tool_call_id, function: {name, arguments}}]
|
||
tool → content=tool_output, tool_call_id=tool_call_id
|
||
|
||
❌ 修复前的错误:
|
||
system → user → tool ← 缺少 assistant(tool_calls) 前置消息
|
||
"""
|
||
self.logger.info(f"✍️ 整合单步工具结果 [{tool_name}]...")
|
||
|
||
if not tool_call_id:
|
||
# 无 tool_call_id 时降级到直接回复模式
|
||
self.logger.warning("⚠️ tool_call_id 为空,使用直接回复模式")
|
||
return self._generate_simple_reply(user_input, tool_name, tool_output)
|
||
|
||
import json
|
||
# 构造单步合规消息块
|
||
single_tool_block = [
|
||
{
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": [{
|
||
"id": tool_call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_name,
|
||
"arguments": json.dumps({"result": tool_output[:100]},
|
||
ensure_ascii=False),
|
||
},
|
||
}],
|
||
},
|
||
{
|
||
"role": "tool",
|
||
"content": tool_output,
|
||
"tool_call_id": tool_call_id,
|
||
},
|
||
]
|
||
|
||
return self.generate_chain_reply(
|
||
user_input=user_input,
|
||
chain_summary=tool_output,
|
||
context=context,
|
||
openai_tool_block=single_tool_block,
|
||
)
|
||
|
||
def generate_direct_reply(self, user_input: str, context: str = "") -> str:
|
||
"""无需工具时直接调用 OpenAI 生成回复(不涉及 tool 消息,无需修复)"""
|
||
self.logger.info("💬 直接生成回复(无需工具)...")
|
||
messages = [
|
||
{"role": "system", "content": "你是一个友好、专业的 AI 助手,请简洁准确地回答用户问题。"},
|
||
{"role": "user", "content": user_input},
|
||
]
|
||
result = self.provider.generate_reply(messages)
|
||
if result.success and result.content:
|
||
return result.content
|
||
return (
|
||
f"您好!关于「{user_input}」,我已收到您的问题。\n"
|
||
f"(API 暂时不可用,请检查 API Key 配置)"
|
||
)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 消息构造(修复核心)
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
@staticmethod
|
||
def _build_reply_messages_with_block(
|
||
user_input: str,
|
||
openai_tool_block: list[dict],
|
||
) -> list[dict]:
|
||
"""
|
||
构造回复生成阶段的完整消息列表
|
||
|
||
✅ 正确结构:
|
||
[
|
||
{"role": "system", "content": REPLY_SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_input},
|
||
{"role": "assistant", "content": null, "tool_calls": [...]}, ← tool_block[0]
|
||
{"role": "tool", "content": "...", "tool_call_id": "..."}, ← tool_block[1]
|
||
{"role": "tool", "content": "...", "tool_call_id": "..."}, ← tool_block[2]
|
||
...
|
||
]
|
||
|
||
⚠️ 注意: openai_tool_block 必须整体追加,不能拆分或重排
|
||
"""
|
||
from llm.providers.openai_provider import OpenAIProvider
|
||
messages: list[dict] = [
|
||
{"role": "system", "content": OpenAIProvider._REPLY_SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_input},
|
||
]
|
||
# 整体追加工具消息块(assistant + tool(s))
|
||
messages.extend(openai_tool_block)
|
||
return messages
|
||
|
||
@staticmethod
|
||
def _build_plan_messages(
|
||
user_input: str,
|
||
context: str,
|
||
history: list[dict] | None,
|
||
) -> list[dict]:
|
||
"""构造规划阶段的消息列表(不含 tool 消息,无需修复)"""
|
||
from llm.providers.openai_provider import OpenAIProvider
|
||
messages: list[dict] = [
|
||
{"role": "system", "content": OpenAIProvider._PLANNER_SYSTEM_PROMPT},
|
||
]
|
||
if history:
|
||
# 过滤掉 tool 消息,只保留 user/assistant 对话历史
|
||
clean_history = [
|
||
m for m in history[-6:]
|
||
if m.get("role") in ("user", "assistant")
|
||
and not m.get("tool_calls") # 排除含 tool_calls 的 assistant 消息
|
||
]
|
||
messages.extend(clean_history)
|
||
elif context and context != "(暂无对话历史)":
|
||
messages.append({
|
||
"role": "system",
|
||
"content": f"## 对话历史\n{context}",
|
||
})
|
||
messages.append({"role": "user", "content": user_input})
|
||
return messages
|
||
|
||
@staticmethod
|
||
def _validate_tool_block(openai_tool_block: list[dict]) -> bool:
|
||
"""
|
||
验证 openai_tool_block 消息块合规性
|
||
|
||
规则:
|
||
1. 第一条必须是 role=assistant 且含 tool_calls
|
||
2. 后续每条必须是 role=tool 且含 tool_call_id
|
||
3. tool 消息数量 == assistant.tool_calls 数量
|
||
4. 所有 tool_call_id 必须在 assistant.tool_calls[].id 中存在
|
||
"""
|
||
if not openai_tool_block:
|
||
return False
|
||
|
||
first = openai_tool_block[0]
|
||
if first.get("role") != "assistant" or not first.get("tool_calls"):
|
||
return False
|
||
|
||
declared_ids = {tc["id"] for tc in first["tool_calls"]}
|
||
tool_msgs = openai_tool_block[1:]
|
||
|
||
if len(tool_msgs) != len(declared_ids):
|
||
return False
|
||
|
||
for msg in tool_msgs:
|
||
if msg.get("role") != "tool":
|
||
return False
|
||
if msg.get("tool_call_id") not in declared_ids:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _generate_simple_reply(
|
||
self,
|
||
user_input: str,
|
||
tool_name: str,
|
||
tool_output: str,
|
||
) -> str:
|
||
"""无 tool_call_id 时的降级回复(不走 Function Calling 协议)"""
|
||
messages = [
|
||
{"role": "system",
|
||
"content": "你是一个友好的 AI 助手,请基于工具执行结果回答用户问题。"},
|
||
{"role": "user",
|
||
"content": (
|
||
f"用户问题: {user_input}\n\n"
|
||
f"工具 [{tool_name}] 执行结果:\n{tool_output}\n\n"
|
||
f"请基于以上结果给出清晰的回答。"
|
||
)},
|
||
]
|
||
result = self.provider.generate_reply(messages)
|
||
return result.content if result.success else self._fallback_chain_reply(
|
||
user_input, tool_output
|
||
)
|
||
|
||
def _log_messages_structure(self, messages: list[dict]) -> None:
|
||
"""调试:打印消息序列结构(不打印完整内容)"""
|
||
self.logger.debug("📋 消息序列结构:")
|
||
for i, msg in enumerate(messages):
|
||
role = msg.get("role", "?")
|
||
if role == "assistant" and msg.get("tool_calls"):
|
||
ids = [tc["id"] for tc in msg["tool_calls"]]
|
||
names = [tc["function"]["name"] for tc in msg["tool_calls"]]
|
||
self.logger.debug(
|
||
f" [{i}] {role:10s} tool_calls={names} ids={ids}"
|
||
)
|
||
elif role == "tool":
|
||
self.logger.debug(
|
||
f" [{i}] {role:10s} tool_call_id={msg.get('tool_call_id')} "
|
||
f"content={str(msg.get('content',''))[:40]}..."
|
||
)
|
||
else:
|
||
content_preview = str(msg.get("content", ""))[:50]
|
||
self.logger.debug(f" [{i}] {role:10s} {content_preview}...")
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 降级规则引擎
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def _rule_based_plan(self, user_input: str) -> ChainPlan:
|
||
self.logger.info("⚙️ 使用规则引擎规划...")
|
||
text = user_input.lower()
|
||
|
||
if (any(k in text for k in ["搜索", "查询", "查一下"]) and
|
||
any(k in text for k in ["计算", "算", "等于", "结果"])):
|
||
return ChainPlan(
|
||
goal=user_input,
|
||
steps=[
|
||
ToolStep(1, "web_search",
|
||
{"query": user_input,
|
||
"max_results": settings.tools.web_search.max_results},
|
||
"搜索相关信息", []),
|
||
ToolStep(2, "calculator",
|
||
{"expression": self._extract_expression(user_input)},
|
||
"进行计算", [1]),
|
||
],
|
||
)
|
||
if (any(k in text for k in ["读取", "文件", "file"]) and
|
||
any(k in text for k in ["执行", "运行", "run"])):
|
||
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
|
||
return ChainPlan(
|
||
goal=user_input,
|
||
steps=[
|
||
ToolStep(1, "file_reader",
|
||
{"path": fname.group() if fname else "script.py"},
|
||
"读取文件", []),
|
||
ToolStep(2, "code_executor",
|
||
{"code": "{{STEP_1_OUTPUT}}",
|
||
"timeout": settings.tools.code_executor.timeout},
|
||
"执行代码", [1]),
|
||
],
|
||
)
|
||
return self._rule_single_step(user_input)
|
||
|
||
def _rule_single_step(self, user_input: str) -> ChainPlan:
|
||
text = user_input.lower()
|
||
if any(k in text for k in ["计算", "等于", "×", "÷", "+", "-", "*", "/"]):
|
||
expr = self._extract_expression(user_input)
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "calculator",
|
||
{"expression": expr}, "数学计算")])
|
||
if any(k in text for k in ["搜索", "查询", "天气", "新闻"]):
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "web_search",
|
||
{"query": user_input,
|
||
"max_results": settings.tools.web_search.max_results},
|
||
"网络搜索")])
|
||
if any(k in text for k in ["文件", "读取", "file"]):
|
||
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "file_reader",
|
||
{"path": fname.group() if fname else "config.json"},
|
||
"读取文件")])
|
||
if any(k in text for k in ["执行", "运行", "代码", "python"]):
|
||
code_m = re.search(r"[`'\"](.+?)[`'\"]", user_input)
|
||
code = code_m.group(1) if code_m else 'print("Hello, Agent!")'
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "code_executor",
|
||
{"code": code,
|
||
"timeout": settings.tools.code_executor.timeout},
|
||
"执行代码")])
|
||
return ChainPlan(goal=user_input, is_single=True, steps=[])
|
||
|
||
@staticmethod
|
||
def _fallback_chain_reply(user_input: str, chain_summary: str) -> str:
|
||
return (
|
||
f"✅ **任务已完成**\n\n"
|
||
f"针对您的需求「{user_input}」,执行结果如下:\n\n"
|
||
f"{chain_summary}"
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_expression(text: str) -> str:
|
||
cleaned = text.replace("×", "*").replace("÷", "/").replace(",", "")
|
||
match = re.search(r"[\d\s\+\-\*\/\(\)\.]+", cleaned)
|
||
expr = match.group().strip() if match else "1+1"
|
||
return expr if len(expr) > 1 else "1+1" |