base_agent/llm/llm_engine.py

488 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
llm/llm_engine.py
修复generate_chain_reply / generate_final_reply 消息序列构造
确保 tool 消息始终紧跟在含 tool_calls 的 assistant 消息之后
"""
import re
from dataclasses import dataclass
from config.settings import LLMConfig, settings
from llm.provider_factory import create_provider
from llm.providers.base_provider import BaseProvider
from mcp.mcp_protocol import ChainPlan, MCPMethod, MCPRequest, ToolSchema, ToolStep
from utils.logger import get_logger
@dataclass
class ToolDecision:
need_tool: bool
tool_name: str = ""
arguments: dict = None
reasoning: str = ""
def __post_init__(self):
self.arguments = self.arguments or {}
def to_mcp_request(self) -> MCPRequest | None:
if not self.need_tool:
return None
return MCPRequest(
method=MCPMethod.TOOLS_CALL,
params={"name": self.tool_name, "arguments": self.arguments},
)
class LLMEngine:
"""
LLM 推理引擎
✅ 修复后的消息序列规范:
generate_chain_reply() 构造的完整消息:
[
{"role": "system", "content": REPLY_SYSTEM_PROMPT},
{"role": "user", "content": "用户输入"},
# ↓ openai_tool_block来自 AgentClient._execute_chain
{"role": "assistant", "content": null,
"tool_calls": [{"id":"call_001","type":"function","function":{...}},
{"id":"call_002","type":"function","function":{...}}]},
{"role": "tool", "content": "结果1", "tool_call_id": "call_001"},
{"role": "tool", "content": "结果2", "tool_call_id": "call_002"},
]
→ provider.generate_reply() → 最终回复
⚠️ 绝对不能在 tool 消息前插入任何其他消息(尤其是 system/user
"""
_MULTI_STEP_KEYWORDS = [
"然后", "接着", "", "并且", "同时", "之后",
"先.*再", "首先.*然后", "搜索.*计算", "读取.*执行",
"多个", "分别", "依次",
]
def __init__(self, cfg: LLMConfig | None = None):
self.cfg = cfg or settings.llm
self.logger = get_logger("LLM")
self.provider: BaseProvider = create_provider(self.cfg)
self._log_init()
def _log_init(self) -> None:
self.logger.info("🧠 LLM 引擎初始化完成")
self.logger.info(f" provider = {self.cfg.provider}")
self.logger.info(f" model_name = {self.cfg.model_name}")
self.logger.info(f" function_calling = {self.cfg.function_calling}")
self.logger.info(f" temperature = {self.cfg.temperature}")
self.logger.info(f" fallback_rules = {settings.agent.fallback_to_rules}")
def reconfigure(self, cfg: LLMConfig) -> None:
self.cfg = cfg
self.provider = create_provider(cfg)
self.logger.info(f"🔄 LLM 配置已更新: model={cfg.model_name}")
# ════════════════════════════════════════════════════════════
# 工具规划
# ════════════════════════════════════════════════════════════
def plan_tool_chain(
self,
user_input: str,
tool_schemas: list[ToolSchema],
context: str = "",
history: list[dict] | None = None,
) -> ChainPlan:
"""使用 OpenAI Function Calling 规划工具调用链"""
self.logger.info(f"🗺 规划工具调用链: {user_input[:60]}...")
messages = self._build_plan_messages(user_input, context, history)
if self.cfg.function_calling:
result = self.provider.plan_with_tools(messages, tool_schemas)
if result.success and result.plan is not None:
plan = result.plan
if not plan.goal:
plan.goal = user_input
self.logger.info(f"📋 OpenAI 规划完成: {plan.step_count}")
for step in plan.steps:
self.logger.info(
f" Step {step.step_id}: [{step.tool_name}] "
f"args={step.arguments}"
)
return plan
self.logger.warning(f"⚠️ OpenAI 规划失败: {result.error}")
if settings.agent.fallback_to_rules:
self.logger.info("🔄 降级到规则引擎...")
return self._rule_based_plan(user_input)
return ChainPlan(goal=user_input, steps=[])
else:
self.logger.info("⚙️ function_calling=false使用规则引擎")
return self._rule_based_plan(user_input)
def think_and_decide(
self,
user_input: str,
tool_schemas: list[ToolSchema],
context: str = "",
) -> ToolDecision:
plan = self.plan_tool_chain(user_input, tool_schemas, context)
if not plan.steps:
return ToolDecision(need_tool=False, reasoning="无需工具,直接回复")
first = plan.steps[0]
return ToolDecision(
need_tool=True,
tool_name=first.tool_name,
arguments=first.arguments,
reasoning=first.description,
)
# ════════════════════════════════════════════════════════════
# 回复生成(核心修复区域)
# ════════════════════════════════════════════════════════════
def generate_chain_reply(
self,
user_input: str,
chain_summary: str,
context: str = "",
openai_tool_block: list[dict] | None = None,
) -> str:
"""
整合工具执行结果,调用 OpenAI 生成最终自然语言回复
✅ 修复后消息序列:
system → 回复生成提示
user → 原始用户输入
← openai_tool_block 直接追加(已包含 assistant+tool 消息)→
assistant(tool_calls=[...])
tool(tool_call_id=..., content=...)
tool(tool_call_id=..., content=...)
...
❌ 修复前的错误(导致 400:
system → user → tool → tool ← tool 前缺少 assistant(tool_calls)
Args:
user_input: 原始用户输入
chain_summary: 步骤摘要API 失败时的降级内容)
context: 对话历史(仅规划阶段使用,回复阶段不注入)
openai_tool_block: 由 AgentClient 构造的合规消息块
格式: [assistant(tool_calls), tool, tool, ...]
"""
self.logger.info("✍️ 生成最终回复(工具调用链模式)...")
if not openai_tool_block:
self.logger.warning("⚠️ openai_tool_block 为空,降级到摘要模板")
return self._fallback_chain_reply(user_input, chain_summary)
# 验证消息块合规性
if not self._validate_tool_block(openai_tool_block):
self.logger.warning("⚠️ 消息块验证失败,降级到摘要模板")
return self._fallback_chain_reply(user_input, chain_summary)
# ✅ 正确的消息序列构造
messages = self._build_reply_messages_with_block(user_input, openai_tool_block)
self.logger.debug(f"📤 发送回复请求,消息数: {len(messages)}")
self._log_messages_structure(messages)
result = self.provider.generate_reply(messages)
if result.success and result.content:
self.logger.info(
f"✅ OpenAI 回复生成成功 ({len(result.content)} chars)"
)
return result.content
self.logger.warning(f"⚠️ OpenAI 回复生成失败: {result.error}")
return self._fallback_chain_reply(user_input, chain_summary)
def generate_final_reply(
self,
user_input: str,
tool_name: str,
tool_output: str,
context: str = "",
tool_call_id: str = "",
) -> str:
"""
单步工具结果整合(调用 OpenAI 生成自然语言回复)
✅ 修复后消息序列:
system → 回复生成提示
user → 原始用户输入
assistant → tool_calls=[{id: tool_call_id, function: {name, arguments}}]
tool → content=tool_output, tool_call_id=tool_call_id
❌ 修复前的错误:
system → user → tool ← 缺少 assistant(tool_calls) 前置消息
"""
self.logger.info(f"✍️ 整合单步工具结果 [{tool_name}]...")
if not tool_call_id:
# 无 tool_call_id 时降级到直接回复模式
self.logger.warning("⚠️ tool_call_id 为空,使用直接回复模式")
return self._generate_simple_reply(user_input, tool_name, tool_output)
import json
# 构造单步合规消息块
single_tool_block = [
{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": tool_call_id,
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps({"result": tool_output[:100]},
ensure_ascii=False),
},
}],
},
{
"role": "tool",
"content": tool_output,
"tool_call_id": tool_call_id,
},
]
return self.generate_chain_reply(
user_input=user_input,
chain_summary=tool_output,
context=context,
openai_tool_block=single_tool_block,
)
def generate_direct_reply(self, user_input: str, context: str = "") -> str:
"""无需工具时直接调用 OpenAI 生成回复(不涉及 tool 消息,无需修复)"""
self.logger.info("💬 直接生成回复(无需工具)...")
messages = [
{"role": "system", "content": "你是一个友好、专业的 AI 助手,请简洁准确地回答用户问题。"},
{"role": "user", "content": user_input},
]
result = self.provider.generate_reply(messages)
if result.success and result.content:
return result.content
return (
f"您好!关于「{user_input}」,我已收到您的问题。\n"
f"API 暂时不可用,请检查 API Key 配置)"
)
# ════════════════════════════════════════════════════════════
# 消息构造(修复核心)
# ════════════════════════════════════════════════════════════
@staticmethod
def _build_reply_messages_with_block(
user_input: str,
openai_tool_block: list[dict],
) -> list[dict]:
"""
构造回复生成阶段的完整消息列表
✅ 正确结构:
[
{"role": "system", "content": REPLY_SYSTEM_PROMPT},
{"role": "user", "content": user_input},
{"role": "assistant", "content": null, "tool_calls": [...]}, ← tool_block[0]
{"role": "tool", "content": "...", "tool_call_id": "..."}, ← tool_block[1]
{"role": "tool", "content": "...", "tool_call_id": "..."}, ← tool_block[2]
...
]
⚠️ 注意: openai_tool_block 必须整体追加,不能拆分或重排
"""
from llm.providers.openai_provider import OpenAIProvider
messages: list[dict] = [
{"role": "system", "content": OpenAIProvider._REPLY_SYSTEM_PROMPT},
{"role": "user", "content": user_input},
]
# 整体追加工具消息块assistant + tool(s)
messages.extend(openai_tool_block)
return messages
@staticmethod
def _build_plan_messages(
user_input: str,
context: str,
history: list[dict] | None,
) -> list[dict]:
"""构造规划阶段的消息列表(不含 tool 消息,无需修复)"""
from llm.providers.openai_provider import OpenAIProvider
messages: list[dict] = [
{"role": "system", "content": OpenAIProvider._PLANNER_SYSTEM_PROMPT},
]
if history:
# 过滤掉 tool 消息,只保留 user/assistant 对话历史
clean_history = [
m for m in history[-6:]
if m.get("role") in ("user", "assistant")
and not m.get("tool_calls") # 排除含 tool_calls 的 assistant 消息
]
messages.extend(clean_history)
elif context and context != "(暂无对话历史)":
messages.append({
"role": "system",
"content": f"## 对话历史\n{context}",
})
messages.append({"role": "user", "content": user_input})
return messages
@staticmethod
def _validate_tool_block(openai_tool_block: list[dict]) -> bool:
"""
验证 openai_tool_block 消息块合规性
规则:
1. 第一条必须是 role=assistant 且含 tool_calls
2. 后续每条必须是 role=tool 且含 tool_call_id
3. tool 消息数量 == assistant.tool_calls 数量
4. 所有 tool_call_id 必须在 assistant.tool_calls[].id 中存在
"""
if not openai_tool_block:
return False
first = openai_tool_block[0]
if first.get("role") != "assistant" or not first.get("tool_calls"):
return False
declared_ids = {tc["id"] for tc in first["tool_calls"]}
tool_msgs = openai_tool_block[1:]
if len(tool_msgs) != len(declared_ids):
return False
for msg in tool_msgs:
if msg.get("role") != "tool":
return False
if msg.get("tool_call_id") not in declared_ids:
return False
return True
def _generate_simple_reply(
self,
user_input: str,
tool_name: str,
tool_output: str,
) -> str:
"""无 tool_call_id 时的降级回复(不走 Function Calling 协议)"""
messages = [
{"role": "system",
"content": "你是一个友好的 AI 助手,请基于工具执行结果回答用户问题。"},
{"role": "user",
"content": (
f"用户问题: {user_input}\n\n"
f"工具 [{tool_name}] 执行结果:\n{tool_output}\n\n"
f"请基于以上结果给出清晰的回答。"
)},
]
result = self.provider.generate_reply(messages)
return result.content if result.success else self._fallback_chain_reply(
user_input, tool_output
)
def _log_messages_structure(self, messages: list[dict]) -> None:
"""调试:打印消息序列结构(不打印完整内容)"""
self.logger.debug("📋 消息序列结构:")
for i, msg in enumerate(messages):
role = msg.get("role", "?")
if role == "assistant" and msg.get("tool_calls"):
ids = [tc["id"] for tc in msg["tool_calls"]]
names = [tc["function"]["name"] for tc in msg["tool_calls"]]
self.logger.debug(
f" [{i}] {role:10s} tool_calls={names} ids={ids}"
)
elif role == "tool":
self.logger.debug(
f" [{i}] {role:10s} tool_call_id={msg.get('tool_call_id')} "
f"content={str(msg.get('content',''))[:40]}..."
)
else:
content_preview = str(msg.get("content", ""))[:50]
self.logger.debug(f" [{i}] {role:10s} {content_preview}...")
# ════════════════════════════════════════════════════════════
# 降级规则引擎
# ════════════════════════════════════════════════════════════
def _rule_based_plan(self, user_input: str) -> ChainPlan:
self.logger.info("⚙️ 使用规则引擎规划...")
text = user_input.lower()
if (any(k in text for k in ["搜索", "查询", "查一下"]) and
any(k in text for k in ["计算", "", "等于", "结果"])):
return ChainPlan(
goal=user_input,
steps=[
ToolStep(1, "web_search",
{"query": user_input,
"max_results": settings.tools.web_search.max_results},
"搜索相关信息", []),
ToolStep(2, "calculator",
{"expression": self._extract_expression(user_input)},
"进行计算", [1]),
],
)
if (any(k in text for k in ["读取", "文件", "file"]) and
any(k in text for k in ["执行", "运行", "run"])):
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
return ChainPlan(
goal=user_input,
steps=[
ToolStep(1, "file_reader",
{"path": fname.group() if fname else "script.py"},
"读取文件", []),
ToolStep(2, "code_executor",
{"code": "{{STEP_1_OUTPUT}}",
"timeout": settings.tools.code_executor.timeout},
"执行代码", [1]),
],
)
return self._rule_single_step(user_input)
def _rule_single_step(self, user_input: str) -> ChainPlan:
text = user_input.lower()
if any(k in text for k in ["计算", "等于", "×", "÷", "+", "-", "*", "/"]):
expr = self._extract_expression(user_input)
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "calculator",
{"expression": expr}, "数学计算")])
if any(k in text for k in ["搜索", "查询", "天气", "新闻"]):
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "web_search",
{"query": user_input,
"max_results": settings.tools.web_search.max_results},
"网络搜索")])
if any(k in text for k in ["文件", "读取", "file"]):
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "file_reader",
{"path": fname.group() if fname else "config.json"},
"读取文件")])
if any(k in text for k in ["执行", "运行", "代码", "python"]):
code_m = re.search(r"[`'\"](.+?)[`'\"]", user_input)
code = code_m.group(1) if code_m else 'print("Hello, Agent!")'
return ChainPlan(goal=user_input, is_single=True,
steps=[ToolStep(1, "code_executor",
{"code": code,
"timeout": settings.tools.code_executor.timeout},
"执行代码")])
return ChainPlan(goal=user_input, is_single=True, steps=[])
@staticmethod
def _fallback_chain_reply(user_input: str, chain_summary: str) -> str:
return (
f"✅ **任务已完成**\n\n"
f"针对您的需求「{user_input}」,执行结果如下:\n\n"
f"{chain_summary}"
)
@staticmethod
def _extract_expression(text: str) -> str:
cleaned = text.replace("×", "*").replace("÷", "/").replace("", "")
match = re.search(r"[\d\s\+\-\*\/\(\)\.]+", cleaned)
expr = match.group().strip() if match else "1+1"
return expr if len(expr) > 1 else "1+1"