278 lines
9.9 KiB
Python
278 lines
9.9 KiB
Python
"""
|
||
config/settings.py
|
||
配置加载与管理模块(新增 OpenAI 专用字段)
|
||
"""
|
||
|
||
import os
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
try:
|
||
import yaml
|
||
_YAML_AVAILABLE = True
|
||
except ImportError:
|
||
_YAML_AVAILABLE = False
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 配置数据类
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
@dataclass
|
||
class LLMConfig:
|
||
"""LLM 模型配置(含 OpenAI 专用字段)"""
|
||
provider: str = "openai"
|
||
model_name: str = "gpt-4o"
|
||
api_key: str = ""
|
||
api_base_url: str = ""
|
||
max_tokens: int = 4096
|
||
temperature: float = 0.7
|
||
timeout: int = 60
|
||
max_retries: int = 3
|
||
# OpenAI 专用
|
||
function_calling: bool = True
|
||
stream: bool = False
|
||
# Ollama / 本地模型
|
||
model_path: str = ""
|
||
ollama_host: str = "http://localhost:11434"
|
||
|
||
def __post_init__(self):
|
||
self.api_key = os.getenv("LLM_API_KEY", self.api_key)
|
||
self.api_base_url = os.getenv("LLM_API_BASE_URL", self.api_base_url)
|
||
self.model_name = os.getenv("LLM_MODEL_NAME", self.model_name)
|
||
self.model_path = os.getenv("LLM_MODEL_PATH", self.model_path)
|
||
|
||
|
||
@dataclass
|
||
class MCPConfig:
|
||
server_name: str = "DemoMCPServer"
|
||
transport: str = "stdio"
|
||
host: str = "localhost"
|
||
port: int = 3000
|
||
enabled_tools: list[str] = field(default_factory=lambda: [
|
||
"calculator", "web_search", "file_reader", "code_executor"
|
||
])
|
||
|
||
|
||
@dataclass
|
||
class WebSearchToolConfig:
|
||
max_results: int = 5
|
||
timeout: int = 10
|
||
api_key: str = ""
|
||
engine: str = "mock"
|
||
|
||
def __post_init__(self):
|
||
self.api_key = os.getenv("SEARCH_API_KEY", self.api_key)
|
||
|
||
|
||
@dataclass
|
||
class FileReaderToolConfig:
|
||
allowed_root: str = "./workspace"
|
||
max_file_size_kb: int = 512
|
||
|
||
|
||
@dataclass
|
||
class CodeExecutorToolConfig:
|
||
timeout: int = 5
|
||
sandbox: bool = True
|
||
|
||
|
||
@dataclass
|
||
class CalculatorToolConfig:
|
||
precision: int = 10
|
||
|
||
|
||
@dataclass
|
||
class ToolsConfig:
|
||
web_search: WebSearchToolConfig = field(default_factory=WebSearchToolConfig)
|
||
file_reader: FileReaderToolConfig = field(default_factory=FileReaderToolConfig)
|
||
code_executor: CodeExecutorToolConfig = field(default_factory=CodeExecutorToolConfig)
|
||
calculator: CalculatorToolConfig = field(default_factory=CalculatorToolConfig)
|
||
|
||
|
||
@dataclass
|
||
class MemoryConfig:
|
||
max_history: int = 20
|
||
enable_long_term: bool = False
|
||
vector_db_url: str = ""
|
||
|
||
|
||
@dataclass
|
||
class LoggingConfig:
|
||
level: str = "DEBUG"
|
||
enable_file: bool = True
|
||
log_dir: str = "./logs"
|
||
log_file: str = "agent.log"
|
||
|
||
def __post_init__(self):
|
||
self.level = os.getenv("LOG_LEVEL", self.level).upper()
|
||
|
||
|
||
@dataclass
|
||
class AgentConfig:
|
||
max_chain_steps: int = 10
|
||
enable_multi_step: bool = True
|
||
session_timeout: int = 3600
|
||
fallback_to_rules: bool = True # API 失败时降级到规则引擎
|
||
|
||
|
||
@dataclass
|
||
class AppConfig:
|
||
llm: LLMConfig = field(default_factory=LLMConfig)
|
||
mcp: MCPConfig = field(default_factory=MCPConfig)
|
||
tools: ToolsConfig = field(default_factory=ToolsConfig)
|
||
memory: MemoryConfig = field(default_factory=MemoryConfig)
|
||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||
agent: AgentConfig = field(default_factory=AgentConfig)
|
||
|
||
def display(self) -> str:
|
||
lines = [
|
||
"─" * 52,
|
||
" 📋 当前配置",
|
||
"─" * 52,
|
||
f" [LLM] provider = {self.llm.provider}",
|
||
f" [LLM] model_name = {self.llm.model_name}",
|
||
f" [LLM] api_key = {'***' if self.llm.api_key else '(未设置)'}",
|
||
f" [LLM] api_base_url = {self.llm.api_base_url or '(默认)'}",
|
||
f" [LLM] temperature = {self.llm.temperature}",
|
||
f" [LLM] max_tokens = {self.llm.max_tokens}",
|
||
f" [LLM] function_calling = {self.llm.function_calling}",
|
||
f" [LLM] stream = {self.llm.stream}",
|
||
f" [LLM] max_retries = {self.llm.max_retries}",
|
||
f" [MCP] server_name = {self.mcp.server_name}",
|
||
f" [MCP] enabled_tools = {self.mcp.enabled_tools}",
|
||
f" [MEMORY] max_history = {self.memory.max_history}",
|
||
f" [AGENT] multi_step = {self.agent.enable_multi_step}",
|
||
f" [AGENT] fallback_rules = {self.agent.fallback_to_rules}",
|
||
f" [LOG] level = {self.logging.level}",
|
||
"─" * 52,
|
||
]
|
||
return "\n".join(lines)
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 配置加载器
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class ConfigLoader:
|
||
_CONFIG_SEARCH_PATHS = [
|
||
Path(os.getenv("AGENT_CONFIG_PATH", "./config.yaml")),
|
||
Path("config") / "config.yaml",
|
||
Path("config.yaml"),
|
||
]
|
||
|
||
@classmethod
|
||
def load(cls) -> AppConfig:
|
||
raw = cls._read_yaml()
|
||
return cls._parse(raw) if raw else AppConfig()
|
||
|
||
@classmethod
|
||
def _read_yaml(cls) -> dict[str, Any] | None:
|
||
if not _YAML_AVAILABLE:
|
||
print("⚠️ PyYAML 未安装(pip install pyyaml),使用默认配置")
|
||
return None
|
||
for path in cls._CONFIG_SEARCH_PATHS:
|
||
if path and path.exists():
|
||
with open(path, encoding="utf-8") as f:
|
||
data = yaml.safe_load(f)
|
||
print(f"✅ 已加载配置文件: {path.resolve()}")
|
||
return data or {}
|
||
print("ℹ️ 未找到配置文件,使用默认配置")
|
||
return None
|
||
|
||
@classmethod
|
||
def _parse(cls, raw: dict[str, Any]) -> AppConfig:
|
||
return AppConfig(
|
||
llm=cls._parse_llm(raw.get("llm", {})),
|
||
mcp=cls._parse_mcp(raw.get("mcp", {})),
|
||
tools=cls._parse_tools(raw.get("tools", {})),
|
||
memory=cls._parse_memory(raw.get("memory", {})),
|
||
logging=cls._parse_logging(raw.get("logging", {})),
|
||
agent=cls._parse_agent(raw.get("agent", {})),
|
||
)
|
||
|
||
@staticmethod
|
||
def _parse_llm(d: dict) -> LLMConfig:
|
||
return LLMConfig(
|
||
provider=d.get("provider", "openai"),
|
||
model_name=d.get("model_name", "gpt-4o"),
|
||
api_key=d.get("api_key", ""),
|
||
api_base_url=d.get("api_base_url", ""),
|
||
max_tokens=int(d.get("max_tokens", 4096)),
|
||
temperature=float(d.get("temperature", 0.7)),
|
||
timeout=int(d.get("timeout", 60)),
|
||
max_retries=int(d.get("max_retries", 3)),
|
||
function_calling=bool(d.get("function_calling", True)),
|
||
stream=bool(d.get("stream", False)),
|
||
model_path=d.get("model_path", ""),
|
||
ollama_host=d.get("ollama_host", "http://localhost:11434"),
|
||
)
|
||
|
||
@staticmethod
|
||
def _parse_mcp(d: dict) -> MCPConfig:
|
||
return MCPConfig(
|
||
server_name=d.get("server_name", "DemoMCPServer"),
|
||
transport=d.get("transport", "stdio"),
|
||
host=d.get("host", "localhost"),
|
||
port=int(d.get("port", 3000)),
|
||
enabled_tools=d.get("enabled_tools", [
|
||
"calculator", "web_search", "file_reader", "code_executor"
|
||
]),
|
||
)
|
||
|
||
@staticmethod
|
||
def _parse_tools(d: dict) -> ToolsConfig:
|
||
ws = d.get("web_search", {})
|
||
fr = d.get("file_reader", {})
|
||
ce = d.get("code_executor", {})
|
||
ca = d.get("calculator", {})
|
||
return ToolsConfig(
|
||
web_search=WebSearchToolConfig(
|
||
max_results=int(ws.get("max_results", 5)),
|
||
timeout=int(ws.get("timeout", 10)),
|
||
api_key=ws.get("api_key", ""),
|
||
engine=ws.get("engine", "mock"),
|
||
),
|
||
file_reader=FileReaderToolConfig(
|
||
allowed_root=fr.get("allowed_root", "./workspace"),
|
||
max_file_size_kb=int(fr.get("max_file_size_kb", 512)),
|
||
),
|
||
code_executor=CodeExecutorToolConfig(
|
||
timeout=int(ce.get("timeout", 5)),
|
||
sandbox=bool(ce.get("sandbox", True)),
|
||
),
|
||
calculator=CalculatorToolConfig(
|
||
precision=int(ca.get("precision", 10)),
|
||
),
|
||
)
|
||
|
||
@staticmethod
|
||
def _parse_memory(d: dict) -> MemoryConfig:
|
||
return MemoryConfig(
|
||
max_history=int(d.get("max_history", 20)),
|
||
enable_long_term=bool(d.get("enable_long_term", False)),
|
||
vector_db_url=d.get("vector_db_url", ""),
|
||
)
|
||
|
||
@staticmethod
|
||
def _parse_logging(d: dict) -> LoggingConfig:
|
||
return LoggingConfig(
|
||
level=d.get("level", "DEBUG"),
|
||
enable_file=bool(d.get("enable_file", True)),
|
||
log_dir=d.get("log_dir", "./logs"),
|
||
log_file=d.get("log_file", "agent.log"),
|
||
)
|
||
|
||
@staticmethod
|
||
def _parse_agent(d: dict) -> AgentConfig:
|
||
return AgentConfig(
|
||
max_chain_steps=int(d.get("max_chain_steps", 10)),
|
||
enable_multi_step=bool(d.get("enable_multi_step", True)),
|
||
session_timeout=int(d.get("session_timeout", 3600)),
|
||
fallback_to_rules=bool(d.get("fallback_to_rules", True)),
|
||
)
|
||
|
||
|
||
# 全局单例
|
||
settings: AppConfig = ConfigLoader.load() |