base_agent/config/settings.py

506 lines
21 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
config/settings.py
配置加载与管理 —— 新增 mcp_skills 在线 MCP Server 配置节
工具配置通过 settings.tools['tool_name']['key'] 访问
在线 skill 配置通过 settings.mcp_skills 列表访问
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
try:
import yaml
_YAML_AVAILABLE = True
except ImportError:
_YAML_AVAILABLE = False
# ════════════════════════════════════════════════════════════════
# 默认配置(与 config.yaml 结构完全对应)
# ════════════════════════════════════════════════════════════════
_DEFAULTS: dict[str, Any] = {
"llm": {
"provider": "openai",
"model_name": "gpt-4o",
"api_key": "",
"api_base_url": "",
"max_tokens": 4096,
"temperature": 0.7,
"timeout": 60,
"max_retries": 3,
"function_calling": True,
"stream": False,
"model_path": "",
"ollama_host": "http://localhost:11434",
},
"mcp": {
"server_name": "DemoMCPServer",
"transport": "stdio",
"host": "localhost",
"port": 3000,
"enabled_tools": [
"calculator", "web_search", "file_reader",
"code_executor", "static_analyzer", "ssh_docker",
],
},
"mcp_skills": [],
"tools": {
"calculator": {"precision": 10},
"web_search": {"max_results": 5, "timeout": 10, "api_key": "", "engine": "mock"},
"file_reader": {"allowed_root": "./workspace", "max_file_size_kb": 512},
"code_executor": {"timeout": 5, "sandbox": True},
"static_analyzer": {
"default_tool": "cppcheck",
"default_std": "c++17",
"timeout": 120,
"jobs": 4,
"output_format": "summary",
"max_issues": 500,
"allowed_roots": [],
"tool_extra_args": {
"cppcheck": "--suppress=missingIncludeSystem --suppress=unmatchedSuppression",
"clang-tidy": "--checks=*,-fuchsia-*,-google-*,-zircon-*",
"infer": "",
},
},
"ssh_docker": {
"default_ssh_port": 22,
"default_username": "root",
"connect_timeout": 30,
"cmd_timeout": 120,
"deploy_timeout": 300,
"default_restart_policy": "unless-stopped",
"default_tail_lines": 100,
"allowed_hosts": [],
"blocked_images": [],
"allow_privileged": False,
"servers": {},
},
},
"memory": {"max_history": 20, "enable_long_term": False, "vector_db_url": ""},
"logging": {"level": "DEBUG", "enable_file": True, "log_dir": "./logs", "log_file": "agent.log"},
"agent": {"max_chain_steps": 10, "enable_multi_step": True, "session_timeout": 3600, "fallback_to_rules": True},
}
# ════════════════════════════════════════════════════════════════
# 在线 MCP Skill 配置对象
# ════════════════════════════════════════════════════════════════
@dataclass
class MCPSkillConfig:
"""
单个在线 MCP Server 的连接配置
访问方式:
for skill in settings.mcp_skills:
skill.name
skill.transport
skill.url
skill.headers
skill.command / skill.args / skill.env # stdio 模式
skill.timeout
skill.retry
skill.include_tools
skill.exclude_tools
"""
name: str # skill 组名称
enabled: bool = True
transport: str = "sse" # sse | http | stdio
url: str = "" # sse / http 模式
headers: dict[str, str] = field(default_factory=dict)
command: str = "" # stdio 模式:可执行文件
args: list[str] = field(default_factory=list)
env: dict[str, str] = field(default_factory=dict)
timeout: int = 30
retry: int = 2
include_tools: list[str] = field(default_factory=list)
exclude_tools: list[str] = field(default_factory=list)
def __post_init__(self):
# 从环境变量自动填充 headers 中的空值
# 规则: headers 中值为空字符串时,尝试读取
# MCP_{NAME}_{HEADER_KEY} 环境变量
# 例如: name=everything, header key=Authorization
# → 环境变量: MCP_EVERYTHING_AUTHORIZATION
prefix = f"MCP_{self.name.upper().replace('-', '_')}"
for key, val in self.headers.items():
if not val:
env_key = f"{prefix}_{key.upper().replace('-', '_')}"
if env_val := os.getenv(env_key):
self.headers[key] = env_val
def is_tool_allowed(self, tool_name: str) -> bool:
"""判断工具是否应被暴露include/exclude 过滤)"""
if self.exclude_tools and tool_name in self.exclude_tools:
return False
if self.include_tools and tool_name not in self.include_tools:
return False
return True
def display(self) -> str:
if self.transport == "stdio":
conn = f"stdio cmd={self.command} {' '.join(self.args)}"
else:
conn = f"{self.transport} url={self.url}"
return (
f" MCPSkill[{self.name}] enabled={self.enabled} {conn}\n"
f" timeout={self.timeout}s retry={self.retry}\n"
f" include={self.include_tools or '(全部)'} "
f"exclude={self.exclude_tools or '(无)'}"
)
# ════════════════════════════════════════════════════════════════
# ToolsView —— 工具配置字典视图
# ════════════════════════════════════════════════════════════════
class ToolsView:
"""
工具配置字典视图,支持:
settings.tools['web_search']['timeout']
settings.tools['static_analyzer']['tool_extra_args']['cppcheck']
'ssh_docker' in settings.tools
"""
def __init__(self, data: dict[str, dict]):
self._data = data
def __getitem__(self, tool_name: str) -> dict[str, Any]:
if tool_name not in self._data:
raise KeyError(
f"工具 '{tool_name}' 未在配置中定义。"
f"可用工具: {list(self._data.keys())}"
)
return self._data[tool_name]
def __contains__(self, tool_name: str) -> bool:
return tool_name in self._data
def __repr__(self) -> str:
return f"ToolsView({list(self._data.keys())})"
def get(self, tool_name: str, default: Any = None) -> Any:
return self._data.get(tool_name, default)
def keys(self):
return self._data.keys()
# ════════════════════════════════════════════════════════════════
# 其他配置 dataclass
# ════════════════════════════════════════════════════════════════
@dataclass
class LLMConfig:
provider: str = "openai"
model_name: str = "gpt-4o"
api_key: str = ""
api_base_url: str = ""
max_tokens: int = 4096
temperature: float = 0.7
timeout: int = 60
max_retries: int = 3
function_calling: bool = True
stream: bool = False
model_path: str = ""
ollama_host: str = "http://localhost:11434"
def __post_init__(self):
self.api_key = os.getenv("LLM_API_KEY", self.api_key)
self.api_base_url = os.getenv("LLM_API_BASE_URL", self.api_base_url)
self.model_name = os.getenv("LLM_MODEL_NAME", self.model_name)
@dataclass
class MCPConfig:
server_name: str = "DemoMCPServer"
transport: str = "stdio"
host: str = "localhost"
port: int = 3000
enabled_tools: list[str] = field(default_factory=lambda: [
"calculator", "web_search", "file_reader",
"code_executor", "static_analyzer", "ssh_docker",
])
@dataclass
class MemoryConfig:
max_history: int = 20
enable_long_term: bool = False
vector_db_url: str = ""
@dataclass
class LoggingConfig:
level: str = "DEBUG"
enable_file: bool = True
log_dir: str = "./logs"
log_file: str = "agent.log"
def __post_init__(self):
self.level = os.getenv("LOG_LEVEL", self.level).upper()
@dataclass
class AgentConfig:
max_chain_steps: int = 10
enable_multi_step: bool = True
session_timeout: int = 3600
fallback_to_rules: bool = True
# ════════════════════════════════════════════════════════════════
# 顶层 AppConfig
# ════════════════════════════════════════════════════════════════
class AppConfig:
"""
全局配置单例
访问方式:
settings.llm.model_name
settings.mcp.enabled_tools
settings.mcp_skills # list[MCPSkillConfig]
settings.mcp_skills[0].name
settings.mcp_skills[0].url
settings.tools['web_search']['timeout']
settings.tools['ssh_docker']['servers']['prod']['host']
settings.memory.max_history
settings.agent.fallback_to_rules
settings.logging.level
"""
def __init__(
self,
llm: LLMConfig,
mcp: MCPConfig,
mcp_skills: list[MCPSkillConfig],
tools: ToolsView,
memory: MemoryConfig,
logging: LoggingConfig,
agent: AgentConfig,
):
self.llm = llm
self.mcp = mcp
self.mcp_skills = mcp_skills # 在线 MCP Skill 列表
self.tools = tools
self.memory = memory
self.logging = logging
self.agent = agent
@property
def enabled_mcp_skills(self) -> list[MCPSkillConfig]:
"""返回所有 enabled=true 的在线 MCP Skill"""
return [s for s in self.mcp_skills if s.enabled]
def display(self) -> str:
sa = self.tools['static_analyzer']
ssh = self.tools['ssh_docker']
ws = self.tools['web_search']
lines = [
"" * 64,
" 📋 当前配置",
"" * 64,
f" [LLM] provider = {self.llm.provider}",
f" [LLM] model_name = {self.llm.model_name}",
f" [LLM] api_key = {'***' + self.llm.api_key[-4:] if len(self.llm.api_key) > 4 else '(未设置)'}",
f" [LLM] temperature = {self.llm.temperature}",
f" [MCP] enabled_tools = {self.mcp.enabled_tools}",
"",
f" [MCP_SKILLS] 在线 Skill 数量: {len(self.mcp_skills)} "
f"(已启用: {len(self.enabled_mcp_skills)})",
]
for skill in self.mcp_skills:
icon = "" if skill.enabled else ""
lines.append(f" {icon} {skill.display()}")
lines += [
"",
f" [TOOL] web_search.engine = {ws['engine']}",
f" [TOOL] web_search.timeout = {ws['timeout']}s",
f" [TOOL] static_analyzer.tool= {sa['default_tool']}",
f" [TOOL] ssh_docker.port = {ssh['default_ssh_port']}",
f" [MEM] max_history = {self.memory.max_history}",
f" [AGT] max_chain_steps = {self.agent.max_chain_steps}",
f" [LOG] level = {self.logging.level}",
"" * 64,
]
return "\n".join(lines)
# ════════════════════════════════════════════════════════════════
# 配置加载器
# ════════════════════════════════════════════════════════════════
class ConfigLoader:
_SEARCH_PATHS = [
Path(os.getenv("AGENT_CONFIG_PATH", "__none__")),
Path("config") / "config.yaml",
Path("config.yaml"),
]
@classmethod
def load(cls) -> AppConfig:
raw = cls._read_yaml()
return cls._build(raw if raw is not None else {})
@classmethod
def _read_yaml(cls) -> dict[str, Any] | None:
if not _YAML_AVAILABLE:
print("⚠️ PyYAML 未安装pip install pyyaml使用默认配置")
return None
for path in cls._SEARCH_PATHS:
if path and path.exists() and path.suffix in (".yaml", ".yml"):
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f)
print(f"✅ 已加载配置文件: {path.resolve()}")
return data or {}
print(" 未找到配置文件,使用默认配置")
return None
@classmethod
def _build(cls, raw: dict[str, Any]) -> AppConfig:
return AppConfig(
llm=cls._build_llm(raw.get("llm", {})),
mcp=cls._build_mcp(raw.get("mcp", {})),
mcp_skills=cls._build_mcp_skills(raw.get("mcp_skills", [])),
tools=cls._build_tools(raw.get("tools", {})),
memory=cls._build_memory(raw.get("memory", {})),
logging=cls._build_logging(raw.get("logging", {})),
agent=cls._build_agent(raw.get("agent", {})),
)
# ── LLM ───────────────────────────────────────────────────
@staticmethod
def _build_llm(d: dict) -> LLMConfig:
df = _DEFAULTS["llm"]
return LLMConfig(
provider=d.get("provider", df["provider"]),
model_name=d.get("model_name", df["model_name"]),
api_key=d.get("api_key", df["api_key"]),
api_base_url=d.get("api_base_url", df["api_base_url"]),
max_tokens=int(d.get("max_tokens", df["max_tokens"])),
temperature=float(d.get("temperature", df["temperature"])),
timeout=int(d.get("timeout", df["timeout"])),
max_retries=int(d.get("max_retries", df["max_retries"])),
function_calling=bool(d.get("function_calling", df["function_calling"])),
stream=bool(d.get("stream", df["stream"])),
model_path=d.get("model_path", df["model_path"]),
ollama_host=d.get("ollama_host", df["ollama_host"]),
)
# ── MCP ───────────────────────────────────────────────────
@staticmethod
def _build_mcp(d: dict) -> MCPConfig:
df = _DEFAULTS["mcp"]
return MCPConfig(
server_name=d.get("server_name", df["server_name"]),
transport=d.get("transport", df["transport"]),
host=d.get("host", df["host"]),
port=int(d.get("port", df["port"])),
enabled_tools=d.get("enabled_tools", df["enabled_tools"]),
)
# ── MCP Skills在线 MCP Server 列表)────────────────────
@staticmethod
def _build_mcp_skills(raw_list: list) -> list[MCPSkillConfig]:
skills = []
if not isinstance(raw_list, list):
return skills
for item in raw_list:
if not isinstance(item, dict):
continue
name = item.get("name", "")
if not name:
continue
skills.append(MCPSkillConfig(
name=name,
enabled=bool(item.get("enabled", True)),
transport=item.get("transport", "sse"),
url=item.get("url", ""),
headers=dict(item.get("headers", {})),
command=item.get("command", ""),
args=list(item.get("args", [])),
env=dict(item.get("env", {})),
timeout=int(item.get("timeout", 30)),
retry=int(item.get("retry", 2)),
include_tools=list(item.get("include_tools", [])),
exclude_tools=list(item.get("exclude_tools", [])),
))
return skills
# ── Tools纯字典深度合并──────────────────────────────
@classmethod
def _build_tools(cls, d: dict) -> ToolsView:
df = _DEFAULTS["tools"]
merged: dict[str, dict] = {}
for tool_name, tool_defaults in df.items():
yaml_tool = d.get(tool_name, {})
merged[tool_name] = cls._deep_merge(tool_defaults, yaml_tool)
for tool_name, tool_cfg in d.items():
if tool_name not in merged:
merged[tool_name] = tool_cfg if isinstance(tool_cfg, dict) else {}
cls._apply_env_overrides(merged)
return ToolsView(merged)
@staticmethod
def _deep_merge(base: dict, override: dict) -> dict:
result = dict(base)
for key, val in override.items():
if key in result and isinstance(result[key], dict) and isinstance(val, dict):
result[key] = ConfigLoader._deep_merge(result[key], val)
else:
result[key] = val
return result
@staticmethod
def _apply_env_overrides(tools: dict[str, dict]) -> None:
if api_key := os.getenv("SEARCH_API_KEY"):
tools["web_search"]["api_key"] = api_key
for server_name, srv in tools.get("ssh_docker", {}).get("servers", {}).items():
if isinstance(srv, dict) and not srv.get("password"):
env_key = f"SSH_{server_name.upper()}_PASSWORD"
if pw := os.getenv(env_key):
srv["password"] = pw
# ── Memory / Logging / Agent ──────────────────────────────
@staticmethod
def _build_memory(d: dict) -> MemoryConfig:
df = _DEFAULTS["memory"]
return MemoryConfig(
max_history=int(d.get("max_history", df["max_history"])),
enable_long_term=bool(d.get("enable_long_term",df["enable_long_term"])),
vector_db_url=d.get("vector_db_url", df["vector_db_url"]),
)
@staticmethod
def _build_logging(d: dict) -> LoggingConfig:
df = _DEFAULTS["logging"]
return LoggingConfig(
level=d.get("level", df["level"]),
enable_file=bool(d.get("enable_file", df["enable_file"])),
log_dir=d.get("log_dir", df["log_dir"]),
log_file=d.get("log_file", df["log_file"]),
)
@staticmethod
def _build_agent(d: dict) -> AgentConfig:
df = _DEFAULTS["agent"]
return AgentConfig(
max_chain_steps=int(d.get("max_chain_steps", df["max_chain_steps"])),
enable_multi_step=bool(d.get("enable_multi_step", df["enable_multi_step"])),
session_timeout=int(d.get("session_timeout", df["session_timeout"])),
fallback_to_rules=bool(d.get("fallback_to_rules", df["fallback_to_rules"])),
)
# ── 全局单例 ──────────────────────────────────────────────────
settings: AppConfig = ConfigLoader.load()