278 lines
9.9 KiB
Python
278 lines
9.9 KiB
Python
|
|
"""
|
|||
|
|
config/settings.py
|
|||
|
|
配置加载与管理模块(新增 OpenAI 专用字段)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import yaml
|
|||
|
|
_YAML_AVAILABLE = True
|
|||
|
|
except ImportError:
|
|||
|
|
_YAML_AVAILABLE = False
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# 配置数据类
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class LLMConfig:
|
|||
|
|
"""LLM 模型配置(含 OpenAI 专用字段)"""
|
|||
|
|
provider: str = "openai"
|
|||
|
|
model_name: str = "gpt-4o"
|
|||
|
|
api_key: str = ""
|
|||
|
|
api_base_url: str = ""
|
|||
|
|
max_tokens: int = 4096
|
|||
|
|
temperature: float = 0.7
|
|||
|
|
timeout: int = 60
|
|||
|
|
max_retries: int = 3
|
|||
|
|
# OpenAI 专用
|
|||
|
|
function_calling: bool = True
|
|||
|
|
stream: bool = False
|
|||
|
|
# Ollama / 本地模型
|
|||
|
|
model_path: str = ""
|
|||
|
|
ollama_host: str = "http://localhost:11434"
|
|||
|
|
|
|||
|
|
def __post_init__(self):
|
|||
|
|
self.api_key = os.getenv("LLM_API_KEY", self.api_key)
|
|||
|
|
self.api_base_url = os.getenv("LLM_API_BASE_URL", self.api_base_url)
|
|||
|
|
self.model_name = os.getenv("LLM_MODEL_NAME", self.model_name)
|
|||
|
|
self.model_path = os.getenv("LLM_MODEL_PATH", self.model_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MCPConfig:
|
|||
|
|
server_name: str = "DemoMCPServer"
|
|||
|
|
transport: str = "stdio"
|
|||
|
|
host: str = "localhost"
|
|||
|
|
port: int = 3000
|
|||
|
|
enabled_tools: list[str] = field(default_factory=lambda: [
|
|||
|
|
"calculator", "web_search", "file_reader", "code_executor"
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class WebSearchToolConfig:
|
|||
|
|
max_results: int = 5
|
|||
|
|
timeout: int = 10
|
|||
|
|
api_key: str = ""
|
|||
|
|
engine: str = "mock"
|
|||
|
|
|
|||
|
|
def __post_init__(self):
|
|||
|
|
self.api_key = os.getenv("SEARCH_API_KEY", self.api_key)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class FileReaderToolConfig:
|
|||
|
|
allowed_root: str = "./workspace"
|
|||
|
|
max_file_size_kb: int = 512
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class CodeExecutorToolConfig:
|
|||
|
|
timeout: int = 5
|
|||
|
|
sandbox: bool = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class CalculatorToolConfig:
|
|||
|
|
precision: int = 10
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ToolsConfig:
|
|||
|
|
web_search: WebSearchToolConfig = field(default_factory=WebSearchToolConfig)
|
|||
|
|
file_reader: FileReaderToolConfig = field(default_factory=FileReaderToolConfig)
|
|||
|
|
code_executor: CodeExecutorToolConfig = field(default_factory=CodeExecutorToolConfig)
|
|||
|
|
calculator: CalculatorToolConfig = field(default_factory=CalculatorToolConfig)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MemoryConfig:
|
|||
|
|
max_history: int = 20
|
|||
|
|
enable_long_term: bool = False
|
|||
|
|
vector_db_url: str = ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class LoggingConfig:
|
|||
|
|
level: str = "DEBUG"
|
|||
|
|
enable_file: bool = True
|
|||
|
|
log_dir: str = "./logs"
|
|||
|
|
log_file: str = "agent.log"
|
|||
|
|
|
|||
|
|
def __post_init__(self):
|
|||
|
|
self.level = os.getenv("LOG_LEVEL", self.level).upper()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AgentConfig:
|
|||
|
|
max_chain_steps: int = 10
|
|||
|
|
enable_multi_step: bool = True
|
|||
|
|
session_timeout: int = 3600
|
|||
|
|
fallback_to_rules: bool = True # API 失败时降级到规则引擎
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class AppConfig:
|
|||
|
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
|||
|
|
mcp: MCPConfig = field(default_factory=MCPConfig)
|
|||
|
|
tools: ToolsConfig = field(default_factory=ToolsConfig)
|
|||
|
|
memory: MemoryConfig = field(default_factory=MemoryConfig)
|
|||
|
|
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
|||
|
|
agent: AgentConfig = field(default_factory=AgentConfig)
|
|||
|
|
|
|||
|
|
def display(self) -> str:
|
|||
|
|
lines = [
|
|||
|
|
"─" * 52,
|
|||
|
|
" 📋 当前配置",
|
|||
|
|
"─" * 52,
|
|||
|
|
f" [LLM] provider = {self.llm.provider}",
|
|||
|
|
f" [LLM] model_name = {self.llm.model_name}",
|
|||
|
|
f" [LLM] api_key = {'***' if self.llm.api_key else '(未设置)'}",
|
|||
|
|
f" [LLM] api_base_url = {self.llm.api_base_url or '(默认)'}",
|
|||
|
|
f" [LLM] temperature = {self.llm.temperature}",
|
|||
|
|
f" [LLM] max_tokens = {self.llm.max_tokens}",
|
|||
|
|
f" [LLM] function_calling = {self.llm.function_calling}",
|
|||
|
|
f" [LLM] stream = {self.llm.stream}",
|
|||
|
|
f" [LLM] max_retries = {self.llm.max_retries}",
|
|||
|
|
f" [MCP] server_name = {self.mcp.server_name}",
|
|||
|
|
f" [MCP] enabled_tools = {self.mcp.enabled_tools}",
|
|||
|
|
f" [MEMORY] max_history = {self.memory.max_history}",
|
|||
|
|
f" [AGENT] multi_step = {self.agent.enable_multi_step}",
|
|||
|
|
f" [AGENT] fallback_rules = {self.agent.fallback_to_rules}",
|
|||
|
|
f" [LOG] level = {self.logging.level}",
|
|||
|
|
"─" * 52,
|
|||
|
|
]
|
|||
|
|
return "\n".join(lines)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
# 配置加载器
|
|||
|
|
# ════════════════════════════════════════════════════════════════
|
|||
|
|
|
|||
|
|
class ConfigLoader:
|
|||
|
|
_CONFIG_SEARCH_PATHS = [
|
|||
|
|
Path(os.getenv("AGENT_CONFIG_PATH", "./config.yaml")),
|
|||
|
|
Path("config") / "config.yaml",
|
|||
|
|
Path("config.yaml"),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def load(cls) -> AppConfig:
|
|||
|
|
raw = cls._read_yaml()
|
|||
|
|
return cls._parse(raw) if raw else AppConfig()
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def _read_yaml(cls) -> dict[str, Any] | None:
|
|||
|
|
if not _YAML_AVAILABLE:
|
|||
|
|
print("⚠️ PyYAML 未安装(pip install pyyaml),使用默认配置")
|
|||
|
|
return None
|
|||
|
|
for path in cls._CONFIG_SEARCH_PATHS:
|
|||
|
|
if path and path.exists():
|
|||
|
|
with open(path, encoding="utf-8") as f:
|
|||
|
|
data = yaml.safe_load(f)
|
|||
|
|
print(f"✅ 已加载配置文件: {path.resolve()}")
|
|||
|
|
return data or {}
|
|||
|
|
print("ℹ️ 未找到配置文件,使用默认配置")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def _parse(cls, raw: dict[str, Any]) -> AppConfig:
|
|||
|
|
return AppConfig(
|
|||
|
|
llm=cls._parse_llm(raw.get("llm", {})),
|
|||
|
|
mcp=cls._parse_mcp(raw.get("mcp", {})),
|
|||
|
|
tools=cls._parse_tools(raw.get("tools", {})),
|
|||
|
|
memory=cls._parse_memory(raw.get("memory", {})),
|
|||
|
|
logging=cls._parse_logging(raw.get("logging", {})),
|
|||
|
|
agent=cls._parse_agent(raw.get("agent", {})),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_llm(d: dict) -> LLMConfig:
|
|||
|
|
return LLMConfig(
|
|||
|
|
provider=d.get("provider", "openai"),
|
|||
|
|
model_name=d.get("model_name", "gpt-4o"),
|
|||
|
|
api_key=d.get("api_key", ""),
|
|||
|
|
api_base_url=d.get("api_base_url", ""),
|
|||
|
|
max_tokens=int(d.get("max_tokens", 4096)),
|
|||
|
|
temperature=float(d.get("temperature", 0.7)),
|
|||
|
|
timeout=int(d.get("timeout", 60)),
|
|||
|
|
max_retries=int(d.get("max_retries", 3)),
|
|||
|
|
function_calling=bool(d.get("function_calling", True)),
|
|||
|
|
stream=bool(d.get("stream", False)),
|
|||
|
|
model_path=d.get("model_path", ""),
|
|||
|
|
ollama_host=d.get("ollama_host", "http://localhost:11434"),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_mcp(d: dict) -> MCPConfig:
|
|||
|
|
return MCPConfig(
|
|||
|
|
server_name=d.get("server_name", "DemoMCPServer"),
|
|||
|
|
transport=d.get("transport", "stdio"),
|
|||
|
|
host=d.get("host", "localhost"),
|
|||
|
|
port=int(d.get("port", 3000)),
|
|||
|
|
enabled_tools=d.get("enabled_tools", [
|
|||
|
|
"calculator", "web_search", "file_reader", "code_executor"
|
|||
|
|
]),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_tools(d: dict) -> ToolsConfig:
|
|||
|
|
ws = d.get("web_search", {})
|
|||
|
|
fr = d.get("file_reader", {})
|
|||
|
|
ce = d.get("code_executor", {})
|
|||
|
|
ca = d.get("calculator", {})
|
|||
|
|
return ToolsConfig(
|
|||
|
|
web_search=WebSearchToolConfig(
|
|||
|
|
max_results=int(ws.get("max_results", 5)),
|
|||
|
|
timeout=int(ws.get("timeout", 10)),
|
|||
|
|
api_key=ws.get("api_key", ""),
|
|||
|
|
engine=ws.get("engine", "mock"),
|
|||
|
|
),
|
|||
|
|
file_reader=FileReaderToolConfig(
|
|||
|
|
allowed_root=fr.get("allowed_root", "./workspace"),
|
|||
|
|
max_file_size_kb=int(fr.get("max_file_size_kb", 512)),
|
|||
|
|
),
|
|||
|
|
code_executor=CodeExecutorToolConfig(
|
|||
|
|
timeout=int(ce.get("timeout", 5)),
|
|||
|
|
sandbox=bool(ce.get("sandbox", True)),
|
|||
|
|
),
|
|||
|
|
calculator=CalculatorToolConfig(
|
|||
|
|
precision=int(ca.get("precision", 10)),
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_memory(d: dict) -> MemoryConfig:
|
|||
|
|
return MemoryConfig(
|
|||
|
|
max_history=int(d.get("max_history", 20)),
|
|||
|
|
enable_long_term=bool(d.get("enable_long_term", False)),
|
|||
|
|
vector_db_url=d.get("vector_db_url", ""),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_logging(d: dict) -> LoggingConfig:
|
|||
|
|
return LoggingConfig(
|
|||
|
|
level=d.get("level", "DEBUG"),
|
|||
|
|
enable_file=bool(d.get("enable_file", True)),
|
|||
|
|
log_dir=d.get("log_dir", "./logs"),
|
|||
|
|
log_file=d.get("log_file", "agent.log"),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _parse_agent(d: dict) -> AgentConfig:
|
|||
|
|
return AgentConfig(
|
|||
|
|
max_chain_steps=int(d.get("max_chain_steps", 10)),
|
|||
|
|
enable_multi_step=bool(d.get("enable_multi_step", True)),
|
|||
|
|
session_timeout=int(d.get("session_timeout", 3600)),
|
|||
|
|
fallback_to_rules=bool(d.get("fallback_to_rules", True)),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
settings: AppConfig = ConfigLoader.load()
|