base_agent/config/settings.py

278 lines
9.9 KiB
Python
Raw Normal View History

2026-03-09 05:37:29 +00:00
"""
config/settings.py
配置加载与管理模块新增 OpenAI 专用字段
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
try:
import yaml
_YAML_AVAILABLE = True
except ImportError:
_YAML_AVAILABLE = False
# ════════════════════════════════════════════════════════════════
# 配置数据类
# ════════════════════════════════════════════════════════════════
@dataclass
class LLMConfig:
"""LLM 模型配置(含 OpenAI 专用字段)"""
provider: str = "openai"
model_name: str = "gpt-4o"
api_key: str = ""
api_base_url: str = ""
max_tokens: int = 4096
temperature: float = 0.7
timeout: int = 60
max_retries: int = 3
# OpenAI 专用
function_calling: bool = True
stream: bool = False
# Ollama / 本地模型
model_path: str = ""
ollama_host: str = "http://localhost:11434"
def __post_init__(self):
self.api_key = os.getenv("LLM_API_KEY", self.api_key)
self.api_base_url = os.getenv("LLM_API_BASE_URL", self.api_base_url)
self.model_name = os.getenv("LLM_MODEL_NAME", self.model_name)
self.model_path = os.getenv("LLM_MODEL_PATH", self.model_path)
@dataclass
class MCPConfig:
server_name: str = "DemoMCPServer"
transport: str = "stdio"
host: str = "localhost"
port: int = 3000
enabled_tools: list[str] = field(default_factory=lambda: [
"calculator", "web_search", "file_reader", "code_executor"
])
@dataclass
class WebSearchToolConfig:
max_results: int = 5
timeout: int = 10
api_key: str = ""
engine: str = "mock"
def __post_init__(self):
self.api_key = os.getenv("SEARCH_API_KEY", self.api_key)
@dataclass
class FileReaderToolConfig:
allowed_root: str = "./workspace"
max_file_size_kb: int = 512
@dataclass
class CodeExecutorToolConfig:
timeout: int = 5
sandbox: bool = True
@dataclass
class CalculatorToolConfig:
precision: int = 10
@dataclass
class ToolsConfig:
web_search: WebSearchToolConfig = field(default_factory=WebSearchToolConfig)
file_reader: FileReaderToolConfig = field(default_factory=FileReaderToolConfig)
code_executor: CodeExecutorToolConfig = field(default_factory=CodeExecutorToolConfig)
calculator: CalculatorToolConfig = field(default_factory=CalculatorToolConfig)
@dataclass
class MemoryConfig:
max_history: int = 20
enable_long_term: bool = False
vector_db_url: str = ""
@dataclass
class LoggingConfig:
level: str = "DEBUG"
enable_file: bool = True
log_dir: str = "./logs"
log_file: str = "agent.log"
def __post_init__(self):
self.level = os.getenv("LOG_LEVEL", self.level).upper()
@dataclass
class AgentConfig:
max_chain_steps: int = 10
enable_multi_step: bool = True
session_timeout: int = 3600
fallback_to_rules: bool = True # API 失败时降级到规则引擎
@dataclass
class AppConfig:
llm: LLMConfig = field(default_factory=LLMConfig)
mcp: MCPConfig = field(default_factory=MCPConfig)
tools: ToolsConfig = field(default_factory=ToolsConfig)
memory: MemoryConfig = field(default_factory=MemoryConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
agent: AgentConfig = field(default_factory=AgentConfig)
def display(self) -> str:
lines = [
"" * 52,
" 📋 当前配置",
"" * 52,
f" [LLM] provider = {self.llm.provider}",
f" [LLM] model_name = {self.llm.model_name}",
f" [LLM] api_key = {'***' if self.llm.api_key else '(未设置)'}",
f" [LLM] api_base_url = {self.llm.api_base_url or '(默认)'}",
f" [LLM] temperature = {self.llm.temperature}",
f" [LLM] max_tokens = {self.llm.max_tokens}",
f" [LLM] function_calling = {self.llm.function_calling}",
f" [LLM] stream = {self.llm.stream}",
f" [LLM] max_retries = {self.llm.max_retries}",
f" [MCP] server_name = {self.mcp.server_name}",
f" [MCP] enabled_tools = {self.mcp.enabled_tools}",
f" [MEMORY] max_history = {self.memory.max_history}",
f" [AGENT] multi_step = {self.agent.enable_multi_step}",
f" [AGENT] fallback_rules = {self.agent.fallback_to_rules}",
f" [LOG] level = {self.logging.level}",
"" * 52,
]
return "\n".join(lines)
# ════════════════════════════════════════════════════════════════
# 配置加载器
# ════════════════════════════════════════════════════════════════
class ConfigLoader:
_CONFIG_SEARCH_PATHS = [
Path(os.getenv("AGENT_CONFIG_PATH", "./config.yaml")),
Path("config") / "config.yaml",
Path("config.yaml"),
]
@classmethod
def load(cls) -> AppConfig:
raw = cls._read_yaml()
return cls._parse(raw) if raw else AppConfig()
@classmethod
def _read_yaml(cls) -> dict[str, Any] | None:
if not _YAML_AVAILABLE:
print("⚠️ PyYAML 未安装pip install pyyaml使用默认配置")
return None
for path in cls._CONFIG_SEARCH_PATHS:
if path and path.exists():
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f)
print(f"✅ 已加载配置文件: {path.resolve()}")
return data or {}
print(" 未找到配置文件,使用默认配置")
return None
@classmethod
def _parse(cls, raw: dict[str, Any]) -> AppConfig:
return AppConfig(
llm=cls._parse_llm(raw.get("llm", {})),
mcp=cls._parse_mcp(raw.get("mcp", {})),
tools=cls._parse_tools(raw.get("tools", {})),
memory=cls._parse_memory(raw.get("memory", {})),
logging=cls._parse_logging(raw.get("logging", {})),
agent=cls._parse_agent(raw.get("agent", {})),
)
@staticmethod
def _parse_llm(d: dict) -> LLMConfig:
return LLMConfig(
provider=d.get("provider", "openai"),
model_name=d.get("model_name", "gpt-4o"),
api_key=d.get("api_key", ""),
api_base_url=d.get("api_base_url", ""),
max_tokens=int(d.get("max_tokens", 4096)),
temperature=float(d.get("temperature", 0.7)),
timeout=int(d.get("timeout", 60)),
max_retries=int(d.get("max_retries", 3)),
function_calling=bool(d.get("function_calling", True)),
stream=bool(d.get("stream", False)),
model_path=d.get("model_path", ""),
ollama_host=d.get("ollama_host", "http://localhost:11434"),
)
@staticmethod
def _parse_mcp(d: dict) -> MCPConfig:
return MCPConfig(
server_name=d.get("server_name", "DemoMCPServer"),
transport=d.get("transport", "stdio"),
host=d.get("host", "localhost"),
port=int(d.get("port", 3000)),
enabled_tools=d.get("enabled_tools", [
"calculator", "web_search", "file_reader", "code_executor"
]),
)
@staticmethod
def _parse_tools(d: dict) -> ToolsConfig:
ws = d.get("web_search", {})
fr = d.get("file_reader", {})
ce = d.get("code_executor", {})
ca = d.get("calculator", {})
return ToolsConfig(
web_search=WebSearchToolConfig(
max_results=int(ws.get("max_results", 5)),
timeout=int(ws.get("timeout", 10)),
api_key=ws.get("api_key", ""),
engine=ws.get("engine", "mock"),
),
file_reader=FileReaderToolConfig(
allowed_root=fr.get("allowed_root", "./workspace"),
max_file_size_kb=int(fr.get("max_file_size_kb", 512)),
),
code_executor=CodeExecutorToolConfig(
timeout=int(ce.get("timeout", 5)),
sandbox=bool(ce.get("sandbox", True)),
),
calculator=CalculatorToolConfig(
precision=int(ca.get("precision", 10)),
),
)
@staticmethod
def _parse_memory(d: dict) -> MemoryConfig:
return MemoryConfig(
max_history=int(d.get("max_history", 20)),
enable_long_term=bool(d.get("enable_long_term", False)),
vector_db_url=d.get("vector_db_url", ""),
)
@staticmethod
def _parse_logging(d: dict) -> LoggingConfig:
return LoggingConfig(
level=d.get("level", "DEBUG"),
enable_file=bool(d.get("enable_file", True)),
log_dir=d.get("log_dir", "./logs"),
log_file=d.get("log_file", "agent.log"),
)
@staticmethod
def _parse_agent(d: dict) -> AgentConfig:
return AgentConfig(
max_chain_steps=int(d.get("max_chain_steps", 10)),
enable_multi_step=bool(d.get("enable_multi_step", True)),
session_timeout=int(d.get("session_timeout", 3600)),
fallback_to_rules=bool(d.get("fallback_to_rules", True)),
)
# 全局单例
settings: AppConfig = ConfigLoader.load()