385 lines
14 KiB
Python
385 lines
14 KiB
Python
"""
|
||
llm/providers/openai_provider.py
|
||
OpenAI Provider:使用 Function Calling 实现工具链规划与回复生成
|
||
|
||
核心流程:
|
||
1. plan_with_tools()
|
||
messages + tools → OpenAI API
|
||
→ 解析 tool_calls → ChainPlan
|
||
|
||
2. generate_reply()
|
||
messages(含 tool 结果)→ OpenAI API
|
||
→ 最终自然语言回复
|
||
|
||
依赖:
|
||
pip install openai>=1.0.0
|
||
"""
|
||
|
||
import json
|
||
import time
|
||
from typing import Any
|
||
|
||
from config.settings import LLMConfig
|
||
from llm.providers.base_provider import BaseProvider, PlanResult, ReplyResult
|
||
from mcp.mcp_protocol import ChainPlan, ToolSchema, ToolStep
|
||
from utils.logger import get_logger
|
||
|
||
# OpenAI SDK(运行时导入,避免未安装时整体崩溃)
|
||
try:
|
||
from openai import (
|
||
APIConnectionError,
|
||
APIStatusError,
|
||
APITimeoutError,
|
||
AuthenticationError,
|
||
OpenAI,
|
||
RateLimitError,
|
||
)
|
||
_OPENAI_AVAILABLE = True
|
||
except ImportError:
|
||
_OPENAI_AVAILABLE = False
|
||
|
||
|
||
class OpenAIProvider(BaseProvider):
|
||
"""
|
||
OpenAI Provider 实现
|
||
|
||
支持:
|
||
- 标准 OpenAI API(gpt-4o / gpt-4-turbo / gpt-3.5-turbo 等)
|
||
- 兼容 OpenAI 协议的第三方代理(通过 api_base_url 配置)
|
||
- Function Calling 多工具并行/串行规划
|
||
- 自动重试(RateLimit / 网络超时)
|
||
- API 不可用时降级到规则引擎
|
||
|
||
配置示例(config.yaml):
|
||
llm:
|
||
provider: "openai"
|
||
model_name: "gpt-4o"
|
||
api_key: "sk-..."
|
||
api_base_url: "" # 留空使用官方地址
|
||
temperature: 0.2 # 规划任务建议低温度
|
||
function_calling: true
|
||
"""
|
||
|
||
# 系统 Prompt:指导 LLM 进行多步骤工具规划
|
||
_PLANNER_SYSTEM_PROMPT = """\
|
||
你是一个智能任务规划助手,擅长将用户需求分解为多个工具调用步骤。
|
||
|
||
## 工作原则
|
||
1. 仔细分析用户需求,判断是否需要调用工具
|
||
2. 如需多个工具,按逻辑顺序依次调用
|
||
3. 当后续步骤依赖前步结果时,先完成前步再继续
|
||
4. 每次只规划并调用当前最合适的工具
|
||
5. 所有工具执行完毕后,整合结果给出最终回复
|
||
"""
|
||
|
||
# 回复生成系统 Prompt
|
||
_REPLY_SYSTEM_PROMPT = """\
|
||
你是一个友好、专业的 AI 助手。
|
||
请基于已执行的工具调用结果,用清晰、自然的语言回答用户的问题。
|
||
回复要简洁明了,重点突出工具执行的关键结果。
|
||
"""
|
||
|
||
def __init__(self, cfg: LLMConfig):
|
||
self.cfg = cfg
|
||
self.logger = get_logger("LLM")
|
||
self._client: "OpenAI | None" = None
|
||
|
||
if not _OPENAI_AVAILABLE:
|
||
self.logger.warning("⚠️ openai 包未安装,请执行: pip install openai>=1.0.0")
|
||
else:
|
||
self._init_client()
|
||
|
||
# ── Provider 标识 ────────────────────────────────────────────
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return "openai"
|
||
|
||
# ── 客户端初始化 ─────────────────────────────────────────────
|
||
|
||
def _init_client(self) -> None:
|
||
"""初始化 OpenAI 客户端"""
|
||
if not self.cfg.api_key:
|
||
self.logger.warning(
|
||
"⚠️ LLM_API_KEY 未设置,OpenAI API 调用将失败。\n"
|
||
" 请设置环境变量: export LLM_API_KEY=sk-..."
|
||
)
|
||
|
||
kwargs: dict[str, Any] = {
|
||
"api_key": self.cfg.api_key or "sk-placeholder",
|
||
"timeout": self.cfg.timeout,
|
||
"max_retries": self.cfg.max_retries,
|
||
}
|
||
if self.cfg.api_base_url:
|
||
kwargs["base_url"] = self.cfg.api_base_url
|
||
self.logger.info(f"🔗 使用自定义 API 地址: {self.cfg.api_base_url}")
|
||
|
||
self._client = OpenAI(**kwargs)
|
||
self.logger.info(
|
||
f"✅ OpenAI 客户端初始化完成\n"
|
||
f" model = {self.cfg.model_name}\n"
|
||
f" base_url = {self.cfg.api_base_url or 'https://api.openai.com/v1'}\n"
|
||
f" max_retries= {self.cfg.max_retries}"
|
||
)
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 核心接口实现
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
def plan_with_tools(
|
||
self,
|
||
messages: list[dict],
|
||
tool_schemas: list[ToolSchema],
|
||
) -> PlanResult:
|
||
"""
|
||
调用 OpenAI Function Calling 规划工具调用链
|
||
|
||
OpenAI 消息格式:
|
||
[
|
||
{"role": "system", "content": "..."},
|
||
{"role": "user", "content": "用户输入"},
|
||
]
|
||
|
||
OpenAI tools 格式:
|
||
[
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "calculator",
|
||
"description": "计算数学表达式",
|
||
"parameters": {"type": "object", "properties": {...}}
|
||
}
|
||
}
|
||
]
|
||
|
||
返回 tool_calls 示例:
|
||
[
|
||
{
|
||
"id": "call_abc123",
|
||
"type": "function",
|
||
"function": {"name": "calculator", "arguments": '{"expression":"1+1"}'}
|
||
}
|
||
]
|
||
"""
|
||
if not self._client:
|
||
return PlanResult(plan=None, success=False, error="OpenAI 客户端未初始化")
|
||
|
||
# 构造 OpenAI tools 参数
|
||
tools = self._build_openai_tools(tool_schemas)
|
||
self.logger.debug(f"📤 发送规划请求,tools 数量: {len(tools)}")
|
||
self.logger.debug(f"📤 消息历史长度: {len(messages)}")
|
||
|
||
try:
|
||
response = self._client.chat.completions.create(
|
||
model=self.cfg.model_name,
|
||
messages=messages,
|
||
tools=tools,
|
||
tool_choice="auto", # 由模型决定是否调用工具
|
||
temperature=self.cfg.temperature,
|
||
max_tokens=self.cfg.max_tokens,
|
||
)
|
||
|
||
usage = self._extract_usage(response)
|
||
self.logger.info(
|
||
f"📊 Token 用量: prompt={usage.get('prompt_tokens', 0)}, "
|
||
f"completion={usage.get('completion_tokens', 0)}"
|
||
)
|
||
|
||
# 解析 tool_calls → ChainPlan
|
||
choice = response.choices[0]
|
||
message = choice.message
|
||
|
||
if not message.tool_calls:
|
||
# 模型决定不调用工具,直接回复
|
||
self.logger.info("💬 模型决策: 无需工具,直接回复")
|
||
return PlanResult(
|
||
plan=ChainPlan(goal="", steps=[]),
|
||
raw_response=response,
|
||
usage=usage,
|
||
)
|
||
|
||
plan = self._parse_tool_calls(message.tool_calls)
|
||
self.logger.info(f"📋 解析到 {plan.step_count} 个工具调用步骤")
|
||
return PlanResult(
|
||
plan=plan,
|
||
raw_response=response,
|
||
usage=usage,
|
||
)
|
||
|
||
except AuthenticationError as e:
|
||
return self._handle_error("认证失败,请检查 API Key", e)
|
||
except RateLimitError as e:
|
||
return self._handle_error("请求频率超限,请稍后重试", e)
|
||
except APITimeoutError as e:
|
||
return self._handle_error(f"请求超时(>{self.cfg.timeout}s)", e)
|
||
except APIConnectionError as e:
|
||
return self._handle_error("网络连接失败,请检查网络或 api_base_url", e)
|
||
except APIStatusError as e:
|
||
return self._handle_error(f"API 错误 HTTP {e.status_code}: {e.message}", e)
|
||
except Exception as e:
|
||
return self._handle_error(f"未知错误: {e}", e)
|
||
|
||
def generate_reply(
|
||
self,
|
||
messages: list[dict],
|
||
) -> ReplyResult:
|
||
"""
|
||
基于完整对话历史(含工具执行结果)生成最终自然语言回复
|
||
|
||
消息格式示例(含工具结果):
|
||
[
|
||
{"role": "system", "content": "..."},
|
||
{"role": "user", "content": "搜索天气然后计算..."},
|
||
{"role": "assistant", "content": None,
|
||
"tool_calls": [{"id":"call_1","function":{"name":"web_search",...}}]},
|
||
{"role": "tool", "content": "搜索结果...", "tool_call_id": "call_1"},
|
||
{"role": "assistant", "content": None,
|
||
"tool_calls": [{"id":"call_2","function":{"name":"calculator",...}}]},
|
||
{"role": "tool", "content": "计算结果: 312", "tool_call_id": "call_2"},
|
||
]
|
||
"""
|
||
if not self._client:
|
||
return ReplyResult(content="", success=False, error="OpenAI 客户端未初始化")
|
||
|
||
self.logger.debug(f"📤 发送回复生成请求,消息长度: {len(messages)}")
|
||
|
||
try:
|
||
response = self._client.chat.completions.create(
|
||
model=self.cfg.model_name,
|
||
messages=messages,
|
||
temperature=self.cfg.temperature,
|
||
max_tokens=self.cfg.max_tokens,
|
||
)
|
||
|
||
content = response.choices[0].message.content or ""
|
||
usage = self._extract_usage(response)
|
||
self.logger.info(
|
||
f"✅ 回复生成成功,长度: {len(content)} chars,"
|
||
f"Token: {usage.get('completion_tokens', 0)}"
|
||
)
|
||
return ReplyResult(content=content, usage=usage)
|
||
|
||
except AuthenticationError as e:
|
||
return ReplyResult(content="", success=False,
|
||
error=f"认证失败: {e}")
|
||
except RateLimitError as e:
|
||
return ReplyResult(content="", success=False,
|
||
error=f"频率超限: {e}")
|
||
except APITimeoutError as e:
|
||
return ReplyResult(content="", success=False,
|
||
error=f"请求超时: {e}")
|
||
except Exception as e:
|
||
return ReplyResult(content="", success=False,
|
||
error=f"生成回复失败: {e}")
|
||
|
||
def health_check(self) -> bool:
|
||
"""发送最小请求检测 API 连通性"""
|
||
if not self._client:
|
||
return False
|
||
try:
|
||
self._client.chat.completions.create(
|
||
model=self.cfg.model_name,
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
max_tokens=1,
|
||
)
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
# ════════════════════════════════════════════════════════════
|
||
# 工具方法
|
||
# ════════════════════════════════════════════════════════════
|
||
|
||
@staticmethod
|
||
def _build_openai_tools(tool_schemas: list[ToolSchema]) -> list[dict]:
|
||
"""
|
||
将 ToolSchema 列表转换为 OpenAI tools 参数格式
|
||
|
||
OpenAI 格式:
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": "calculator",
|
||
"description": "计算数学表达式",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"expression": {"type": "string", "description": "..."}
|
||
},
|
||
"required": ["expression"]
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
tools = []
|
||
for schema in tool_schemas:
|
||
tools.append({
|
||
"type": "function",
|
||
"function": {
|
||
"name": schema.name,
|
||
"description": schema.description,
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": schema.parameters,
|
||
"required": list(schema.parameters.keys()),
|
||
},
|
||
},
|
||
})
|
||
return tools
|
||
|
||
@staticmethod
|
||
def _parse_tool_calls(tool_calls: list) -> ChainPlan:
|
||
"""
|
||
将 OpenAI tool_calls 解析为 ChainPlan
|
||
|
||
OpenAI tool_calls 格式:
|
||
[
|
||
{
|
||
"id": "call_abc123",
|
||
"type": "function",
|
||
"function": {
|
||
"name": "calculator",
|
||
"arguments": '{"expression": "1+2"}' ← JSON 字符串
|
||
}
|
||
}
|
||
]
|
||
"""
|
||
steps: list[ToolStep] = []
|
||
for idx, tc in enumerate(tool_calls):
|
||
fn = tc.function
|
||
tool_name = fn.name
|
||
try:
|
||
arguments = json.loads(fn.arguments)
|
||
except json.JSONDecodeError:
|
||
arguments = {"raw": fn.arguments}
|
||
|
||
steps.append(ToolStep(
|
||
step_id=idx + 1,
|
||
tool_name=tool_name,
|
||
arguments=arguments,
|
||
description=f"调用 {tool_name}(由 OpenAI Function Calling 规划)",
|
||
depends_on=list(range(1, idx + 1)) if idx > 0 else [],
|
||
))
|
||
|
||
goal = " → ".join(s.tool_name for s in steps)
|
||
return ChainPlan(
|
||
goal=goal,
|
||
steps=steps,
|
||
is_single=len(steps) == 1,
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_usage(response: Any) -> dict[str, int]:
|
||
"""提取 Token 用量信息"""
|
||
if hasattr(response, "usage") and response.usage:
|
||
return {
|
||
"prompt_tokens": response.usage.prompt_tokens,
|
||
"completion_tokens": response.usage.completion_tokens,
|
||
"total_tokens": response.usage.total_tokens,
|
||
}
|
||
return {}
|
||
|
||
def _handle_error(self, msg: str, exc: Exception) -> PlanResult:
|
||
self.logger.error(f"❌ OpenAI API 错误: {msg}")
|
||
self.logger.debug(f" 原始异常: {exc}")
|
||
return PlanResult(plan=None, success=False, error=msg) |