""" config/settings.py 配置加载与管理模块(新增 OpenAI 专用字段) """ import os from dataclasses import dataclass, field from pathlib import Path from typing import Any try: import yaml _YAML_AVAILABLE = True except ImportError: _YAML_AVAILABLE = False # ════════════════════════════════════════════════════════════════ # 配置数据类 # ════════════════════════════════════════════════════════════════ @dataclass class LLMConfig: """LLM 模型配置(含 OpenAI 专用字段)""" provider: str = "openai" model_name: str = "gpt-4o" api_key: str = "" api_base_url: str = "" max_tokens: int = 4096 temperature: float = 0.7 timeout: int = 60 max_retries: int = 3 # OpenAI 专用 function_calling: bool = True stream: bool = False # Ollama / 本地模型 model_path: str = "" ollama_host: str = "http://localhost:11434" def __post_init__(self): self.api_key = os.getenv("LLM_API_KEY", self.api_key) self.api_base_url = os.getenv("LLM_API_BASE_URL", self.api_base_url) self.model_name = os.getenv("LLM_MODEL_NAME", self.model_name) self.model_path = os.getenv("LLM_MODEL_PATH", self.model_path) @dataclass class MCPConfig: server_name: str = "DemoMCPServer" transport: str = "stdio" host: str = "localhost" port: int = 3000 enabled_tools: list[str] = field(default_factory=lambda: [ "calculator", "web_search", "file_reader", "code_executor" ]) @dataclass class WebSearchToolConfig: max_results: int = 5 timeout: int = 10 api_key: str = "" engine: str = "mock" def __post_init__(self): self.api_key = os.getenv("SEARCH_API_KEY", self.api_key) @dataclass class FileReaderToolConfig: allowed_root: str = "./workspace" max_file_size_kb: int = 512 @dataclass class CodeExecutorToolConfig: timeout: int = 5 sandbox: bool = True @dataclass class CalculatorToolConfig: precision: int = 10 @dataclass class ToolsConfig: web_search: WebSearchToolConfig = field(default_factory=WebSearchToolConfig) file_reader: FileReaderToolConfig = field(default_factory=FileReaderToolConfig) code_executor: CodeExecutorToolConfig = field(default_factory=CodeExecutorToolConfig) calculator: CalculatorToolConfig = field(default_factory=CalculatorToolConfig) @dataclass class MemoryConfig: max_history: int = 20 enable_long_term: bool = False vector_db_url: str = "" @dataclass class LoggingConfig: level: str = "DEBUG" enable_file: bool = True log_dir: str = "./logs" log_file: str = "agent.log" def __post_init__(self): self.level = os.getenv("LOG_LEVEL", self.level).upper() @dataclass class AgentConfig: max_chain_steps: int = 10 enable_multi_step: bool = True session_timeout: int = 3600 fallback_to_rules: bool = True # API 失败时降级到规则引擎 @dataclass class AppConfig: llm: LLMConfig = field(default_factory=LLMConfig) mcp: MCPConfig = field(default_factory=MCPConfig) tools: ToolsConfig = field(default_factory=ToolsConfig) memory: MemoryConfig = field(default_factory=MemoryConfig) logging: LoggingConfig = field(default_factory=LoggingConfig) agent: AgentConfig = field(default_factory=AgentConfig) def display(self) -> str: lines = [ "─" * 52, " 📋 当前配置", "─" * 52, f" [LLM] provider = {self.llm.provider}", f" [LLM] model_name = {self.llm.model_name}", f" [LLM] api_key = {'***' if self.llm.api_key else '(未设置)'}", f" [LLM] api_base_url = {self.llm.api_base_url or '(默认)'}", f" [LLM] temperature = {self.llm.temperature}", f" [LLM] max_tokens = {self.llm.max_tokens}", f" [LLM] function_calling = {self.llm.function_calling}", f" [LLM] stream = {self.llm.stream}", f" [LLM] max_retries = {self.llm.max_retries}", f" [MCP] server_name = {self.mcp.server_name}", f" [MCP] enabled_tools = {self.mcp.enabled_tools}", f" [MEMORY] max_history = {self.memory.max_history}", f" [AGENT] multi_step = {self.agent.enable_multi_step}", f" [AGENT] fallback_rules = {self.agent.fallback_to_rules}", f" [LOG] level = {self.logging.level}", "─" * 52, ] return "\n".join(lines) # ════════════════════════════════════════════════════════════════ # 配置加载器 # ════════════════════════════════════════════════════════════════ class ConfigLoader: _CONFIG_SEARCH_PATHS = [ Path(os.getenv("AGENT_CONFIG_PATH", "./config.yaml")), Path("config") / "config.yaml", Path("config.yaml"), ] @classmethod def load(cls) -> AppConfig: raw = cls._read_yaml() return cls._parse(raw) if raw else AppConfig() @classmethod def _read_yaml(cls) -> dict[str, Any] | None: if not _YAML_AVAILABLE: print("⚠️ PyYAML 未安装(pip install pyyaml),使用默认配置") return None for path in cls._CONFIG_SEARCH_PATHS: if path and path.exists(): with open(path, encoding="utf-8") as f: data = yaml.safe_load(f) print(f"✅ 已加载配置文件: {path.resolve()}") return data or {} print("ℹ️ 未找到配置文件,使用默认配置") return None @classmethod def _parse(cls, raw: dict[str, Any]) -> AppConfig: return AppConfig( llm=cls._parse_llm(raw.get("llm", {})), mcp=cls._parse_mcp(raw.get("mcp", {})), tools=cls._parse_tools(raw.get("tools", {})), memory=cls._parse_memory(raw.get("memory", {})), logging=cls._parse_logging(raw.get("logging", {})), agent=cls._parse_agent(raw.get("agent", {})), ) @staticmethod def _parse_llm(d: dict) -> LLMConfig: return LLMConfig( provider=d.get("provider", "openai"), model_name=d.get("model_name", "gpt-4o"), api_key=d.get("api_key", ""), api_base_url=d.get("api_base_url", ""), max_tokens=int(d.get("max_tokens", 4096)), temperature=float(d.get("temperature", 0.7)), timeout=int(d.get("timeout", 60)), max_retries=int(d.get("max_retries", 3)), function_calling=bool(d.get("function_calling", True)), stream=bool(d.get("stream", False)), model_path=d.get("model_path", ""), ollama_host=d.get("ollama_host", "http://localhost:11434"), ) @staticmethod def _parse_mcp(d: dict) -> MCPConfig: return MCPConfig( server_name=d.get("server_name", "DemoMCPServer"), transport=d.get("transport", "stdio"), host=d.get("host", "localhost"), port=int(d.get("port", 3000)), enabled_tools=d.get("enabled_tools", [ "calculator", "web_search", "file_reader", "code_executor" ]), ) @staticmethod def _parse_tools(d: dict) -> ToolsConfig: ws = d.get("web_search", {}) fr = d.get("file_reader", {}) ce = d.get("code_executor", {}) ca = d.get("calculator", {}) return ToolsConfig( web_search=WebSearchToolConfig( max_results=int(ws.get("max_results", 5)), timeout=int(ws.get("timeout", 10)), api_key=ws.get("api_key", ""), engine=ws.get("engine", "mock"), ), file_reader=FileReaderToolConfig( allowed_root=fr.get("allowed_root", "./workspace"), max_file_size_kb=int(fr.get("max_file_size_kb", 512)), ), code_executor=CodeExecutorToolConfig( timeout=int(ce.get("timeout", 5)), sandbox=bool(ce.get("sandbox", True)), ), calculator=CalculatorToolConfig( precision=int(ca.get("precision", 10)), ), ) @staticmethod def _parse_memory(d: dict) -> MemoryConfig: return MemoryConfig( max_history=int(d.get("max_history", 20)), enable_long_term=bool(d.get("enable_long_term", False)), vector_db_url=d.get("vector_db_url", ""), ) @staticmethod def _parse_logging(d: dict) -> LoggingConfig: return LoggingConfig( level=d.get("level", "DEBUG"), enable_file=bool(d.get("enable_file", True)), log_dir=d.get("log_dir", "./logs"), log_file=d.get("log_file", "agent.log"), ) @staticmethod def _parse_agent(d: dict) -> AgentConfig: return AgentConfig( max_chain_steps=int(d.get("max_chain_steps", 10)), enable_multi_step=bool(d.get("enable_multi_step", True)), session_timeout=int(d.get("session_timeout", 3600)), fallback_to_rules=bool(d.get("fallback_to_rules", True)), ) # 全局单例 settings: AppConfig = ConfigLoader.load()