125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
"""
|
||
llm/providers/base_provider.py
|
||
LLM Provider 抽象基类:定义所有 Provider 必须实现的统一接口
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
from mcp.mcp_protocol import ChainPlan, ToolSchema
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# Provider 返回数据结构
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
@dataclass
|
||
class PlanResult:
|
||
"""
|
||
工具调用链规划结果
|
||
|
||
Attributes:
|
||
plan: 解析出的 ChainPlan(成功时)
|
||
raw_response: 原始 API 响应(用于调试)
|
||
usage: Token 用量统计
|
||
success: 是否成功
|
||
error: 失败原因
|
||
"""
|
||
plan: ChainPlan | None
|
||
raw_response: Any = None
|
||
usage: dict[str, int] = field(default_factory=dict)
|
||
success: bool = True
|
||
error: str = ""
|
||
|
||
@property
|
||
def prompt_tokens(self) -> int:
|
||
return self.usage.get("prompt_tokens", 0)
|
||
|
||
@property
|
||
def completion_tokens(self) -> int:
|
||
return self.usage.get("completion_tokens", 0)
|
||
|
||
|
||
@dataclass
|
||
class ReplyResult:
|
||
"""
|
||
最终回复生成结果
|
||
|
||
Attributes:
|
||
content: 生成的自然语言回复
|
||
usage: Token 用量统计
|
||
success: 是否成功
|
||
error: 失败原因
|
||
"""
|
||
content: str
|
||
usage: dict[str, int] = field(default_factory=dict)
|
||
success: bool = True
|
||
error: str = ""
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 抽象基类
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class BaseProvider(ABC):
|
||
"""
|
||
LLM Provider 抽象基类
|
||
|
||
所有具体 Provider(OpenAI / Anthropic / Ollama)必须继承此类
|
||
并实现以下两个核心方法:
|
||
- plan_with_tools() 工具调用链规划(Function Calling)
|
||
- generate_reply() 最终回复生成
|
||
"""
|
||
|
||
@property
|
||
@abstractmethod
|
||
def provider_name(self) -> str:
|
||
"""Provider 名称标识,如 'openai' / 'anthropic'"""
|
||
...
|
||
|
||
@abstractmethod
|
||
def plan_with_tools(
|
||
self,
|
||
messages: list[dict],
|
||
tool_schemas: list[ToolSchema],
|
||
) -> PlanResult:
|
||
"""
|
||
使用 Function Calling 规划工具调用链
|
||
|
||
Args:
|
||
messages: 对话历史消息列表(OpenAI 格式)
|
||
tool_schemas: 可用工具的 Schema 列表
|
||
|
||
Returns:
|
||
PlanResult 实例
|
||
"""
|
||
...
|
||
|
||
@abstractmethod
|
||
def generate_reply(
|
||
self,
|
||
messages: list[dict],
|
||
) -> ReplyResult:
|
||
"""
|
||
基于完整对话历史(含工具执行结果)生成最终回复
|
||
|
||
Args:
|
||
messages: 包含 tool 角色消息的完整对话历史
|
||
|
||
Returns:
|
||
ReplyResult 实例
|
||
"""
|
||
...
|
||
|
||
def health_check(self) -> bool:
|
||
"""
|
||
连通性检测(可选实现)
|
||
|
||
Returns:
|
||
True 表示 API 可用
|
||
"""
|
||
return True
|
||
|
||
def __repr__(self) -> str:
|
||
return f"{self.__class__.__name__}(provider={self.provider_name})" |