379 lines
16 KiB
Python
379 lines
16 KiB
Python
"""
|
||
llm/llm_engine.py
|
||
LLM 引擎:通过 Provider 接口调用真实 OpenAI API
|
||
支持 Function Calling 多步骤规划 + 工具结果整合回复
|
||
"""
|
||
|
||
import re
|
||
from dataclasses import dataclass
|
||
|
||
from config.settings import LLMConfig, settings
|
||
from llm.provider_factory import create_provider
|
||
from llm.providers.base_provider import BaseProvider
|
||
from mcp.mcp_protocol import ChainPlan, MCPMethod, MCPRequest, ToolSchema, ToolStep
|
||
from utils.logger import get_logger
|
||
|
||
|
||
@dataclass
|
||
class ToolDecision:
|
||
need_tool: bool
|
||
tool_name: str = ""
|
||
arguments: dict = None
|
||
reasoning: str = ""
|
||
|
||
def __post_init__(self):
|
||
self.arguments = self.arguments or {}
|
||
|
||
def to_mcp_request(self) -> MCPRequest | None:
|
||
if not self.need_tool:
|
||
return None
|
||
return MCPRequest(
|
||
method=MCPMethod.TOOLS_CALL,
|
||
params={"name": self.tool_name, "arguments": self.arguments},
|
||
)
|
||
|
||
|
||
class LLMEngine:
|
||
"""
|
||
LLM 推理引擎(Provider 模式)
|
||
|
||
核心流程:
|
||
1. plan_tool_chain()
|
||
构造 OpenAI 格式消息 + tools
|
||
→ Provider.plan_with_tools()
|
||
→ 解析 tool_calls → ChainPlan
|
||
|
||
2. generate_chain_reply()
|
||
构造含工具结果的完整消息历史
|
||
→ Provider.generate_reply()
|
||
→ 最终自然语言回复
|
||
|
||
降级策略:
|
||
API 调用失败 且 fallback_to_rules=true
|
||
→ 自动切换到规则引擎(保证系统可用性)
|
||
"""
|
||
|
||
# 规则引擎关键词(降级时使用)
|
||
_MULTI_STEP_KEYWORDS = [
|
||
"然后", "接着", "再", "并且", "同时", "之后",
|
||
"先.*再", "首先.*然后", "搜索.*计算", "读取.*执行",
|
||
"多个", "分别", "依次",
|
||
]
|
||
|
||
def __init__(self, cfg: LLMConfig | None = None):
|
||
self.cfg = cfg or settings.llm
|
||
self.logger = get_logger("LLM")
|
||
self.provider: BaseProvider = create_provider(self.cfg)
|
||
self._log_init()
|
||
|
||
def _log_init(self) -> None:
|
||
self.logger.info("🧠 LLM 引擎初始化完成")
|
||
self.logger.info(f" provider = {self.cfg.provider}")
|
||
self.logger.info(f" model_name = {self.cfg.model_name}")
|
||
self.logger.info(f" function_calling = {self.cfg.function_calling}")
|
||
self.logger.info(f" temperature = {self.cfg.temperature}")
|
||
self.logger.info(f" fallback_rules = {settings.agent.fallback_to_rules}")
|
||
|
||
def reconfigure(self, cfg: LLMConfig) -> None:
|
||
"""热更新配置并重建 Provider"""
|
||
self.cfg = cfg
|
||
self.provider = create_provider(cfg)
|
||
self.logger.info(f"🔄 LLM 配置已更新: model={cfg.model_name}")
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 核心接口
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def plan_tool_chain(
|
||
self,
|
||
user_input: str,
|
||
tool_schemas: list[ToolSchema],
|
||
context: str = "",
|
||
history: list[dict] | None = None,
|
||
) -> ChainPlan:
|
||
"""
|
||
使用 OpenAI Function Calling 规划工具调用链
|
||
|
||
消息构造策略:
|
||
system → 规划器系统提示
|
||
history → 历史对话(可选)
|
||
user → 当前用户输入
|
||
|
||
Args:
|
||
user_input: 用户输入文本
|
||
tool_schemas: 可用工具列表
|
||
context: 对话历史摘要(文本格式,用于无 history 时)
|
||
history: 结构化对话历史(OpenAI 消息格式,优先使用)
|
||
|
||
Returns:
|
||
ChainPlan 实例
|
||
"""
|
||
self.logger.info(f"🗺 规划工具调用链: {user_input[:60]}...")
|
||
|
||
# 构造消息列表
|
||
messages = self._build_plan_messages(user_input, context, history)
|
||
|
||
if self.cfg.function_calling:
|
||
# ── 真实 OpenAI Function Calling ──────────────────
|
||
result = self.provider.plan_with_tools(messages, tool_schemas)
|
||
|
||
if result.success and result.plan is not None:
|
||
plan = result.plan
|
||
# 补充 goal 字段
|
||
if not plan.goal:
|
||
plan.goal = user_input
|
||
self.logger.info(f"📋 OpenAI 规划完成: {plan.step_count} 步")
|
||
for step in plan.steps:
|
||
self.logger.info(
|
||
f" Step {step.step_id}: [{step.tool_name}] "
|
||
f"args={step.arguments}"
|
||
)
|
||
return plan
|
||
|
||
# API 失败处理
|
||
self.logger.warning(f"⚠️ OpenAI 规划失败: {result.error}")
|
||
if settings.agent.fallback_to_rules:
|
||
self.logger.info("🔄 降级到规则引擎...")
|
||
return self._rule_based_plan(user_input)
|
||
return ChainPlan(goal=user_input, steps=[])
|
||
|
||
else:
|
||
# function_calling=false 时直接使用规则引擎
|
||
self.logger.info("⚙️ function_calling=false,使用规则引擎")
|
||
return self._rule_based_plan(user_input)
|
||
|
||
def think_and_decide(
|
||
self,
|
||
user_input: str,
|
||
tool_schemas: list[ToolSchema],
|
||
context: str = "",
|
||
) -> ToolDecision:
|
||
"""单步工具决策(代理到 plan_tool_chain)"""
|
||
plan = self.plan_tool_chain(user_input, tool_schemas, context)
|
||
if not plan.steps:
|
||
return ToolDecision(need_tool=False, reasoning="无需工具,直接回复")
|
||
first = plan.steps[0]
|
||
return ToolDecision(
|
||
need_tool=True,
|
||
tool_name=first.tool_name,
|
||
arguments=first.arguments,
|
||
reasoning=first.description,
|
||
)
|
||
|
||
def generate_chain_reply(
|
||
self,
|
||
user_input: str,
|
||
chain_summary: str,
|
||
context: str = "",
|
||
tool_messages: list[dict] | None = None,
|
||
) -> str:
|
||
"""
|
||
整合多步骤执行结果,调用 OpenAI 生成最终自然语言回复
|
||
|
||
消息构造(含工具执行结果):
|
||
system → 回复生成系统提示
|
||
user → 原始用户输入
|
||
assistant → 工具调用决策(tool_calls)
|
||
tool → 工具执行结果
|
||
...(多轮工具调用)
|
||
|
||
Args:
|
||
user_input: 原始用户输入
|
||
chain_summary: 步骤摘要(API 失败时的降级内容)
|
||
context: 对话历史
|
||
tool_messages: 完整的工具调用消息序列(OpenAI 格式)
|
||
|
||
Returns:
|
||
最终回复字符串
|
||
"""
|
||
self.logger.info("✍️ 生成最终回复...")
|
||
|
||
if tool_messages:
|
||
# 构造含工具结果的完整消息历史
|
||
messages = self._build_reply_messages(user_input, tool_messages)
|
||
result = self.provider.generate_reply(messages)
|
||
|
||
if result.success and result.content:
|
||
self.logger.info(
|
||
f"✅ OpenAI 回复生成成功 ({len(result.content)} chars)"
|
||
)
|
||
return result.content
|
||
|
||
self.logger.warning(f"⚠️ OpenAI 回复生成失败: {result.error}")
|
||
|
||
# 降级:使用模板回复
|
||
return self._fallback_chain_reply(user_input, chain_summary)
|
||
|
||
def generate_final_reply(
|
||
self,
|
||
user_input: str,
|
||
tool_name: str,
|
||
tool_output: str,
|
||
context: str = "",
|
||
tool_call_id: str = "",
|
||
) -> str:
|
||
"""单步工具结果整合(调用 OpenAI 生成自然语言回复)"""
|
||
self.logger.info(f"✍️ 整合单步工具结果 [{tool_name}]...")
|
||
|
||
# 构造单步工具消息
|
||
tool_messages = []
|
||
if tool_call_id:
|
||
tool_messages = [
|
||
{
|
||
"role": "tool",
|
||
"content": tool_output,
|
||
"tool_call_id": tool_call_id,
|
||
}
|
||
]
|
||
|
||
return self.generate_chain_reply(
|
||
user_input=user_input,
|
||
chain_summary=tool_output,
|
||
context=context,
|
||
tool_messages=tool_messages,
|
||
)
|
||
|
||
def generate_direct_reply(self, user_input: str, context: str = "") -> str:
|
||
"""无需工具时直接调用 OpenAI 生成回复"""
|
||
self.logger.info("💬 直接生成回复(无需工具)...")
|
||
messages = [
|
||
{"role": "system", "content": "你是一个友好、专业的 AI 助手,请简洁准确地回答用户问题。"},
|
||
{"role": "user", "content": user_input},
|
||
]
|
||
result = self.provider.generate_reply(messages)
|
||
if result.success and result.content:
|
||
return result.content
|
||
# 降级
|
||
return (
|
||
f"[{self.cfg.model_name}] 您好!\n"
|
||
f"关于「{user_input}」,我已收到您的问题。\n"
|
||
f"(API 暂时不可用,请检查 API Key 配置)"
|
||
)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 消息构造
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
@staticmethod
|
||
def _build_plan_messages(
|
||
user_input: str,
|
||
context: str,
|
||
history: list[dict] | None,
|
||
) -> list[dict]:
|
||
"""构造规划阶段的消息列表"""
|
||
from llm.providers.openai_provider import OpenAIProvider
|
||
messages: list[dict] = [
|
||
{"role": "system", "content": OpenAIProvider._PLANNER_SYSTEM_PROMPT},
|
||
]
|
||
# 注入结构化历史(优先)或文本摘要
|
||
if history:
|
||
messages.extend(history[-6:]) # 最近 3 轮
|
||
elif context and context != "(暂无对话历史)":
|
||
messages.append({
|
||
"role": "system",
|
||
"content": f"## 对话历史\n{context}",
|
||
})
|
||
messages.append({"role": "user", "content": user_input})
|
||
return messages
|
||
|
||
@staticmethod
|
||
def _build_reply_messages(
|
||
user_input: str,
|
||
tool_messages: list[dict],
|
||
) -> list[dict]:
|
||
"""构造回复生成阶段的消息列表(含工具执行结果)"""
|
||
from llm.providers.openai_provider import OpenAIProvider
|
||
messages: list[dict] = [
|
||
{"role": "system", "content": OpenAIProvider._REPLY_SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_input},
|
||
]
|
||
messages.extend(tool_messages)
|
||
return messages
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 降级规则引擎
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def _rule_based_plan(self, user_input: str) -> ChainPlan:
|
||
"""规则引擎(API 不可用时的降级方案)"""
|
||
self.logger.info("⚙️ 使用规则引擎规划...")
|
||
text = user_input.lower()
|
||
|
||
# 搜索 + 计算
|
||
if (any(k in text for k in ["搜索", "查询", "查一下"]) and
|
||
any(k in text for k in ["计算", "算", "等于", "结果"])):
|
||
return ChainPlan(
|
||
goal=user_input,
|
||
steps=[
|
||
ToolStep(1, "web_search",
|
||
{"query": user_input,
|
||
"max_results": settings.tools.web_search.max_results},
|
||
"搜索相关信息", []),
|
||
ToolStep(2, "calculator",
|
||
{"expression": self._extract_expression(user_input)},
|
||
"进行计算", [1]),
|
||
],
|
||
)
|
||
# 读取文件 + 执行代码
|
||
if (any(k in text for k in ["读取", "文件", "file"]) and
|
||
any(k in text for k in ["执行", "运行", "run"])):
|
||
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
|
||
return ChainPlan(
|
||
goal=user_input,
|
||
steps=[
|
||
ToolStep(1, "file_reader",
|
||
{"path": fname.group() if fname else "script.py"},
|
||
"读取文件", []),
|
||
ToolStep(2, "code_executor",
|
||
{"code": "{{STEP_1_OUTPUT}}",
|
||
"timeout": settings.tools.code_executor.timeout},
|
||
"执行代码", [1]),
|
||
],
|
||
)
|
||
return self._rule_single_step(user_input)
|
||
|
||
def _rule_single_step(self, user_input: str) -> ChainPlan:
|
||
"""单步规则匹配"""
|
||
text = user_input.lower()
|
||
if any(k in text for k in ["计算", "等于", "×", "÷", "+", "-", "*", "/"]):
|
||
expr = self._extract_expression(user_input)
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "calculator",
|
||
{"expression": expr}, "数学计算")])
|
||
if any(k in text for k in ["搜索", "查询", "天气", "新闻"]):
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "web_search",
|
||
{"query": user_input,
|
||
"max_results": settings.tools.web_search.max_results},
|
||
"网络搜索")])
|
||
if any(k in text for k in ["文件", "读取", "file"]):
|
||
fname = re.search(r"[\w\-\.]+\.\w+", user_input)
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "file_reader",
|
||
{"path": fname.group() if fname else "config.json"},
|
||
"读取文件")])
|
||
if any(k in text for k in ["执行", "运行", "代码", "python"]):
|
||
code_m = re.search(r"[`'\"](.+?)[`'\"]", user_input)
|
||
code = code_m.group(1) if code_m else 'print("Hello, Agent!")'
|
||
return ChainPlan(goal=user_input, is_single=True,
|
||
steps=[ToolStep(1, "code_executor",
|
||
{"code": code,
|
||
"timeout": settings.tools.code_executor.timeout},
|
||
"执行代码")])
|
||
return ChainPlan(goal=user_input, is_single=True, steps=[])
|
||
|
||
@staticmethod
|
||
def _fallback_chain_reply(user_input: str, chain_summary: str) -> str:
|
||
"""API 不可用时的模板回复"""
|
||
return (
|
||
f"✅ **任务已完成**\n\n"
|
||
f"针对您的需求「{user_input}」,执行结果如下:\n\n"
|
||
f"{chain_summary}"
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_expression(text: str) -> str:
|
||
cleaned = text.replace("×", "*").replace("÷", "/").replace(",", "")
|
||
match = re.search(r"[\d\s\+\-\*\/\(\)\.]+", cleaned)
|
||
expr = match.group().strip() if match else "1+1"
|
||
return expr if len(expr) > 1 else "1+1" |