base_agent/llm/providers/base_provider.py

125 lines
3.7 KiB
Python
Raw Normal View History

2026-03-09 05:37:29 +00:00
"""
llm/providers/base_provider.py
LLM Provider 抽象基类定义所有 Provider 必须实现的统一接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from mcp.mcp_protocol import ChainPlan, ToolSchema
# ════════════════════════════════════════════════════════════════
# Provider 返回数据结构
# ════════════════════════════════════════════════════════════════
@dataclass
class PlanResult:
"""
工具调用链规划结果
Attributes:
plan: 解析出的 ChainPlan成功时
raw_response: 原始 API 响应用于调试
usage: Token 用量统计
success: 是否成功
error: 失败原因
"""
plan: ChainPlan | None
raw_response: Any = None
usage: dict[str, int] = field(default_factory=dict)
success: bool = True
error: str = ""
@property
def prompt_tokens(self) -> int:
return self.usage.get("prompt_tokens", 0)
@property
def completion_tokens(self) -> int:
return self.usage.get("completion_tokens", 0)
@dataclass
class ReplyResult:
"""
最终回复生成结果
Attributes:
content: 生成的自然语言回复
usage: Token 用量统计
success: 是否成功
error: 失败原因
"""
content: str
usage: dict[str, int] = field(default_factory=dict)
success: bool = True
error: str = ""
# ════════════════════════════════════════════════════════════════
# 抽象基类
# ════════════════════════════════════════════════════════════════
class BaseProvider(ABC):
"""
LLM Provider 抽象基类
所有具体 ProviderOpenAI / Anthropic / Ollama必须继承此类
并实现以下两个核心方法:
- plan_with_tools() 工具调用链规划Function Calling
- generate_reply() 最终回复生成
"""
@property
@abstractmethod
def provider_name(self) -> str:
"""Provider 名称标识,如 'openai' / 'anthropic'"""
...
@abstractmethod
def plan_with_tools(
self,
messages: list[dict],
tool_schemas: list[ToolSchema],
) -> PlanResult:
"""
使用 Function Calling 规划工具调用链
Args:
messages: 对话历史消息列表OpenAI 格式
tool_schemas: 可用工具的 Schema 列表
Returns:
PlanResult 实例
"""
...
@abstractmethod
def generate_reply(
self,
messages: list[dict],
) -> ReplyResult:
"""
基于完整对话历史含工具执行结果生成最终回复
Args:
messages: 包含 tool 角色消息的完整对话历史
Returns:
ReplyResult 实例
"""
...
def health_check(self) -> bool:
"""
连通性检测可选实现
Returns:
True 表示 API 可用
"""
return True
def __repr__(self) -> str:
return f"{self.__class__.__name__}(provider={self.provider_name})"