base_agent/config/settings.py

450 lines
18 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
config/settings.py
配置加载与管理 —— 使用纯字典存储工具配置,通过 settings.tools['tool_name']['key'] 访问
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
try:
import yaml
_YAML_AVAILABLE = True
except ImportError:
_YAML_AVAILABLE = False
# ════════════════════════════════════════════════════════════════
# 默认配置(与 config.yaml 结构完全对应,作为 fallback
# ════════════════════════════════════════════════════════════════
_DEFAULTS: dict[str, Any] = {
"llm": {
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "",
"api_base_url": "",
"max_tokens": 4096,
"temperature": 0.7,
"timeout": 60,
"max_retries": 3,
"function_calling": True,
"stream": False,
"model_path": "",
"ollama_host": "http://localhost:11434",
},
"mcp": {
"server_name": "DemoMCPServer",
"transport": "stdio",
"host": "localhost",
"port": 3000,
"enabled_tools": [
"calculator", "web_search", "file_reader",
"code_executor", "static_analyzer", "ssh_docker",
],
},
"tools": {
"calculator": {
"precision": 10,
},
"web_search": {
"max_results": 5,
"timeout": 10,
"api_key": "",
"engine": "mock",
},
"file_reader": {
"allowed_root": "./workspace",
"max_file_size_kb": 512,
},
"code_executor": {
"timeout": 5,
"sandbox": True,
},
"static_analyzer": {
"default_tool": "cppcheck",
"default_std": "c++17",
"timeout": 120,
"jobs": 4,
"output_format": "summary",
"max_issues": 500,
"allowed_roots": [],
"tool_extra_args": {
"cppcheck": "--suppress=missingIncludeSystem --suppress=unmatchedSuppression",
"clang-tidy": "--checks=*,-fuchsia-*,-google-*,-zircon-*",
"infer": "",
},
},
"ssh_docker": {
"default_ssh_port": 22,
"default_username": "root",
"connect_timeout": 30,
"cmd_timeout": 120,
"deploy_timeout": 300,
"default_restart_policy": "unless-stopped",
"default_tail_lines": 100,
"allowed_hosts": [],
"blocked_images": [],
"allow_privileged": False,
"servers": {},
},
},
"memory": {
"max_history": 20,
"enable_long_term": False,
"vector_db_url": "",
},
"logging": {
"level": "DEBUG",
"enable_file": True,
"log_dir": "./logs",
"log_file": "agent.log",
},
"agent": {
"max_chain_steps": 10,
"enable_multi_step": True,
"session_timeout": 3600,
"fallback_to_rules": True,
},
}
# ════════════════════════════════════════════════════════════════
# 工具配置字典视图(支持 settings.tools['web_search']['timeout']
# ════════════════════════════════════════════════════════════════
class ToolsView:
"""
工具配置字典视图
用法:
settings.tools['web_search']['timeout'] → 10
settings.tools['static_analyzer']['jobs'] → 4
settings.tools['ssh_docker']['connect_timeout']→ 30
settings.tools['ssh_docker']['servers'] → {...}
'web_search' in settings.tools → True
"""
def __init__(self, data: dict[str, dict]):
self._data = data
def __getitem__(self, tool_name: str) -> dict[str, Any]:
if tool_name not in self._data:
raise KeyError(
f"工具 '{tool_name}' 未在配置中定义。"
f"可用工具: {list(self._data.keys())}"
)
return self._data[tool_name]
def __contains__(self, tool_name: str) -> bool:
return tool_name in self._data
def __repr__(self) -> str:
return f"ToolsView({list(self._data.keys())})"
def get(self, tool_name: str, default: Any = None) -> Any:
return self._data.get(tool_name, default)
def keys(self):
return self._data.keys()
# ════════════════════════════════════════════════════════════════
# LLM / MCP / Memory / Logging / Agent 轻量配置对象
# (保留 dataclass 方便属性访问,非工具类配置)
# ════════════════════════════════════════════════════════════════
@dataclass
class LLMConfig:
provider: str = "openai"
model_name: str = "gpt-4o"
api_key: str = ""
api_base_url: str = ""
max_tokens: int = 4096
temperature: float = 0.7
timeout: int = 60
max_retries: int = 3
function_calling: bool = True
stream: bool = False
model_path: str = ""
ollama_host: str = "http://localhost:11434"
def __post_init__(self):
self.api_key = os.getenv("LLM_API_KEY", self.api_key)
self.api_base_url = os.getenv("LLM_API_BASE_URL", self.api_base_url)
self.model_name = os.getenv("LLM_MODEL_NAME", self.model_name)
@dataclass
class MCPConfig:
server_name: str = "DemoMCPServer"
transport: str = "stdio"
host: str = "localhost"
port: int = 3000
enabled_tools: list[str] = field(default_factory=lambda: [
"calculator", "web_search", "file_reader",
"code_executor", "static_analyzer", "ssh_docker",
])
@dataclass
class MemoryConfig:
max_history: int = 20
enable_long_term: bool = False
vector_db_url: str = ""
@dataclass
class LoggingConfig:
level: str = "DEBUG"
enable_file: bool = True
log_dir: str = "./logs"
log_file: str = "agent.log"
def __post_init__(self):
self.level = os.getenv("LOG_LEVEL", self.level).upper()
@dataclass
class AgentConfig:
max_chain_steps: int = 10
enable_multi_step: bool = True
session_timeout: int = 3600
fallback_to_rules: bool = True
# ════════════════════════════════════════════════════════════════
# 顶层 AppConfig
# ════════════════════════════════════════════════════════════════
class AppConfig:
"""
全局配置单例
访问方式:
settings.llm.model_name
settings.mcp.enabled_tools
settings.tools['web_search']['timeout']
settings.tools['static_analyzer']['tool_extra_args']['cppcheck']
settings.tools['ssh_docker']['servers']['prod']['host']
settings.memory.max_history
settings.agent.fallback_to_rules
settings.logging.level
"""
def __init__(
self,
llm: LLMConfig,
mcp: MCPConfig,
tools: ToolsView,
memory: MemoryConfig,
logging: LoggingConfig,
agent: AgentConfig,
):
self.llm = llm
self.mcp = mcp
self.tools = tools
self.memory = memory
self.logging = logging
self.agent = agent
def display(self) -> str:
sa = self.tools['static_analyzer']
ssh = self.tools['ssh_docker']
ws = self.tools['web_search']
fr = self.tools['file_reader']
ce = self.tools['code_executor']
calc= self.tools['calculator']
lines = [
"" * 62,
" 📋 当前配置",
"" * 62,
f" [LLM] provider = {self.llm.provider}",
f" [LLM] model_name = {self.llm.model_name}",
f" [LLM] api_key = {'***' + self.llm.api_key[-4:] if len(self.llm.api_key) > 4 else '(未设置)'}",
f" [LLM] api_base_url = {self.llm.api_base_url or '(默认)'}",
f" [LLM] function_calling = {self.llm.function_calling}",
f" [LLM] temperature = {self.llm.temperature}",
f" [MCP] enabled_tools = {self.mcp.enabled_tools}",
f" [TOOL] calculator.precision= {calc['precision']}",
f" [TOOL] web_search.engine = {ws['engine']}",
f" [TOOL] web_search.timeout = {ws['timeout']}s",
f" [TOOL] file_reader.root = {fr['allowed_root']}",
f" [TOOL] code_executor.timeout={ce['timeout']}s",
f" [TOOL] static_analyzer.tool = {sa['default_tool']}",
f" [TOOL] static_analyzer.std = {sa['default_std']}",
f" [TOOL] static_analyzer.timeout = {sa['timeout']}s",
f" [TOOL] static_analyzer.jobs = {sa['jobs']}",
f" [TOOL] static_analyzer.roots = {sa['allowed_roots'] or '(不限制)'}",
f" [TOOL] ssh_docker.port = {ssh['default_ssh_port']}",
f" [TOOL] ssh_docker.user = {ssh['default_username']}",
f" [TOOL] ssh_docker.conn_timeout = {ssh['connect_timeout']}s",
f" [TOOL] ssh_docker.deploy_timeout= {ssh['deploy_timeout']}s",
f" [TOOL] ssh_docker.allowed_hosts = {ssh['allowed_hosts'] or '(不限制)'}",
f" [TOOL] ssh_docker.servers = {list(ssh['servers'].keys()) or '(无预设)'}",
f" [MEM] max_history = {self.memory.max_history}",
f" [AGT] fallback_rules = {self.agent.fallback_to_rules}",
f" [LOG] level = {self.logging.level}",
"" * 62,
]
return "\n".join(lines)
# ════════════════════════════════════════════════════════════════
# 配置加载器
# ════════════════════════════════════════════════════════════════
class ConfigLoader:
_SEARCH_PATHS = [
Path(os.getenv("AGENT_CONFIG_PATH", "__none__")),
Path("config") / "config.yaml",
Path("config.yaml"),
]
@classmethod
def load(cls) -> AppConfig:
raw = cls._read_yaml()
return cls._build(raw if raw is not None else {})
@classmethod
def _read_yaml(cls) -> dict[str, Any] | None:
if not _YAML_AVAILABLE:
print("⚠️ PyYAML 未安装pip install pyyaml使用默认配置")
return None
for path in cls._SEARCH_PATHS:
if path and path.exists() and path.suffix in (".yaml", ".yml"):
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f)
print(f"✅ 已加载配置文件: {path.resolve()}")
return data or {}
print(" 未找到配置文件,使用默认配置")
return None
@classmethod
def _build(cls, raw: dict[str, Any]) -> AppConfig:
return AppConfig(
llm=cls._build_llm(raw.get("llm", {})),
mcp=cls._build_mcp(raw.get("mcp", {})),
tools=cls._build_tools(raw.get("tools", {})),
memory=cls._build_memory(raw.get("memory", {})),
logging=cls._build_logging(raw.get("logging", {})),
agent=cls._build_agent(raw.get("agent", {})),
)
# ── LLM ───────────────────────────────────────────────────
@staticmethod
def _build_llm(d: dict) -> LLMConfig:
df = _DEFAULTS["llm"]
return LLMConfig(
provider=d.get("provider", df["provider"]),
model_name=d.get("model_name", df["model_name"]),
api_key=d.get("api_key", df["api_key"]),
api_base_url=d.get("api_base_url", df["api_base_url"]),
max_tokens=int(d.get("max_tokens", df["max_tokens"])),
temperature=float(d.get("temperature", df["temperature"])),
timeout=int(d.get("timeout", df["timeout"])),
max_retries=int(d.get("max_retries", df["max_retries"])),
function_calling=bool(d.get("function_calling", df["function_calling"])),
stream=bool(d.get("stream", df["stream"])),
model_path=d.get("model_path", df["model_path"]),
ollama_host=d.get("ollama_host", df["ollama_host"]),
)
# ── MCP ───────────────────────────────────────────────────
@staticmethod
def _build_mcp(d: dict) -> MCPConfig:
df = _DEFAULTS["mcp"]
return MCPConfig(
server_name=d.get("server_name", df["server_name"]),
transport=d.get("transport", df["transport"]),
host=d.get("host", df["host"]),
port=int(d.get("port", df["port"])),
enabled_tools=d.get("enabled_tools", df["enabled_tools"]),
)
# ── Tools纯字典深度合并默认值────────────────────────
@classmethod
def _build_tools(cls, d: dict) -> ToolsView:
df = _DEFAULTS["tools"]
merged: dict[str, dict] = {}
# 遍历所有已知工具,深度合并 yaml 值与默认值
for tool_name, tool_defaults in df.items():
yaml_tool = d.get(tool_name, {})
merged[tool_name] = cls._deep_merge(tool_defaults, yaml_tool)
# 处理 yaml 中额外定义的工具(不在默认列表中)
for tool_name, tool_cfg in d.items():
if tool_name not in merged:
merged[tool_name] = tool_cfg if isinstance(tool_cfg, dict) else {}
# 环境变量覆盖
cls._apply_env_overrides(merged)
return ToolsView(merged)
@staticmethod
def _deep_merge(base: dict, override: dict) -> dict:
"""
深度合并两个字典override 中的值覆盖 base 中的值
对于嵌套字典递归合并,其他类型直接覆盖
"""
result = dict(base)
for key, val in override.items():
if (
key in result
and isinstance(result[key], dict)
and isinstance(val, dict)
):
result[key] = ConfigLoader._deep_merge(result[key], val)
else:
result[key] = val
return result
@staticmethod
def _apply_env_overrides(tools: dict[str, dict]) -> None:
"""从环境变量覆盖特定工具配置"""
# web_search.api_key
if api_key := os.getenv("SEARCH_API_KEY"):
tools["web_search"]["api_key"] = api_key
# ssh_docker servers 密码(格式: SSH_<SERVER_NAME>_PASSWORD
for server_name, srv in tools.get("ssh_docker", {}).get("servers", {}).items():
if isinstance(srv, dict) and not srv.get("password"):
env_key = f"SSH_{server_name.upper()}_PASSWORD"
if pw := os.getenv(env_key):
srv["password"] = pw
# ── Memory / Logging / Agent ──────────────────────────────
@staticmethod
def _build_memory(d: dict) -> MemoryConfig:
df = _DEFAULTS["memory"]
return MemoryConfig(
max_history=int(d.get("max_history", df["max_history"])),
enable_long_term=bool(d.get("enable_long_term",df["enable_long_term"])),
vector_db_url=d.get("vector_db_url", df["vector_db_url"]),
)
@staticmethod
def _build_logging(d: dict) -> LoggingConfig:
df = _DEFAULTS["logging"]
return LoggingConfig(
level=d.get("level", df["level"]),
enable_file=bool(d.get("enable_file", df["enable_file"])),
log_dir=d.get("log_dir", df["log_dir"]),
log_file=d.get("log_file", df["log_file"]),
)
@staticmethod
def _build_agent(d: dict) -> AgentConfig:
df = _DEFAULTS["agent"]
return AgentConfig(
max_chain_steps=int(d.get("max_chain_steps", df["max_chain_steps"])),
enable_multi_step=bool(d.get("enable_multi_step", df["enable_multi_step"])),
session_timeout=int(d.get("session_timeout", df["session_timeout"])),
fallback_to_rules=bool(d.get("fallback_to_rules", df["fallback_to_rules"])),
)
# ── 全局单例 ──────────────────────────────────────────────────
settings: AppConfig = ConfigLoader.load()