125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
|
|
"""
|
|||
|
|
llm/providers/base_provider.py
|
|||
|
|
LLM Provider 抽象基类:定义所有 Provider 必须实现的统一接口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from mcp.mcp_protocol import ChainPlan, ToolSchema
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# Provider 返回数据结构
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class PlanResult:
|
|||
|
|
"""
|
|||
|
|
工具调用链规划结果
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
plan: 解析出的 ChainPlan(成功时)
|
|||
|
|
raw_response: 原始 API 响应(用于调试)
|
|||
|
|
usage: Token 用量统计
|
|||
|
|
success: 是否成功
|
|||
|
|
error: 失败原因
|
|||
|
|
"""
|
|||
|
|
plan: ChainPlan | None
|
|||
|
|
raw_response: Any = None
|
|||
|
|
usage: dict[str, int] = field(default_factory=dict)
|
|||
|
|
success: bool = True
|
|||
|
|
error: str = ""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def prompt_tokens(self) -> int:
|
|||
|
|
return self.usage.get("prompt_tokens", 0)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def completion_tokens(self) -> int:
|
|||
|
|
return self.usage.get("completion_tokens", 0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ReplyResult:
|
|||
|
|
"""
|
|||
|
|
最终回复生成结果
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
content: 生成的自然语言回复
|
|||
|
|
usage: Token 用量统计
|
|||
|
|
success: 是否成功
|
|||
|
|
error: 失败原因
|
|||
|
|
"""
|
|||
|
|
content: str
|
|||
|
|
usage: dict[str, int] = field(default_factory=dict)
|
|||
|
|
success: bool = True
|
|||
|
|
error: str = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# 抽象基类
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
class BaseProvider(ABC):
|
|||
|
|
"""
|
|||
|
|
LLM Provider 抽象基类
|
|||
|
|
|
|||
|
|
所有具体 Provider(OpenAI / Anthropic / Ollama)必须继承此类
|
|||
|
|
并实现以下两个核心方法:
|
|||
|
|
- plan_with_tools() 工具调用链规划(Function Calling)
|
|||
|
|
- generate_reply() 最终回复生成
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
@abstractmethod
|
|||
|
|
def provider_name(self) -> str:
|
|||
|
|
"""Provider 名称标识,如 'openai' / 'anthropic'"""
|
|||
|
|
...
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def plan_with_tools(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict],
|
|||
|
|
tool_schemas: list[ToolSchema],
|
|||
|
|
) -> PlanResult:
|
|||
|
|
"""
|
|||
|
|
使用 Function Calling 规划工具调用链
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 对话历史消息列表(OpenAI 格式)
|
|||
|
|
tool_schemas: 可用工具的 Schema 列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
PlanResult 实例
|
|||
|
|
"""
|
|||
|
|
...
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def generate_reply(
|
|||
|
|
self,
|
|||
|
|
messages: list[dict],
|
|||
|
|
) -> ReplyResult:
|
|||
|
|
"""
|
|||
|
|
基于完整对话历史(含工具执行结果)生成最终回复
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
messages: 包含 tool 角色消息的完整对话历史
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ReplyResult 实例
|
|||
|
|
"""
|
|||
|
|
...
|
|||
|
|
|
|||
|
|
def health_check(self) -> bool:
|
|||
|
|
"""
|
|||
|
|
连通性检测(可选实现)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
True 表示 API 可用
|
|||
|
|
"""
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def __repr__(self) -> str:
|
|||
|
|
return f"{self.__class__.__name__}(provider={self.provider_name})"
|