base_agent/llm/providers/openai_provider.py

391 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
llm/providers/openai_provider.py
OpenAI Provider使用 Function Calling 实现工具链规划与回复生成
核心流程:
1. plan_with_tools()
messages + tools → OpenAI API
→ 解析 tool_calls → ChainPlan
2. generate_reply()
messages含 tool 结果)→ OpenAI API
→ 最终自然语言回复
依赖:
pip install openai>=1.0.0
"""
import json
import time
from typing import Any
from config.settings import LLMConfig
from llm.providers.base_provider import BaseProvider, PlanResult, ReplyResult
from mcp.mcp_protocol import ChainPlan, ToolSchema, ToolStep
from utils.logger import get_logger
# OpenAI SDK运行时导入避免未安装时整体崩溃
try:
from openai import (
APIConnectionError,
APIStatusError,
APITimeoutError,
AuthenticationError,
OpenAI,
RateLimitError,
)
_OPENAI_AVAILABLE = True
except ImportError:
_OPENAI_AVAILABLE = False
class OpenAIProvider(BaseProvider):
"""
OpenAI Provider 实现
支持:
- 标准 OpenAI APIgpt-4o / gpt-4-turbo / gpt-3.5-turbo 等)
- 兼容 OpenAI 协议的第三方代理(通过 api_base_url 配置)
- Function Calling 多工具并行/串行规划
- 自动重试RateLimit / 网络超时)
- API 不可用时降级到规则引擎
配置示例config.yaml:
llm:
provider: "openai"
model_name: "gpt-4o"
api_key: "sk-..."
api_base_url: "" # 留空使用官方地址
temperature: 0.2 # 规划任务建议低温度
function_calling: true
"""
# 系统 Prompt指导 LLM 进行多步骤工具规划
_PLANNER_SYSTEM_PROMPT = """\
你是一个智能任务规划助手,擅长将用户需求分解为多个工具调用步骤。
## 工作原则
1. 仔细分析用户需求,判断是否需要调用工具
2. 如需多个工具,按逻辑顺序依次调用
3. 当后续步骤依赖前步结果时,先完成前步再继续
4. 每次只规划并调用当前最合适的工具
5. 所有工具执行完毕后,整合结果给出最终回复
## 重要规则
- 数学计算必须使用 calculator 工具,不要自行计算
- 需要实时信息时使用 web_search 工具
- 文件操作使用 file_reader 工具
- 代码执行使用 code_executor 工具
"""
# 回复生成系统 Prompt
_REPLY_SYSTEM_PROMPT = """\
你是一个友好、专业的 AI 助手。
请基于已执行的工具调用结果,用清晰、自然的语言回答用户的问题。
回复要简洁明了,重点突出工具执行的关键结果。
"""
def __init__(self, cfg: LLMConfig):
self.cfg = cfg
self.logger = get_logger("LLM")
self._client: "OpenAI | None" = None
if not _OPENAI_AVAILABLE:
self.logger.warning("⚠️ openai 包未安装,请执行: pip install openai>=1.0.0")
else:
self._init_client()
# ── Provider 标识 ────────────────────────────────────────────
@property
def provider_name(self) -> str:
return "openai"
# ── 客户端初始化 ─────────────────────────────────────────────
def _init_client(self) -> None:
"""初始化 OpenAI 客户端"""
if not self.cfg.api_key:
self.logger.warning(
"⚠️ LLM_API_KEY 未设置OpenAI API 调用将失败。\n"
" 请设置环境变量: export LLM_API_KEY=sk-..."
)
kwargs: dict[str, Any] = {
"api_key": self.cfg.api_key or "sk-placeholder",
"timeout": self.cfg.timeout,
"max_retries": self.cfg.max_retries,
}
if self.cfg.api_base_url:
kwargs["base_url"] = self.cfg.api_base_url
self.logger.info(f"🔗 使用自定义 API 地址: {self.cfg.api_base_url}")
self._client = OpenAI(**kwargs)
self.logger.info(
f"✅ OpenAI 客户端初始化完成\n"
f" model = {self.cfg.model_name}\n"
f" base_url = {self.cfg.api_base_url or 'https://api.openai.com/v1'}\n"
f" max_retries= {self.cfg.max_retries}"
)
# ════════════════════════════════════════════════════════════
# 核心接口实现
# ════════════════════════════════════════════════════════════
def plan_with_tools(
self,
messages: list[dict],
tool_schemas: list[ToolSchema],
) -> PlanResult:
"""
调用 OpenAI Function Calling 规划工具调用链
OpenAI 消息格式:
[
{"role": "system", "content": "..."},
{"role": "user", "content": "用户输入"},
]
OpenAI tools 格式:
[
{
"type": "function",
"function": {
"name": "calculator",
"description": "计算数学表达式",
"parameters": {"type": "object", "properties": {...}}
}
}
]
返回 tool_calls 示例:
[
{
"id": "call_abc123",
"type": "function",
"function": {"name": "calculator", "arguments": '{"expression":"1+1"}'}
}
]
"""
if not self._client:
return PlanResult(plan=None, success=False, error="OpenAI 客户端未初始化")
# 构造 OpenAI tools 参数
tools = self._build_openai_tools(tool_schemas)
self.logger.debug(f"📤 发送规划请求tools 数量: {len(tools)}")
self.logger.debug(f"📤 消息历史长度: {len(messages)}")
try:
response = self._client.chat.completions.create(
model=self.cfg.model_name,
messages=messages,
tools=tools,
tool_choice="auto", # 由模型决定是否调用工具
temperature=self.cfg.temperature,
max_tokens=self.cfg.max_tokens,
)
usage = self._extract_usage(response)
self.logger.info(
f"📊 Token 用量: prompt={usage.get('prompt_tokens', 0)}, "
f"completion={usage.get('completion_tokens', 0)}"
)
# 解析 tool_calls → ChainPlan
choice = response.choices[0]
message = choice.message
if not message.tool_calls:
# 模型决定不调用工具,直接回复
self.logger.info("💬 模型决策: 无需工具,直接回复")
return PlanResult(
plan=ChainPlan(goal="", steps=[]),
raw_response=response,
usage=usage,
)
plan = self._parse_tool_calls(message.tool_calls)
self.logger.info(f"📋 解析到 {plan.step_count} 个工具调用步骤")
return PlanResult(
plan=plan,
raw_response=response,
usage=usage,
)
except AuthenticationError as e:
return self._handle_error("认证失败,请检查 API Key", e)
except RateLimitError as e:
return self._handle_error("请求频率超限,请稍后重试", e)
except APITimeoutError as e:
return self._handle_error(f"请求超时(>{self.cfg.timeout}s", e)
except APIConnectionError as e:
return self._handle_error("网络连接失败,请检查网络或 api_base_url", e)
except APIStatusError as e:
return self._handle_error(f"API 错误 HTTP {e.status_code}: {e.message}", e)
except Exception as e:
return self._handle_error(f"未知错误: {e}", e)
def generate_reply(
self,
messages: list[dict],
) -> ReplyResult:
"""
基于完整对话历史(含工具执行结果)生成最终自然语言回复
消息格式示例(含工具结果):
[
{"role": "system", "content": "..."},
{"role": "user", "content": "搜索天气然后计算..."},
{"role": "assistant", "content": None,
"tool_calls": [{"id":"call_1","function":{"name":"web_search",...}}]},
{"role": "tool", "content": "搜索结果...", "tool_call_id": "call_1"},
{"role": "assistant", "content": None,
"tool_calls": [{"id":"call_2","function":{"name":"calculator",...}}]},
{"role": "tool", "content": "计算结果: 312", "tool_call_id": "call_2"},
]
"""
if not self._client:
return ReplyResult(content="", success=False, error="OpenAI 客户端未初始化")
self.logger.debug(f"📤 发送回复生成请求,消息长度: {len(messages)}")
try:
response = self._client.chat.completions.create(
model=self.cfg.model_name,
messages=messages,
temperature=self.cfg.temperature,
max_tokens=self.cfg.max_tokens,
)
content = response.choices[0].message.content or ""
usage = self._extract_usage(response)
self.logger.info(
f"✅ 回复生成成功,长度: {len(content)} chars"
f"Token: {usage.get('completion_tokens', 0)}"
)
return ReplyResult(content=content, usage=usage)
except AuthenticationError as e:
return ReplyResult(content="", success=False,
error=f"认证失败: {e}")
except RateLimitError as e:
return ReplyResult(content="", success=False,
error=f"频率超限: {e}")
except APITimeoutError as e:
return ReplyResult(content="", success=False,
error=f"请求超时: {e}")
except Exception as e:
return ReplyResult(content="", success=False,
error=f"生成回复失败: {e}")
def health_check(self) -> bool:
"""发送最小请求检测 API 连通性"""
if not self._client:
return False
try:
self._client.chat.completions.create(
model=self.cfg.model_name,
messages=[{"role": "user", "content": "hi"}],
max_tokens=1,
)
return True
except Exception:
return False
# ════════════════════════════════════════════════════════════
# 工具方法
# ════════════════════════════════════════════════════════════
@staticmethod
def _build_openai_tools(tool_schemas: list[ToolSchema]) -> list[dict]:
"""
将 ToolSchema 列表转换为 OpenAI tools 参数格式
OpenAI 格式:
{
"type": "function",
"function": {
"name": "calculator",
"description": "计算数学表达式",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "..."}
},
"required": ["expression"]
}
}
}
"""
tools = []
for schema in tool_schemas:
tools.append({
"type": "function",
"function": {
"name": schema.name,
"description": schema.description,
"parameters": {
"type": "object",
"properties": schema.parameters,
"required": list(schema.parameters.keys()),
},
},
})
return tools
@staticmethod
def _parse_tool_calls(tool_calls: list) -> ChainPlan:
"""
将 OpenAI tool_calls 解析为 ChainPlan
OpenAI tool_calls 格式:
[
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "calculator",
"arguments": '{"expression": "1+2"}' ← JSON 字符串
}
}
]
"""
steps: list[ToolStep] = []
for idx, tc in enumerate(tool_calls):
fn = tc.function
tool_name = fn.name
try:
arguments = json.loads(fn.arguments)
except json.JSONDecodeError:
arguments = {"raw": fn.arguments}
steps.append(ToolStep(
step_id=idx + 1,
tool_name=tool_name,
arguments=arguments,
description=f"调用 {tool_name}(由 OpenAI Function Calling 规划)",
depends_on=list(range(1, idx + 1)) if idx > 0 else [],
))
goal = "".join(s.tool_name for s in steps)
return ChainPlan(
goal=goal,
steps=steps,
is_single=len(steps) == 1,
)
@staticmethod
def _extract_usage(response: Any) -> dict[str, int]:
"""提取 Token 用量信息"""
if hasattr(response, "usage") and response.usage:
return {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
return {}
def _handle_error(self, msg: str, exc: Exception) -> PlanResult:
self.logger.error(f"❌ OpenAI API 错误: {msg}")
self.logger.debug(f" 原始异常: {exc}")
return PlanResult(plan=None, success=False, error=msg)