385 lines
14 KiB
Python
385 lines
14 KiB
Python
|
|
"""
|
|||
|
|
llm/providers/openai_provider.py
|
|||
|
|
OpenAI Provider:使用 Function Calling 实现工具链规划与回复生成
|
|||
|
|
|
|||
|
|
核心流程:
|
|||
|
|
1. plan_with_tools()
|
|||
|
|
messages + tools → OpenAI API
|
|||
|
|
→ 解析 tool_calls → ChainPlan
|
|||
|
|
|
|||
|
|
2. generate_reply()
|
|||
|
|
messages(含 tool 结果)→ OpenAI API
|
|||
|
|
→ 最终自然语言回复
|
|||
|
|
|
|||
|
|
依赖:
|
|||
|
|
pip install openai>=1.0.0
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from config.settings import LLMConfig
|
|||
|
|
from llm.providers.base_provider import BaseProvider, PlanResult, ReplyResult
|
|||
|
|
from mcp.mcp_protocol import ChainPlan, ToolSchema, ToolStep
|
|||
|
|
from utils.logger import get_logger
|
|||
|
|
|
|||
|
|
# OpenAI SDK(运行时导入,避免未安装时整体崩溃)
|
|||
|
|
try:
|
|||
|
|
from openai import (
|
|||
|
|
APIConnectionError,
|
|||
|
|
APIStatusError,
|
|||
|
|
APITimeoutError,
|
|||
|
|
AuthenticationError,
|
|||
|
|
OpenAI,
|
|||
|
|
RateLimitError,
|
|||
|
|
)
|
|||
|
|
_OPENAI_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
_OPENAI_AVAILABLE = False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OpenAIProvider(BaseProvider):
|
|||
|
|
"""
|
|||
|
|
OpenAI Provider 实现
|
|||
|
|
|
|||
|
|
支持:
|
|||
|
|
- 标准 OpenAI API(gpt-4o / gpt-4-turbo / gpt-3.5-turbo 等)
|
|||
|
|
- 兼容 OpenAI 协议的第三方代理(通过 api_base_url 配置)
|
|||
|
|
- Function Calling 多工具并行/串行规划
|
|||
|
|
- 自动重试(RateLimit / 网络超时)
|
|||
|
|
- API 不可用时降级到规则引擎
|
|||
|
|
|
|||
|
|
配置示例(config.yaml):
|
|||
|
|
llm:
|
|||
|
|
provider: "openai"
|
|||
|
|
model_name: "gpt-4o"
|
|||
|
|
api_key: "sk-..."
|
|||
|
|
api_base_url: "" # 留空使用官方地址
|
|||
|
|
temperature: 0.2 # 规划任务建议低温度
|
|||
|
|
function_calling: true
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 系统 Prompt:指导 LLM 进行多步骤工具规划
|
|||
|
|
_PLANNER_SYSTEM_PROMPT = """\
|
|||
|
|
你是一个智能任务规划助手,擅长将用户需求分解为多个工具调用步骤。
|
|||
|
|
|
|||
|
|
## 工作原则
|
|||
|
|
1. 仔细分析用户需求,判断是否需要调用工具
|
|||
|
|
2. 如需多个工具,按逻辑顺序依次调用
|
|||
|
|
3. 当后续步骤依赖前步结果时,先完成前步再继续
|
|||
|
|
4. 每次只规划并调用当前最合适的工具
|
|||
|
|
5. 所有工具执行完毕后,整合结果给出最终回复
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 回复生成系统 Prompt
|
|||
|
|
_REPLY_SYSTEM_PROMPT = """\
|
|||
|
|
你是一个友好、专业的 AI 助手。
|
|||
|
|
请基于已执行的工具调用结果,用清晰、自然的语言回答用户的问题。
|
|||
|
|
回复要简洁明了,重点突出工具执行的关键结果。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, cfg: LLMConfig):
|
|||
|
|
self.cfg = cfg
|
|||
|
|
self.logger = get_logger("LLM")
|
|||
|
|
self._client: "OpenAI | None" = None
|
|||
|
|
|
|||
|
|
if not _OPENAI_AVAILABLE:
|
|||
|
|
self.logger.warning("⚠️ openai 包未安装,请执行: pip install openai>=1.0.0")
|
|||
|
|
else:
|
|||
|
|
self._init_client()
|
|||
|
|
|
|||
|
|
# ── Provider 标识 ────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def provider_name(self) -> str:
|
|||
|
|
return "openai"
|
|||
|
|
|
|||
|
|
# ── 客户端初始化 ─────────────────────────────────────────────
|
|||
|
|
|
|||
|
|
def _init_client(self) -> None:
|
|||
|
|
"""初始化 OpenAI 客户端"""
|
|||
|
|
if not self.cfg.api_key:
|
|||
|
|
self.logger.warning(
|
|||
|
|
"⚠️ LLM_API_KEY 未设置,OpenAI API 调用将失败。\n"
|
|||
|
|
" 请设置环境变量: export LLM_API_KEY=sk-..."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
kwargs: dict[str, Any] = {
|
|||
|
|
"api_key": self.cfg.api_key or "sk-placeholder",
|
|||
|
|
"timeout": self.cfg.timeout,
|
|||
|
|
"max_retries": self.cfg.max_retries,
|
|||
|
|
}
|
|||
|
|
if self.cfg.api_base_url:
|
|||
|
|
kwargs["base_url"] = self.cfg.api_base_url
|
|||
|
|
self.logger.info(f"🔗 使用自定义 API 地址: {self.cfg.api_base_url}")
|
|||
|
|
|
|||
|
|
self._client = OpenAI(**kwargs)
|
|||
|
|
self.logger.info(
|
|||
|
|
f"✅ OpenAI 客户端初始化完成\n"
|
|||
|
|
f" model = {self.cfg.model_name}\n"
|
|||
|
|
f" base_url = {self.cfg.api_base_url or 'https://api.openai.com/v1'}\n"
|
|||
|
|
f" max_retries= {self.cfg.max_retries}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════
|
|||
|
|
# 核心接口实现
|
|||
|
|
# ════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
def plan_with_tools(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict],
|
|||
|
|
tool_schemas: list[ToolSchema],
|
|||
|
|
) -> PlanResult:
|
|||
|
|
"""
|
|||
|
|
调用 OpenAI Function Calling 规划工具调用链
|
|||
|
|
|
|||
|
|
OpenAI 消息格式:
|
|||
|
|
[
|
|||
|
|
{"role": "system", "content": "..."},
|
|||
|
|
{"role": "user", "content": "用户输入"},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
OpenAI tools 格式:
|
|||
|
|
[
|
|||
|
|
{
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": "calculator",
|
|||
|
|
"description": "计算数学表达式",
|
|||
|
|
"parameters": {"type": "object", "properties": {...}}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
返回 tool_calls 示例:
|
|||
|
|
[
|
|||
|
|
{
|
|||
|
|
"id": "call_abc123",
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {"name": "calculator", "arguments": '{"expression":"1+1"}'}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
"""
|
|||
|
|
if not self._client:
|
|||
|
|
return PlanResult(plan=None, success=False, error="OpenAI 客户端未初始化")
|
|||
|
|
|
|||
|
|
# 构造 OpenAI tools 参数
|
|||
|
|
tools = self._build_openai_tools(tool_schemas)
|
|||
|
|
self.logger.debug(f"📤 发送规划请求,tools 数量: {len(tools)}")
|
|||
|
|
self.logger.debug(f"📤 消息历史长度: {len(messages)}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = self._client.chat.completions.create(
|
|||
|
|
model=self.cfg.model_name,
|
|||
|
|
messages=messages,
|
|||
|
|
tools=tools,
|
|||
|
|
tool_choice="auto", # 由模型决定是否调用工具
|
|||
|
|
temperature=self.cfg.temperature,
|
|||
|
|
max_tokens=self.cfg.max_tokens,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
usage = self._extract_usage(response)
|
|||
|
|
self.logger.info(
|
|||
|
|
f"📊 Token 用量: prompt={usage.get('prompt_tokens', 0)}, "
|
|||
|
|
f"completion={usage.get('completion_tokens', 0)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 解析 tool_calls → ChainPlan
|
|||
|
|
choice = response.choices[0]
|
|||
|
|
message = choice.message
|
|||
|
|
|
|||
|
|
if not message.tool_calls:
|
|||
|
|
# 模型决定不调用工具,直接回复
|
|||
|
|
self.logger.info("💬 模型决策: 无需工具,直接回复")
|
|||
|
|
return PlanResult(
|
|||
|
|
plan=ChainPlan(goal="", steps=[]),
|
|||
|
|
raw_response=response,
|
|||
|
|
usage=usage,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
plan = self._parse_tool_calls(message.tool_calls)
|
|||
|
|
self.logger.info(f"📋 解析到 {plan.step_count} 个工具调用步骤")
|
|||
|
|
return PlanResult(
|
|||
|
|
plan=plan,
|
|||
|
|
raw_response=response,
|
|||
|
|
usage=usage,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except AuthenticationError as e:
|
|||
|
|
return self._handle_error("认证失败,请检查 API Key", e)
|
|||
|
|
except RateLimitError as e:
|
|||
|
|
return self._handle_error("请求频率超限,请稍后重试", e)
|
|||
|
|
except APITimeoutError as e:
|
|||
|
|
return self._handle_error(f"请求超时(>{self.cfg.timeout}s)", e)
|
|||
|
|
except APIConnectionError as e:
|
|||
|
|
return self._handle_error("网络连接失败,请检查网络或 api_base_url", e)
|
|||
|
|
except APIStatusError as e:
|
|||
|
|
return self._handle_error(f"API 错误 HTTP {e.status_code}: {e.message}", e)
|
|||
|
|
except Exception as e:
|
|||
|
|
return self._handle_error(f"未知错误: {e}", e)
|
|||
|
|
|
|||
|
|
def generate_reply(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict],
|
|||
|
|
) -> ReplyResult:
|
|||
|
|
"""
|
|||
|
|
基于完整对话历史(含工具执行结果)生成最终自然语言回复
|
|||
|
|
|
|||
|
|
消息格式示例(含工具结果):
|
|||
|
|
[
|
|||
|
|
{"role": "system", "content": "..."},
|
|||
|
|
{"role": "user", "content": "搜索天气然后计算..."},
|
|||
|
|
{"role": "assistant", "content": None,
|
|||
|
|
"tool_calls": [{"id":"call_1","function":{"name":"web_search",...}}]},
|
|||
|
|
{"role": "tool", "content": "搜索结果...", "tool_call_id": "call_1"},
|
|||
|
|
{"role": "assistant", "content": None,
|
|||
|
|
"tool_calls": [{"id":"call_2","function":{"name":"calculator",...}}]},
|
|||
|
|
{"role": "tool", "content": "计算结果: 312", "tool_call_id": "call_2"},
|
|||
|
|
]
|
|||
|
|
"""
|
|||
|
|
if not self._client:
|
|||
|
|
return ReplyResult(content="", success=False, error="OpenAI 客户端未初始化")
|
|||
|
|
|
|||
|
|
self.logger.debug(f"📤 发送回复生成请求,消息长度: {len(messages)}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = self._client.chat.completions.create(
|
|||
|
|
model=self.cfg.model_name,
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=self.cfg.temperature,
|
|||
|
|
max_tokens=self.cfg.max_tokens,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
content = response.choices[0].message.content or ""
|
|||
|
|
usage = self._extract_usage(response)
|
|||
|
|
self.logger.info(
|
|||
|
|
f"✅ 回复生成成功,长度: {len(content)} chars,"
|
|||
|
|
f"Token: {usage.get('completion_tokens', 0)}"
|
|||
|
|
)
|
|||
|
|
return ReplyResult(content=content, usage=usage)
|
|||
|
|
|
|||
|
|
except AuthenticationError as e:
|
|||
|
|
return ReplyResult(content="", success=False,
|
|||
|
|
error=f"认证失败: {e}")
|
|||
|
|
except RateLimitError as e:
|
|||
|
|
return ReplyResult(content="", success=False,
|
|||
|
|
error=f"频率超限: {e}")
|
|||
|
|
except APITimeoutError as e:
|
|||
|
|
return ReplyResult(content="", success=False,
|
|||
|
|
error=f"请求超时: {e}")
|
|||
|
|
except Exception as e:
|
|||
|
|
return ReplyResult(content="", success=False,
|
|||
|
|
error=f"生成回复失败: {e}")
|
|||
|
|
|
|||
|
|
def health_check(self) -> bool:
|
|||
|
|
"""发送最小请求检测 API 连通性"""
|
|||
|
|
if not self._client:
|
|||
|
|
return False
|
|||
|
|
try:
|
|||
|
|
self._client.chat.completions.create(
|
|||
|
|
model=self.cfg.model_name,
|
|||
|
|
messages=[{"role": "user", "content": "hi"}],
|
|||
|
|
max_tokens=1,
|
|||
|
|
)
|
|||
|
|
return True
|
|||
|
|
except Exception:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════
|
|||
|
|
# 工具方法
|
|||
|
|
# ════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _build_openai_tools(tool_schemas: list[ToolSchema]) -> list[dict]:
|
|||
|
|
"""
|
|||
|
|
将 ToolSchema 列表转换为 OpenAI tools 参数格式
|
|||
|
|
|
|||
|
|
OpenAI 格式:
|
|||
|
|
{
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": "calculator",
|
|||
|
|
"description": "计算数学表达式",
|
|||
|
|
"parameters": {
|
|||
|
|
"type": "object",
|
|||
|
|
"properties": {
|
|||
|
|
"expression": {"type": "string", "description": "..."}
|
|||
|
|
},
|
|||
|
|
"required": ["expression"]
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
tools = []
|
|||
|
|
for schema in tool_schemas:
|
|||
|
|
tools.append({
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": schema.name,
|
|||
|
|
"description": schema.description,
|
|||
|
|
"parameters": {
|
|||
|
|
"type": "object",
|
|||
|
|
"properties": schema.parameters,
|
|||
|
|
"required": list(schema.parameters.keys()),
|
|||
|
|
},
|
|||
|
|
},
|
|||
|
|
})
|
|||
|
|
return tools
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_tool_calls(tool_calls: list) -> ChainPlan:
|
|||
|
|
"""
|
|||
|
|
将 OpenAI tool_calls 解析为 ChainPlan
|
|||
|
|
|
|||
|
|
OpenAI tool_calls 格式:
|
|||
|
|
[
|
|||
|
|
{
|
|||
|
|
"id": "call_abc123",
|
|||
|
|
"type": "function",
|
|||
|
|
"function": {
|
|||
|
|
"name": "calculator",
|
|||
|
|
"arguments": '{"expression": "1+2"}' ← JSON 字符串
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
"""
|
|||
|
|
steps: list[ToolStep] = []
|
|||
|
|
for idx, tc in enumerate(tool_calls):
|
|||
|
|
fn = tc.function
|
|||
|
|
tool_name = fn.name
|
|||
|
|
try:
|
|||
|
|
arguments = json.loads(fn.arguments)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
arguments = {"raw": fn.arguments}
|
|||
|
|
|
|||
|
|
steps.append(ToolStep(
|
|||
|
|
step_id=idx + 1,
|
|||
|
|
tool_name=tool_name,
|
|||
|
|
arguments=arguments,
|
|||
|
|
description=f"调用 {tool_name}(由 OpenAI Function Calling 规划)",
|
|||
|
|
depends_on=list(range(1, idx + 1)) if idx > 0 else [],
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
goal = " → ".join(s.tool_name for s in steps)
|
|||
|
|
return ChainPlan(
|
|||
|
|
goal=goal,
|
|||
|
|
steps=steps,
|
|||
|
|
is_single=len(steps) == 1,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _extract_usage(response: Any) -> dict[str, int]:
|
|||
|
|
"""提取 Token 用量信息"""
|
|||
|
|
if hasattr(response, "usage") and response.usage:
|
|||
|
|
return {
|
|||
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|||
|
|
"completion_tokens": response.usage.completion_tokens,
|
|||
|
|
"total_tokens": response.usage.total_tokens,
|
|||
|
|
}
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
def _handle_error(self, msg: str, exc: Exception) -> PlanResult:
|
|||
|
|
self.logger.error(f"❌ OpenAI API 错误: {msg}")
|
|||
|
|
self.logger.debug(f" 原始异常: {exc}")
|
|||
|
|
return PlanResult(plan=None, success=False, error=msg)
|