base_agent/config/settings.py

278 lines
9.9 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
config/settings.py
配置加载与管理模块(新增 OpenAI 专用字段)
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
try:
import yaml
_YAML_AVAILABLE = True
except ImportError:
_YAML_AVAILABLE = False
# ════════════════════════════════════════════════════════════════
# 配置数据类
# ════════════════════════════════════════════════════════════════
@dataclass
class LLMConfig:
"""LLM 模型配置(含 OpenAI 专用字段)"""
provider: str = "openai"
model_name: str = "gpt-4o"
api_key: str = ""
api_base_url: str = ""
max_tokens: int = 4096
temperature: float = 0.7
timeout: int = 60
max_retries: int = 3
# OpenAI 专用
function_calling: bool = True
stream: bool = False
# Ollama / 本地模型
model_path: str = ""
ollama_host: str = "http://localhost:11434"
def __post_init__(self):
self.api_key = os.getenv("LLM_API_KEY", self.api_key)
self.api_base_url = os.getenv("LLM_API_BASE_URL", self.api_base_url)
self.model_name = os.getenv("LLM_MODEL_NAME", self.model_name)
self.model_path = os.getenv("LLM_MODEL_PATH", self.model_path)
@dataclass
class MCPConfig:
server_name: str = "DemoMCPServer"
transport: str = "stdio"
host: str = "localhost"
port: int = 3000
enabled_tools: list[str] = field(default_factory=lambda: [
"calculator", "web_search", "file_reader", "code_executor"
])
@dataclass
class WebSearchToolConfig:
max_results: int = 5
timeout: int = 10
api_key: str = ""
engine: str = "mock"
def __post_init__(self):
self.api_key = os.getenv("SEARCH_API_KEY", self.api_key)
@dataclass
class FileReaderToolConfig:
allowed_root: str = "./workspace"
max_file_size_kb: int = 512
@dataclass
class CodeExecutorToolConfig:
timeout: int = 5
sandbox: bool = True
@dataclass
class CalculatorToolConfig:
precision: int = 10
@dataclass
class ToolsConfig:
web_search: WebSearchToolConfig = field(default_factory=WebSearchToolConfig)
file_reader: FileReaderToolConfig = field(default_factory=FileReaderToolConfig)
code_executor: CodeExecutorToolConfig = field(default_factory=CodeExecutorToolConfig)
calculator: CalculatorToolConfig = field(default_factory=CalculatorToolConfig)
@dataclass
class MemoryConfig:
max_history: int = 20
enable_long_term: bool = False
vector_db_url: str = ""
@dataclass
class LoggingConfig:
level: str = "DEBUG"
enable_file: bool = True
log_dir: str = "./logs"
log_file: str = "agent.log"
def __post_init__(self):
self.level = os.getenv("LOG_LEVEL", self.level).upper()
@dataclass
class AgentConfig:
max_chain_steps: int = 10
enable_multi_step: bool = True
session_timeout: int = 3600
fallback_to_rules: bool = True # API 失败时降级到规则引擎
@dataclass
class AppConfig:
llm: LLMConfig = field(default_factory=LLMConfig)
mcp: MCPConfig = field(default_factory=MCPConfig)
tools: ToolsConfig = field(default_factory=ToolsConfig)
memory: MemoryConfig = field(default_factory=MemoryConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
agent: AgentConfig = field(default_factory=AgentConfig)
def display(self) -> str:
lines = [
"" * 52,
" 📋 当前配置",
"" * 52,
f" [LLM] provider = {self.llm.provider}",
f" [LLM] model_name = {self.llm.model_name}",
f" [LLM] api_key = {'***' if self.llm.api_key else '(未设置)'}",
f" [LLM] api_base_url = {self.llm.api_base_url or '(默认)'}",
f" [LLM] temperature = {self.llm.temperature}",
f" [LLM] max_tokens = {self.llm.max_tokens}",
f" [LLM] function_calling = {self.llm.function_calling}",
f" [LLM] stream = {self.llm.stream}",
f" [LLM] max_retries = {self.llm.max_retries}",
f" [MCP] server_name = {self.mcp.server_name}",
f" [MCP] enabled_tools = {self.mcp.enabled_tools}",
f" [MEMORY] max_history = {self.memory.max_history}",
f" [AGENT] multi_step = {self.agent.enable_multi_step}",
f" [AGENT] fallback_rules = {self.agent.fallback_to_rules}",
f" [LOG] level = {self.logging.level}",
"" * 52,
]
return "\n".join(lines)
# ════════════════════════════════════════════════════════════════
# 配置加载器
# ════════════════════════════════════════════════════════════════
class ConfigLoader:
_CONFIG_SEARCH_PATHS = [
Path(os.getenv("AGENT_CONFIG_PATH", "./config.yaml")),
Path("config") / "config.yaml",
Path("config.yaml"),
]
@classmethod
def load(cls) -> AppConfig:
raw = cls._read_yaml()
return cls._parse(raw) if raw else AppConfig()
@classmethod
def _read_yaml(cls) -> dict[str, Any] | None:
if not _YAML_AVAILABLE:
print("⚠️ PyYAML 未安装pip install pyyaml使用默认配置")
return None
for path in cls._CONFIG_SEARCH_PATHS:
if path and path.exists():
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f)
print(f"✅ 已加载配置文件: {path.resolve()}")
return data or {}
print(" 未找到配置文件,使用默认配置")
return None
@classmethod
def _parse(cls, raw: dict[str, Any]) -> AppConfig:
return AppConfig(
llm=cls._parse_llm(raw.get("llm", {})),
mcp=cls._parse_mcp(raw.get("mcp", {})),
tools=cls._parse_tools(raw.get("tools", {})),
memory=cls._parse_memory(raw.get("memory", {})),
logging=cls._parse_logging(raw.get("logging", {})),
agent=cls._parse_agent(raw.get("agent", {})),
)
@staticmethod
def _parse_llm(d: dict) -> LLMConfig:
return LLMConfig(
provider=d.get("provider", "openai"),
model_name=d.get("model_name", "gpt-4o"),
api_key=d.get("api_key", ""),
api_base_url=d.get("api_base_url", ""),
max_tokens=int(d.get("max_tokens", 4096)),
temperature=float(d.get("temperature", 0.7)),
timeout=int(d.get("timeout", 60)),
max_retries=int(d.get("max_retries", 3)),
function_calling=bool(d.get("function_calling", True)),
stream=bool(d.get("stream", False)),
model_path=d.get("model_path", ""),
ollama_host=d.get("ollama_host", "http://localhost:11434"),
)
@staticmethod
def _parse_mcp(d: dict) -> MCPConfig:
return MCPConfig(
server_name=d.get("server_name", "DemoMCPServer"),
transport=d.get("transport", "stdio"),
host=d.get("host", "localhost"),
port=int(d.get("port", 3000)),
enabled_tools=d.get("enabled_tools", [
"calculator", "web_search", "file_reader", "code_executor"
]),
)
@staticmethod
def _parse_tools(d: dict) -> ToolsConfig:
ws = d.get("web_search", {})
fr = d.get("file_reader", {})
ce = d.get("code_executor", {})
ca = d.get("calculator", {})
return ToolsConfig(
web_search=WebSearchToolConfig(
max_results=int(ws.get("max_results", 5)),
timeout=int(ws.get("timeout", 10)),
api_key=ws.get("api_key", ""),
engine=ws.get("engine", "mock"),
),
file_reader=FileReaderToolConfig(
allowed_root=fr.get("allowed_root", "./workspace"),
max_file_size_kb=int(fr.get("max_file_size_kb", 512)),
),
code_executor=CodeExecutorToolConfig(
timeout=int(ce.get("timeout", 5)),
sandbox=bool(ce.get("sandbox", True)),
),
calculator=CalculatorToolConfig(
precision=int(ca.get("precision", 10)),
),
)
@staticmethod
def _parse_memory(d: dict) -> MemoryConfig:
return MemoryConfig(
max_history=int(d.get("max_history", 20)),
enable_long_term=bool(d.get("enable_long_term", False)),
vector_db_url=d.get("vector_db_url", ""),
)
@staticmethod
def _parse_logging(d: dict) -> LoggingConfig:
return LoggingConfig(
level=d.get("level", "DEBUG"),
enable_file=bool(d.get("enable_file", True)),
log_dir=d.get("log_dir", "./logs"),
log_file=d.get("log_file", "agent.log"),
)
@staticmethod
def _parse_agent(d: dict) -> AgentConfig:
return AgentConfig(
max_chain_steps=int(d.get("max_chain_steps", 10)),
enable_multi_step=bool(d.get("enable_multi_step", True)),
session_timeout=int(d.get("session_timeout", 3600)),
fallback_to_rules=bool(d.get("fallback_to_rules", True)),
)
# 全局单例
settings: AppConfig = ConfigLoader.load()