base_agent/llm/providers/base_provider.py

125 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
llm/providers/base_provider.py
LLM Provider 抽象基类:定义所有 Provider 必须实现的统一接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from mcp.mcp_protocol import ChainPlan, ToolSchema
# ════════════════════════════════════════════════════════════════
# Provider 返回数据结构
# ════════════════════════════════════════════════════════════════
@dataclass
class PlanResult:
"""
工具调用链规划结果
Attributes:
plan: 解析出的 ChainPlan成功时
raw_response: 原始 API 响应(用于调试)
usage: Token 用量统计
success: 是否成功
error: 失败原因
"""
plan: ChainPlan | None
raw_response: Any = None
usage: dict[str, int] = field(default_factory=dict)
success: bool = True
error: str = ""
@property
def prompt_tokens(self) -> int:
return self.usage.get("prompt_tokens", 0)
@property
def completion_tokens(self) -> int:
return self.usage.get("completion_tokens", 0)
@dataclass
class ReplyResult:
"""
最终回复生成结果
Attributes:
content: 生成的自然语言回复
usage: Token 用量统计
success: 是否成功
error: 失败原因
"""
content: str
usage: dict[str, int] = field(default_factory=dict)
success: bool = True
error: str = ""
# ════════════════════════════════════════════════════════════════
# 抽象基类
# ════════════════════════════════════════════════════════════════
class BaseProvider(ABC):
"""
LLM Provider 抽象基类
所有具体 ProviderOpenAI / Anthropic / Ollama必须继承此类
并实现以下两个核心方法:
- plan_with_tools() 工具调用链规划Function Calling
- generate_reply() 最终回复生成
"""
@property
@abstractmethod
def provider_name(self) -> str:
"""Provider 名称标识,如 'openai' / 'anthropic'"""
...
@abstractmethod
def plan_with_tools(
self,
messages: list[dict],
tool_schemas: list[ToolSchema],
) -> PlanResult:
"""
使用 Function Calling 规划工具调用链
Args:
messages: 对话历史消息列表OpenAI 格式)
tool_schemas: 可用工具的 Schema 列表
Returns:
PlanResult 实例
"""
...
@abstractmethod
def generate_reply(
self,
messages: list[dict],
) -> ReplyResult:
"""
基于完整对话历史(含工具执行结果)生成最终回复
Args:
messages: 包含 tool 角色消息的完整对话历史
Returns:
ReplyResult 实例
"""
...
def health_check(self) -> bool:
"""
连通性检测(可选实现)
Returns:
True 表示 API 可用
"""
return True
def __repr__(self) -> str:
return f"{self.__class__.__name__}(provider={self.provider_name})"