base_agent/llm/llm_engine.py

488 lines
21 KiB
Python
Raw Permalink Normal View History

2026-02-28 08:21:35 +00:00
"""
llm/llm_engine.py
2026-03-09 05:40:27 +00:00
修复generate_chain_reply / generate_final_reply 消息序列构造
确保 tool 消息始终紧跟在含 tool_calls assistant 消息之后
2026-02-28 08:21:35 +00:00
"""
import re
from dataclasses import dataclass
2026-03-09 05:37:29 +00:00
from config.settings import LLMConfig, settings
from llm.provider_factory import create_provider
from llm.providers.base_provider import BaseProvider
from mcp.mcp_protocol import ChainPlan, MCPMethod, MCPRequest, ToolSchema, ToolStep
2026-02-28 08:21:35 +00:00
from utils.logger import get_logger
@dataclass
class ToolDecision:
2026-02-28 14:59:41 +00:00
need_tool: bool
2026-03-09 05:37:29 +00:00
tool_name: str = ""
2026-02-28 14:59:41 +00:00
arguments: dict = None
2026-03-09 05:37:29 +00:00
reasoning: str = ""
2026-02-28 08:21:35 +00:00
def __post_init__(self):
self.arguments = self.arguments or {}
def to_mcp_request(self) -> MCPRequest | None:
if not self.need_tool:
return None
return MCPRequest(
method=MCPMethod.TOOLS_CALL,
params={"name": self.tool_name, "arguments": self.arguments},
)
class LLMEngine:
"""
2026-03-09 05:40:27 +00:00
LLM 推理引擎
修复后的消息序列规范:
generate_chain_reply() 构造的完整消息:
[
{"role": "system", "content": REPLY_SYSTEM_PROMPT},
{"role": "user", "content": "用户输入"},
# ↓ openai_tool_block来自 AgentClient._execute_chain
{"role": "assistant", "content": null,
"tool_calls": [{"id":"call_001","type":"function","function":{...}},
{"id":"call_002","type":"function","function":{...}}]},
{"role": "tool", "content": "结果1", "tool_call_id": "call_001"},
{"role": "tool", "content": "结果2", "tool_call_id": "call_002"},
]
provider.generate_reply() 最终回复
绝对不能在 tool 消息前插入任何其他消息尤其是 system/user
2026-02-28 08:21:35 +00:00
"""
2026-03-09 05:37:29 +00:00
_MULTI_STEP_KEYWORDS = [
"然后", "接着", "", "并且", "同时", "之后",
"先.*再", "首先.*然后", "搜索.*计算", "读取.*执行",
"多个", "分别", "依次",
]
def __init__(self, cfg: LLMConfig | None = None):
self.cfg = cfg or settings.llm
self.logger = get_logger("LLM")
self.provider: BaseProvider = create_provider(self.cfg)
self._log_init()
def _log_init(self) -> None:
self.logger.info("🧠 LLM 引擎初始化完成")
self.logger.info(f" provider = {self.cfg.provider}")
self.logger.info(f" model_name = {self.cfg.model_name}")
self.logger.info(f" function_calling = {self.cfg.function_calling}")
self.logger.info(f" temperature = {self.cfg.temperature}")
self.logger.info(f" fallback_rules = {settings.agent.fallback_to_rules}")
def reconfigure(self, cfg: LLMConfig) -> None:
self.cfg = cfg
self.provider = create_provider(cfg)
self.logger.info(f"🔄 LLM 配置已更新: model={cfg.model_name}")
# ════════════════════════════════════════════════════════════
2026-03-09 05:40:27 +00:00
# 工具规划
2026-03-09 05:37:29 +00:00
# ════════════════════════════════════════════════════════════
def plan_tool_chain(
self,
user_input: str,
tool_schemas: list[ToolSchema],
context: str = "",
history: list[dict] | None = None,
) -> ChainPlan:
2026-03-09 05:40:27 +00:00
"""使用 OpenAI Function Calling 规划工具调用链"""
2026-03-09 05:37:29 +00:00
self.logger.info(f"🗺 规划工具调用链: {user_input[:60]}...")
messages = self._build_plan_messages(user_input, context, history)
if self.cfg.function_calling:
result = self.provider.plan_with_tools(messages, tool_schemas)
if result.success and result.plan is not None:
plan = result.plan
if not plan.goal:
plan.goal = user_input
self.logger.info(f"📋 OpenAI 规划完成: {plan.step_count}")
for step in plan.steps:
self.logger.info(
f" Step {step.step_id}: [{step.tool_name}] "
f"args={step.arguments}"
)
return plan
self.logger.warning(f"⚠️ OpenAI 规划失败: {result.error}")
if settings.agent.fallback_to_rules:
self.logger.info("🔄 降级到规则引擎...")
return self._rule_based_plan(user_input)
return ChainPlan(goal=user_input, steps=[])
else:
self.logger.info("⚙️ function_calling=false使用规则引擎")
return self._rule_based_plan(user_input)
2026-02-28 08:21:35 +00:00
2026-03-09 05:37:29 +00:00
def think_and_decide(
self,
user_input: str,
tool_schemas: list[ToolSchema],
context: str = "",
) -> ToolDecision:
plan = self.plan_tool_chain(user_input, tool_schemas, context)
if not plan.steps:
return ToolDecision(need_tool=False, reasoning="无需工具,直接回复")
first = plan.steps[0]
return ToolDecision(
need_tool=True,
tool_name=first.tool_name,
arguments=first.arguments,
reasoning=first.description,
2026-02-28 08:21:35 +00:00
)
2026-03-09 05:40:27 +00:00
# ════════════════════════════════════════════════════════════
# 回复生成(核心修复区域)
# ════════════════════════════════════════════════════════════
2026-03-09 05:37:29 +00:00
def generate_chain_reply(
self,
2026-03-09 05:40:27 +00:00
user_input: str,
chain_summary: str,
context: str = "",
openai_tool_block: list[dict] | None = None,
2026-02-28 08:21:35 +00:00
) -> str:
"""
2026-03-09 05:40:27 +00:00
整合工具执行结果调用 OpenAI 生成最终自然语言回复
2026-03-09 05:37:29 +00:00
2026-03-09 05:40:27 +00:00
修复后消息序列:
system 回复生成提示
user 原始用户输入
openai_tool_block 直接追加已包含 assistant+tool 消息
assistant(tool_calls=[...])
tool(tool_call_id=..., content=...)
tool(tool_call_id=..., content=...)
...
2026-02-28 08:21:35 +00:00
2026-03-09 05:40:27 +00:00
修复前的错误导致 400:
system user tool tool tool 前缺少 assistant(tool_calls)
2026-02-28 08:21:35 +00:00
2026-03-09 05:40:27 +00:00
Args:
user_input: 原始用户输入
chain_summary: 步骤摘要API 失败时的降级内容
context: 对话历史仅规划阶段使用回复阶段不注入
openai_tool_block: AgentClient 构造的合规消息块
格式: [assistant(tool_calls), tool, tool, ...]
2026-02-28 08:21:35 +00:00
"""
2026-03-09 05:40:27 +00:00
self.logger.info("✍️ 生成最终回复(工具调用链模式)...")
2026-02-28 08:21:35 +00:00
2026-03-09 05:40:27 +00:00
if not openai_tool_block:
self.logger.warning("⚠️ openai_tool_block 为空,降级到摘要模板")
return self._fallback_chain_reply(user_input, chain_summary)
2026-02-28 08:21:35 +00:00
2026-03-09 05:40:27 +00:00
# 验证消息块合规性
if not self._validate_tool_block(openai_tool_block):
self.logger.warning("⚠️ 消息块验证失败,降级到摘要模板")
return self._fallback_chain_reply(user_input, chain_summary)
# ✅ 正确的消息序列构造
messages = self._build_reply_messages_with_block(user_input, openai_tool_block)
self.logger.debug(f"📤 发送回复请求,消息数: {len(messages)}")
self._log_messages_structure(messages)
result = self.provider.generate_reply(messages)
2026-02-28 08:21:35 +00:00
2026-03-09 05:40:27 +00:00
if result.success and result.content:
self.logger.info(
f"✅ OpenAI 回复生成成功 ({len(result.content)} chars)"
)
return result.content
2026-02-28 08:21:35 +00:00
2026-03-09 05:40:27 +00:00
self.logger.warning(f"⚠️ OpenAI 回复生成失败: {result.error}")
2026-03-09 05:37:29 +00:00
return self._fallback_chain_reply(user_input, chain_summary)
def generate_final_reply(
self,
2026-03-09 05:40:27 +00:00
user_input: str,
tool_name: str,
tool_output: str,
context: str = "",
2026-03-09 05:37:29 +00:00
tool_call_id: str = "",
2026-02-28 08:21:35 +00:00
) -> str:
2026-03-09 05:40:27 +00:00
"""
单步工具结果整合调用 OpenAI 生成自然语言回复
修复后消息序列:
system 回复生成提示
user 原始用户输入
assistant tool_calls=[{id: tool_call_id, function: {name, arguments}}]
tool content=tool_output, tool_call_id=tool_call_id
修复前的错误:
system user tool 缺少 assistant(tool_calls) 前置消息
"""
2026-03-09 05:37:29 +00:00
self.logger.info(f"✍️ 整合单步工具结果 [{tool_name}]...")
2026-03-09 05:40:27 +00:00
if not tool_call_id:
# 无 tool_call_id 时降级到直接回复模式
self.logger.warning("⚠️ tool_call_id 为空,使用直接回复模式")
return self._generate_simple_reply(user_input, tool_name, tool_output)
import json
# 构造单步合规消息块
single_tool_block = [
{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": tool_call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps({"result": tool_output[:100]},
ensure_ascii=False),
},
}],
},
{
"role": "tool",
"content": tool_output,
"tool_call_id": tool_call_id,
},
]
2026-03-09 05:37:29 +00:00
return self.generate_chain_reply(
user_input=user_input,
chain_summary=tool_output,
context=context,
2026-03-09 05:40:27 +00:00
openai_tool_block=single_tool_block,
2026-02-28 08:21:35 +00:00
)
2026-03-09 05:37:29 +00:00
def generate_direct_reply(self, user_input: str, context: str = "") -> str:
2026-03-09 05:40:27 +00:00
"""无需工具时直接调用 OpenAI 生成回复(不涉及 tool 消息,无需修复)"""
2026-03-09 05:37:29 +00:00
self.logger.info("💬 直接生成回复(无需工具)...")
messages = [
2026-03-09 05:40:27 +00:00
{"role": "system", "content": "你是一个友好、专业的 AI 助手,请简洁准确地回答用户问题。"},
{"role": "user", "content": user_input},
2026-03-09 05:37:29 +00:00
]
result = self.provider.generate_reply(messages)
if result.success and result.content:
return result.content
2026-02-28 08:21:35 +00:00
return (
2026-03-09 05:40:27 +00:00
f"您好!关于「{user_input}」,我已收到您的问题。\n"
2026-03-09 05:37:29 +00:00
f"API 暂时不可用,请检查 API Key 配置)"
2026-02-28 08:21:35 +00:00
)
2026-03-09 05:37:29 +00:00
# ════════════════════════════════════════════════════════════
2026-03-09 05:40:27 +00:00
# 消息构造(修复核心)
2026-03-09 05:37:29 +00:00
# ════════════════════════════════════════════════════════════
2026-03-09 05:40:27 +00:00
@staticmethod
def _build_reply_messages_with_block(
user_input: str,
openai_tool_block: list[dict],
) -> list[dict]:
"""
构造回复生成阶段的完整消息列表
正确结构:
[
{"role": "system", "content": REPLY_SYSTEM_PROMPT},
{"role": "user", "content": user_input},
{"role": "assistant", "content": null, "tool_calls": [...]}, tool_block[0]
{"role": "tool", "content": "...", "tool_call_id": "..."}, tool_block[1]
{"role": "tool", "content": "...", "tool_call_id": "..."}, tool_block[2]
...
]
注意: openai_tool_block 必须整体追加不能拆分或重排
"""
from llm.providers.openai_provider import OpenAIProvider
messages: list[dict] = [
{"role": "system", "content": OpenAIProvider._REPLY_SYSTEM_PROMPT},
{"role": "user", "content": user_input},
]
# 整体追加工具消息块assistant + tool(s)
messages.extend(openai_tool_block)
return messages
2026-03-09 05:37:29 +00:00
@staticmethod
def _build_plan_messages(
user_input: str,
context: str,
history: list[dict] | None,
) -> list[dict]:
2026-03-09 05:40:27 +00:00
"""构造规划阶段的消息列表(不含 tool 消息,无需修复)"""
2026-03-09 05:37:29 +00:00
from llm.providers.openai_provider import OpenAIProvider
messages: list[dict] = [
{"role": "system", "content": OpenAIProvider._PLANNER_SYSTEM_PROMPT},
]
if history:
2026-03-09 05:40:27 +00:00
# 过滤掉 tool 消息,只保留 user/assistant 对话历史
clean_history = [
m for m in history[-6:]
if m.get("role") in ("user", "assistant")
and not m.get("tool_calls") # 排除含 tool_calls 的 assistant 消息
]
messages.extend(clean_history)
2026-03-09 05:37:29 +00:00
elif context and context != "(暂无对话历史)":
messages.append({
"role": "system",
"content": f"## 对话历史\n{context}",
})
messages.append({"role": "user", "content": user_input})
return messages
@staticmethod
2026-03-09 05:40:27 +00:00
def _validate_tool_block(openai_tool_block: list[dict]) -> bool:
"""
验证 openai_tool_block 消息块合规性
规则:
1. 第一条必须是 role=assistant 且含 tool_calls
2. 后续每条必须是 role=tool 且含 tool_call_id
3. tool 消息数量 == assistant.tool_calls 数量
4. 所有 tool_call_id 必须在 assistant.tool_calls[].id 中存在
"""
if not openai_tool_block:
return False
first = openai_tool_block[0]
if first.get("role") != "assistant" or not first.get("tool_calls"):
return False
declared_ids = {tc["id"] for tc in first["tool_calls"]}
tool_msgs = openai_tool_block[1:]
if len(tool_msgs) != len(declared_ids):
return False
for msg in tool_msgs:
if msg.get("role") != "tool":
return False
if msg.get("tool_call_id") not in declared_ids:
return False
return True
def _generate_simple_reply(
self,
user_input: str,
tool_name: str,
tool_output: str,
) -> str:
"""无 tool_call_id 时的降级回复(不走 Function Calling 协议)"""
messages = [
{"role": "system",
"content": "你是一个友好的 AI 助手,请基于工具执行结果回答用户问题。"},
{"role": "user",
"content": (
f"用户问题: {user_input}\n\n"
f"工具 [{tool_name}] 执行结果:\n{tool_output}\n\n"
f"请基于以上结果给出清晰的回答。"
)},
2026-03-09 05:37:29 +00:00
]
2026-03-09 05:40:27 +00:00
result = self.provider.generate_reply(messages)
return result.content if result.success else self._fallback_chain_reply(
user_input, tool_output
)
def _log_messages_structure(self, messages: list[dict]) -> None:
"""调试:打印消息序列结构(不打印完整内容)"""
self.logger.debug("📋 消息序列结构:")
for i, msg in enumerate(messages):
role = msg.get("role", "?")
if role == "assistant" and msg.get("tool_calls"):
ids = [tc["id"] for tc in msg["tool_calls"]]
names = [tc["function"]["name"] for tc in msg["tool_calls"]]
self.logger.debug(
f" [{i}] {role:10s} tool_calls={names} ids={ids}"
)
elif role == "tool":
self.logger.debug(
f" [{i}] {role:10s} tool_call_id={msg.get('tool_call_id')} "
f"content={str(msg.get('content',''))[:40]}..."
)
else:
content_preview = str(msg.get("content", ""))[:50]
self.logger.debug(f" [{i}] {role:10s} {content_preview}...")
2026-03-09 05:37:29 +00:00
# ════════════════════════════════════════════════════════════
# 降级规则引擎
# ════════════════════════════════════════════════════════════
def _rule_based_plan(self, user_input: str) -> ChainPlan:
self.logger.info("⚙️ 使用规则引擎规划...")
text = user_input.lower()
if (any(k in text for k in ["搜索", "查询", "查一下"]) and
any(k in text for k in ["计算", "", "等于", "结果"])):
return ChainPlan(
goal=user_input,
steps=[
ToolStep(1, "web_search",
{"query": user_input,
"max_results": settings.tools.web_search.max_results},
"搜索相关信息", []),
ToolStep(2, "calculator",
{"expression": self._extract_expression(user_input)},
"进行计算", [1]),
],
2026-02-28 08:21:35 +00:00
)
2026-03-09 05:37:29 +00:00
if (any(k in text for k in ["读取", "文件", "file"]) and
any(k in text for k in ["执行", "运行", "run"])):
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
return ChainPlan(
goal=user_input,
steps=[
ToolStep(1, "file_reader",
{"path": fname.group() if fname else "script.py"},
"读取文件", []),
ToolStep(2, "code_executor",
{"code": "{{STEP_1_OUTPUT}}",
"timeout": settings.tools.code_executor.timeout},
"执行代码", [1]),
],
2026-02-28 08:21:35 +00:00
)
2026-03-09 05:37:29 +00:00
return self._rule_single_step(user_input)
def _rule_single_step(self, user_input: str) -> ChainPlan:
text = user_input.lower()
if any(k in text for k in ["计算", "等于", "×", "÷", "+", "-", "*", "/"]):
expr = self._extract_expression(user_input)
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "calculator",
{"expression": expr}, "数学计算")])
if any(k in text for k in ["搜索", "查询", "天气", "新闻"]):
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "web_search",
{"query": user_input,
"max_results": settings.tools.web_search.max_results},
"网络搜索")])
if any(k in text for k in ["文件", "读取", "file"]):
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "file_reader",
{"path": fname.group() if fname else "config.json"},
"读取文件")])
if any(k in text for k in ["执行", "运行", "代码", "python"]):
code_m = re.search(r"[`'\"](.+?)[`'\"]", user_input)
code = code_m.group(1) if code_m else 'print("Hello, Agent!")'
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "code_executor",
{"code": code,
"timeout": settings.tools.code_executor.timeout},
"执行代码")])
return ChainPlan(goal=user_input, is_single=True, steps=[])
@staticmethod
def _fallback_chain_reply(user_input: str, chain_summary: str) -> str:
2026-02-28 08:21:35 +00:00
return (
2026-03-09 05:37:29 +00:00
f"✅ **任务已完成**\n\n"
f"针对您的需求「{user_input}」,执行结果如下:\n\n"
f"{chain_summary}"
2026-02-28 14:59:41 +00:00
)
2026-03-09 05:37:29 +00:00
@staticmethod
def _extract_expression(text: str) -> str:
cleaned = text.replace("×", "*").replace("÷", "/").replace("", "")
match = re.search(r"[\d\s\+\-\*\/\(\)\.]+", cleaned)
expr = match.group().strip() if match else "1+1"
return expr if len(expr) > 1 else "1+1"