in dev
This commit is contained in:
parent
09aea3a693
commit
246b2a485f
|
|
@ -8,9 +8,9 @@ import time
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from config.settings import settings
|
||||
from mcp.skill_registry import DispatchResult, SkillRegistry
|
||||
from utils.logger import get_logger
|
||||
from agent.config.settings import settings
|
||||
from agent.mcp.skill_registry import DispatchResult, SkillRegistry
|
||||
from agent.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("Agent")
|
||||
|
||||
|
|
@ -210,14 +210,6 @@ class Agent:
|
|||
|
||||
可用工具如下:
|
||||
{tools}
|
||||
|
||||
请严格按照以下格式进行回应:
|
||||
|
||||
Thought: 你的思考过程,用于分析问题、拆解任务和规划下一步行动。
|
||||
Action: 你决定采取的行动,必须是以下格式之一:
|
||||
- `{{tool_name}}[{{tool_input}}]`:调用一个可用工具。
|
||||
- `Finish[最终答案]`:当你认为已经获得最终答案时。
|
||||
- 当你收集到足够的信息,能够回答用户的最终问题时,你必须在Action:字段后使用 Finish[最终答案] 来输出最终答案。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -380,7 +372,7 @@ class Agent:
|
|||
# ── 消息构造 ──────────────────────────────────────────────
|
||||
|
||||
def _build_messages(self, tools, history: list[Message]) -> list[dict]:
|
||||
prompt = self.system_prompt.format(tools="\n".join(f"{tool['name']}: {tool['description']}" for tool in tools))
|
||||
prompt = self.system_prompt
|
||||
messages = [{"role": "system", "content":prompt}]
|
||||
messages += [m.to_api_dict() for m in history]
|
||||
return messages
|
||||
|
|
@ -404,7 +396,7 @@ class Agent:
|
|||
# Demo 入口
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
def create_agent() -> tuple[Agent, SkillRegistry]:
|
||||
def create_agent(agent_prompt=None) -> tuple[Agent, SkillRegistry]:
|
||||
"""
|
||||
工厂函数:创建并初始化 Agent + SkillRegistry
|
||||
|
||||
|
|
@ -416,7 +408,7 @@ def create_agent() -> tuple[Agent, SkillRegistry]:
|
|||
registry.load_local_tools()
|
||||
# 连接在线 MCP Skill(来自 config.yaml mcp_skills)
|
||||
registry.connect_mcp_skills()
|
||||
agent = Agent(registry)
|
||||
agent = Agent(registry, system_prompt=agent_prompt)
|
||||
return agent, registry
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,16 +8,14 @@ client/agent_client.py
|
|||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from config.settings import settings
|
||||
from llm.llm_engine import LLMEngine
|
||||
from mcp.mcp_protocol import (
|
||||
from agent.llm.llm_engine import LLMEngine
|
||||
from agent.mcp.mcp_protocol import (
|
||||
ChainPlan, ChainResult, MCPRequest, MCPResponse,
|
||||
StepResult, ToolStep,
|
||||
)
|
||||
from mcp.mcp_server import MCPServer
|
||||
from memory.memory_store import MemoryStore
|
||||
from utils.logger import get_logger
|
||||
from agent.mcp.mcp_server import MCPServer
|
||||
from agent.memory.memory_store import MemoryStore
|
||||
from agent.utils.logger import get_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
170
config.yaml
170
config.yaml
|
|
@ -1,170 +0,0 @@
|
|||
# ════════════════════════════════════════════════════════════════
|
||||
# config/config.yaml — Agent 系统全局配置文件
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
# ── LLM 模型配置 ───────────────────────────────────────────────
|
||||
llm:
|
||||
provider: "openai"
|
||||
model_name: "gpt-4o"
|
||||
api_key: "sk-AUmOuFI731Ty5Nob38jY26d8lydfDT-QkE2giqb0sCuPCAE2JH6zjLM4lZLpvL5WMYPOocaMe2FwVDmqM_9KimmKACjR" # 优先读取环境变量 LLM_API_KEY
|
||||
api_base_url: "https://openapi.monica.im/v1" # 自定义代理地址,留空使用官方
|
||||
max_tokens: 4096
|
||||
temperature: 0.7
|
||||
timeout: 60
|
||||
max_retries: 3
|
||||
function_calling: true
|
||||
stream: false
|
||||
model_path: ""
|
||||
ollama_host: "http://localhost:11434"
|
||||
database:
|
||||
type: "sqlite"
|
||||
url: "sqlite:///skills.db"
|
||||
skills_directory: "./" # 新增:SKILL.md 文件所在目录
|
||||
# ── 本地 MCP Server 配置 ───────────────────────────────────────
|
||||
mcp:
|
||||
server_name: "MCPServer"
|
||||
transport: "stdio"
|
||||
host: "localhost"
|
||||
port: 3000
|
||||
# 本地注册的工具列表
|
||||
enabled_tools:
|
||||
# - calculator
|
||||
- web_search
|
||||
# - file_reader
|
||||
# - code_executor
|
||||
|
||||
# ── 在线 MCP Skill 配置 ────────────────────────────────────────
|
||||
# 每一项代表一个远端 MCP Server,其暴露的所有工具将作为 skill 注册到 Agent
|
||||
mcp_skills:
|
||||
|
||||
# 示例一:SSE 传输(最常见的在线 MCP Server 形式)
|
||||
# - name: "everything" # skill 组名称(用于日志/调试)
|
||||
# enabled: true
|
||||
# transport: "sse" # sse | http | stdio
|
||||
# url: "http://localhost:3001/sse"
|
||||
# # 请求头(可用于 API Key 认证)
|
||||
# headers:
|
||||
# Authorization: "" # 优先读取环境变量 MCP_EVERYTHING_TOKEN
|
||||
# timeout: 30 # 连接超时(秒)
|
||||
# retry: 2 # 失败重试次数
|
||||
# # 只暴露指定工具(空列表=全部暴露)
|
||||
# include_tools: []
|
||||
# # 排除指定工具
|
||||
# exclude_tools: []
|
||||
|
||||
# 示例二:Streamable HTTP 传输
|
||||
# - name: "remote-tools"
|
||||
# enabled: false
|
||||
# transport: "http"
|
||||
# url: "http://api.example.com/mcp"
|
||||
# headers:
|
||||
# Authorization: "Bearer your_token_here"
|
||||
# X-Client-ID: "agent-demo"
|
||||
# timeout: 30
|
||||
# retry: 2
|
||||
# include_tools: []
|
||||
# exclude_tools: []
|
||||
|
||||
# 示例三:stdio 子进程(本地可执行文件作为 MCP Server)
|
||||
# - name: "filesystem"
|
||||
# enabled: false
|
||||
# transport: "stdio"
|
||||
# # stdio 模式使用 command 启动子进程,不需要 url
|
||||
# command: "npx"
|
||||
# args:
|
||||
# - "-y"
|
||||
# - "@modelcontextprotocol/server-filesystem"
|
||||
# - "/tmp"
|
||||
# env:
|
||||
# NODE_ENV: "production"
|
||||
# timeout: 30
|
||||
# retry: 1
|
||||
# include_tools: []
|
||||
# exclude_tools: []
|
||||
- name: "hexstrike-ai"
|
||||
enabled: true
|
||||
transport: "stdio"
|
||||
command: "python3"
|
||||
args:
|
||||
- "/Users/sontolau/Applications/hexstrike-ai/hexstrike_mcp.py"
|
||||
- "--server"
|
||||
- "http://localhost:8999"
|
||||
description: "HexStrike AI v6.0 - Advanced Cybersecurity Automation Platform"
|
||||
timeout: 300
|
||||
# 示例四:带鉴权的在线 MCP SaaS 服务
|
||||
# - name: "brave-search"
|
||||
# enabled: false
|
||||
# transport: "sse"
|
||||
# url: "https://mcp.brave.com/sse"
|
||||
# headers:
|
||||
# Authorization: "" # 优先读取环境变量 MCP_BRAVE_SEARCH_TOKEN
|
||||
# timeout: 20
|
||||
# retry: 2
|
||||
# include_tools:
|
||||
# - "brave_web_search"
|
||||
# - "brave_local_search"
|
||||
# exclude_tools: []
|
||||
|
||||
# ── 工具配置 ───────────────────────────────────────────────────
|
||||
tools:
|
||||
calculator:
|
||||
precision: 10
|
||||
|
||||
web_search:
|
||||
max_results: 5
|
||||
timeout: 10
|
||||
api_key: "7917bef5e46044af5209fdb78518be98be394f3fe763bbce3fbb503280408bd9"
|
||||
|
||||
file_reader:
|
||||
allowed_root: "./workspace"
|
||||
max_file_size_kb: 512
|
||||
|
||||
code_executor:
|
||||
timeout: 5
|
||||
sandbox: true
|
||||
|
||||
static_analyzer:
|
||||
default_tool: "cppcheck"
|
||||
default_std: "c++17"
|
||||
timeout: 120
|
||||
jobs: 4
|
||||
output_format: "summary"
|
||||
max_issues: 500
|
||||
allowed_roots: []
|
||||
tool_extra_args:
|
||||
cppcheck: "--suppress=missingIncludeSystem --suppress=unmatchedSuppression"
|
||||
clang-tidy: "--checks=*,-fuchsia-*,-google-*,-zircon-*"
|
||||
infer: ""
|
||||
|
||||
ssh_docker:
|
||||
default_ssh_port: 22
|
||||
default_username: "root"
|
||||
connect_timeout: 30
|
||||
cmd_timeout: 120
|
||||
deploy_timeout: 300
|
||||
default_restart_policy: "unless-stopped"
|
||||
default_tail_lines: 100
|
||||
allowed_hosts: []
|
||||
blocked_images: []
|
||||
allow_privileged: false
|
||||
servers: {}
|
||||
|
||||
# ── 记忆配置 ───────────────────────────────────────────────────
|
||||
memory:
|
||||
max_history: 20
|
||||
enable_long_term: false
|
||||
vector_db_url: ""
|
||||
|
||||
# ── 日志配置 ───────────────────────────────────────────────────
|
||||
logging:
|
||||
level: "DEBUG"
|
||||
enable_file: true
|
||||
log_dir: "./logs"
|
||||
log_file: "agent.log"
|
||||
|
||||
# ── Agent 行为配置 ─────────────────────────────────────────────
|
||||
agent:
|
||||
max_chain_steps: 10
|
||||
enable_multi_step: true
|
||||
session_timeout: 3600
|
||||
fallback_to_rules: true
|
||||
|
|
@ -221,6 +221,7 @@ class MCPConfig:
|
|||
transport: str = "stdio"
|
||||
host: str = "localhost"
|
||||
port: int = 3000
|
||||
tools_directory: str = "tools"
|
||||
enabled_tools: list[str] = field(default_factory=lambda: [
|
||||
"calculator", "web_search", "file_reader",
|
||||
"code_executor", "static_analyzer", "ssh_docker",
|
||||
|
|
@ -257,6 +258,12 @@ class DatabaseConfig:
|
|||
type: str = "sqlite"
|
||||
url: str = None
|
||||
|
||||
@dataclass
|
||||
class DeviceConfig:
|
||||
type: str
|
||||
device_id: str
|
||||
protocol: str
|
||||
params: dict
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 顶层 AppConfig
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
|
@ -287,6 +294,7 @@ class AppConfig:
|
|||
memory: MemoryConfig,
|
||||
logging: LoggingConfig,
|
||||
agent: AgentConfig,
|
||||
device: DeviceConfig,
|
||||
skills_directory: str = "./skills",
|
||||
database: DatabaseConfig = DatabaseConfig()
|
||||
):
|
||||
|
|
@ -297,6 +305,7 @@ class AppConfig:
|
|||
self.memory = memory
|
||||
self.logging = logging
|
||||
self.agent = agent
|
||||
self.device = device
|
||||
self.skills_directory = skills_directory
|
||||
self.database = database
|
||||
|
||||
|
|
@ -382,12 +391,22 @@ class ConfigLoader:
|
|||
logging=cls._build_logging(raw.get("logging", {})),
|
||||
agent=cls._build_agent(raw.get("agent", {})),
|
||||
skills_directory=raw.get("skills_directory", ""),
|
||||
device=cls._build_device(raw.get("device", {})),
|
||||
database=DatabaseConfig(
|
||||
type=raw.get("database", {}).get("type", "sqlite"),
|
||||
url=raw.get("database", {}).get("url", "sqlite:///skills.db"),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_device(d: dict):
|
||||
return DeviceConfig(
|
||||
type=d.get("type", None),
|
||||
device_id=d.get("device_id", None),
|
||||
protocol=d.get("protocol", None),
|
||||
params=d.get("params", {})
|
||||
)
|
||||
|
||||
# ── LLM ───────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -418,6 +437,7 @@ class ConfigLoader:
|
|||
transport=d.get("transport", df["transport"]),
|
||||
host=d.get("host", df["host"]),
|
||||
port=int(d.get("port", df["port"])),
|
||||
tools_directory=d.get("tools_directory", "tools"),
|
||||
enabled_tools=d.get("enabled_tools", df["enabled_tools"]),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
from agent.config.settings import settings
|
||||
from core.uas_control.controllers.robot_dog_controller import RobotDogController
|
||||
from core.uas_control.controllers.uav_controller import UAVController
|
||||
from core.uas_control.protocols.mavlink_adapter import MAVLinkAdapter
|
||||
from core.uas_control.protocols.ros_adapter import ROSAdapter
|
||||
from core.uas_control.protocols.simulation_adapter import SimulationAdapter
|
||||
|
||||
uas_type = settings.device.type
|
||||
protocol = settings.device.protocol
|
||||
device_id = settings.device.device_id
|
||||
|
||||
if protocol == "mavlink":
|
||||
adapter = MAVLinkAdapter(device_id=device_id, **settings.device.params)
|
||||
elif protocol == "ros":
|
||||
adapter = ROSAdapter(device_id=device_id, **settings.device.params)
|
||||
else:
|
||||
adapter = SimulationAdapter(device_id=device_id, **settings.device.params)
|
||||
|
||||
if uas_type == "uav":
|
||||
controller = UAVController(device_id=device_id, adapter=adapter)
|
||||
elif uas_type == "robot_dog":
|
||||
controller = RobotDogController(device_id=device_id, adapter=adapter)
|
||||
else:
|
||||
controller = UAVController(device_id=device_id, adapter=adapter)
|
||||
|
||||
# results, errors = [], []
|
||||
# controller.command_result.connect(lambda r: results.append(r))
|
||||
# controller.error_occurred.connect(lambda m: errors.append(m))
|
||||
# assert controller.connect(), "连接应成功"
|
||||
|
||||
7
envs
7
envs
|
|
@ -1,7 +0,0 @@
|
|||
LOG_DIR="./logs" # 日志目录
|
||||
|
||||
# --- HEXSTRIKE_AI 配置项 ---
|
||||
HEXSTRIKE_AI_ENTRY="hexstrike_server:app" # WSGI 入口,格式为 模块名:变量名
|
||||
HEXSTRIKE_AI_BIND_ADDRESS="127.0.0.1:8000" # 监听地址
|
||||
HEXSTRIKE_AI_WORKERS=4 # 工作进程数
|
||||
HEXSTRIKE_AI_PID_FILE="./hexstrike_ai.pid" # PID 文件保存位置
|
||||
19
install.sh
19
install.sh
|
|
@ -1,19 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
SKILLS_DIR="./skills"
|
||||
|
||||
function install () {
|
||||
name=$1
|
||||
url=$2
|
||||
|
||||
if ! [ -d "${SKILLS_DIR}/$name" ]; then
|
||||
git clone $url ${SKILLS_DIR}/$name
|
||||
fi
|
||||
requirements_file="${SKILLS_DIR}/$name/requirements.txt"
|
||||
if [ -f ${requirements_file} ]; then
|
||||
pip install -r $requirements_file
|
||||
fi
|
||||
}
|
||||
|
||||
pip install -r ./requirements.txt
|
||||
install "hexstrike_ai" "https://github.com/0x4m4/hexstrike-ai.git"
|
||||
|
|
@ -7577,3 +7577,4 @@ The function `get_system_name()` uses `platform.system()` to determine the syste
|
|||
时间, 08:00, 11:00 ; 天气 ; 气温, 23.9℃,
|
||||
2026-06-01 15:52:17 [INFO ] agent.Agent │ 🔁 推理步骤 2/10
|
||||
2026-06-01 15:52:19 [DEBUG ] agent.Agent │ LLM 响应: finish=stop tool_calls=0 content=今天成都双流区的天气为多云,气温范围大约在22℃到32℃,风力较轻,适合外出活动。请注意根据具体天气情况,适当增减衣物。
|
||||
2026-06-01 15:54:14 [INFO ] agent.MCP.SkillRegistry │ 🔌 SkillRegistry 已关闭所有连接
|
||||
|
|
|
|||
159
main.py
159
main.py
|
|
@ -1,159 +0,0 @@
|
|||
"""
|
||||
main.py
|
||||
项目入口 —— 启动 Agent 交互式对话 或 MCP Server stdio 模式
|
||||
|
||||
用法:
|
||||
python main.py # 启动 Agent 交互式对话(默认)
|
||||
python main.py --mode agent # 同上
|
||||
python main.py --mode mcp # 启动本地 MCP Server(stdio 模式)
|
||||
python main.py --mode check # 检查配置和依赖
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import atexit
|
||||
import sys
|
||||
|
||||
|
||||
def run_agent() -> None:
|
||||
"""启动 Agent 交互式对话"""
|
||||
from agent.agent import create_agent
|
||||
from config.settings import settings
|
||||
|
||||
print(settings.display())
|
||||
|
||||
agent, registry = create_agent()
|
||||
atexit.register(registry.close)
|
||||
|
||||
print(agent.show_tools())
|
||||
print("─" * 60)
|
||||
print("💡 命令: exit=退出 reset=清空历史 tools=查看工具列表")
|
||||
print("─" * 60)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n🧑 You: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n👋 再见!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
match user_input.lower():
|
||||
case "exit" | "quit":
|
||||
print("👋 再见!")
|
||||
break
|
||||
case "reset":
|
||||
agent.reset()
|
||||
print("🔄 对话历史已清空")
|
||||
case "tools":
|
||||
print(agent.show_tools())
|
||||
case _:
|
||||
reply = agent.chat(user_input)
|
||||
print(f"\n🤖 Agent: {reply}")
|
||||
|
||||
|
||||
def run_mcp_server() -> None:
|
||||
"""启动本地 MCP Server(stdio 模式)"""
|
||||
from mcp.mcp_server import MCPServer
|
||||
with MCPServer() as server:
|
||||
server.run_stdio()
|
||||
|
||||
|
||||
# def run_check() -> None:
|
||||
# """检查配置和依赖完整性"""
|
||||
# print("=" * 60)
|
||||
# print(" 🔍 项目依赖检查")
|
||||
# print("=" * 60)
|
||||
#
|
||||
# checks = [
|
||||
# ("pyyaml", "yaml", "pip install pyyaml"),
|
||||
# ("openai", "openai", "pip install openai>=1.0.0"),
|
||||
# ("httpx", "httpx", "pip install httpx>=0.27.0"),
|
||||
# ("httpx-sse", "httpx_sse", "pip install httpx-sse>=0.4.0"),
|
||||
# ("paramiko", "paramiko", "pip install paramiko>=3.0.0"),
|
||||
# ]
|
||||
#
|
||||
# all_ok = True
|
||||
# for pkg_name, import_name, install_cmd in checks:
|
||||
# try:
|
||||
# __import__(import_name)
|
||||
# print(f" ✅ {pkg_name:<15} 已安装")
|
||||
# except ImportError:
|
||||
# print(f" ❌ {pkg_name:<15} 未安装 → {install_cmd}")
|
||||
# all_ok = False
|
||||
#
|
||||
# print()
|
||||
#
|
||||
# # 配置检查
|
||||
# try:
|
||||
# from config.settings import settings
|
||||
# print(" ✅ config/settings.py 加载成功")
|
||||
# print(f" LLM : {settings.llm.provider} / {settings.llm.model_name}")
|
||||
# print(f" 本地工具: {settings.mcp.enabled_tools}")
|
||||
# skills = settings.enabled_mcp_skills
|
||||
# if skills:
|
||||
# print(f" 在线Skill: {[s.name for s in skills]}")
|
||||
# else:
|
||||
# print(" 在线Skill: (未配置)")
|
||||
# except Exception as e:
|
||||
# print(f" ❌ 配置加载失败: {e}")
|
||||
# all_ok = False
|
||||
#
|
||||
# print()
|
||||
#
|
||||
# # 工具注册检查
|
||||
# try:
|
||||
# from mcp.skill_registry import SkillRegistry
|
||||
# from tools.calculator import CalculatorTool
|
||||
# from tools.code_executor import CodeExecutorTool
|
||||
# from tools.file_reader import FileReaderTool
|
||||
# from tools.ssh_docker import SSHDockerTool
|
||||
# from tools.static_analyzer import StaticAnalyzerTool
|
||||
# from tools.web_search import WebSearchTool
|
||||
#
|
||||
# registry = SkillRegistry()
|
||||
# registry.register_local_many(
|
||||
# CalculatorTool(), WebSearchTool(), FileReaderTool(),
|
||||
# CodeExecutorTool(), StaticAnalyzerTool(), SSHDockerTool(),
|
||||
# )
|
||||
# tools = registry.list_all_tools()
|
||||
# print(f" ✅ 本地工具注册 共 {len(tools)} 个:")
|
||||
# for t in tools:
|
||||
# print(f" 🔵 {t['name']}: {t['description'][:50]}")
|
||||
# except Exception as e:
|
||||
# print(f" ❌ 工具注册失败: {e}")
|
||||
# all_ok = False
|
||||
#
|
||||
# print()
|
||||
# print("=" * 60)
|
||||
# if all_ok:
|
||||
# print(" ✅ 所有检查通过,项目可正常运行")
|
||||
# else:
|
||||
# print(" ⚠️ 存在问题,请按提示安装缺失依赖")
|
||||
# print("=" * 60)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Agent Demo —— 支持本地工具 + 在线 MCP Skill"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["agent", "mcp", "check"],
|
||||
default="agent",
|
||||
help="运行模式: agent(交互对话)| mcp(MCP Server)| check(依赖检查)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
match args.mode:
|
||||
case "agent":
|
||||
run_agent()
|
||||
case "mcp":
|
||||
run_mcp_server()
|
||||
# case "check":
|
||||
# run_check()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -6,9 +6,9 @@ mcp/mcp_server.py
|
|||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
from config.settings import settings
|
||||
from mcp.skill_registry import SkillRegistry
|
||||
from utils.logger import get_logger
|
||||
from agent.config.settings import settings
|
||||
from agent.mcp.skill_registry import SkillRegistry
|
||||
from agent.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("MCP.Server")
|
||||
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ import uuid
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterator
|
||||
|
||||
from config.settings import MCPSkillConfig
|
||||
from utils.logger import get_logger
|
||||
from agent.config.settings import MCPSkillConfig
|
||||
from agent.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("MCP.SkillClient")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from utils.logger import get_logger
|
||||
from agent.utils.logger import get_logger
|
||||
import yaml
|
||||
|
||||
logger = get_logger("mcp.SkillLoader")
|
||||
|
|
|
|||
|
|
@ -14,11 +14,11 @@ import time
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from config.settings import settings
|
||||
from mcp.mcp_skill_client import MCPSkillClient, RemoteTool, ToolCallResult
|
||||
from mcp.skill_loader import SkillLoader
|
||||
from tools.base_tool import BaseTool
|
||||
from utils.logger import get_logger
|
||||
from agent.config.settings import settings
|
||||
from agent.mcp.mcp_skill_client import MCPSkillClient, RemoteTool, ToolCallResult
|
||||
from agent.mcp.skill_loader import SkillLoader
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("MCP.SkillRegistry")
|
||||
|
||||
|
|
@ -111,8 +111,10 @@ class SkillRegistry:
|
|||
enabled = settings.mcp.enabled_tools
|
||||
logger.info(f"🔧 注册本地工具: {enabled}")
|
||||
for tool_name in enabled:
|
||||
tool_path = f"tools/{tool_name}.py"
|
||||
tool_path =f"{settings.mcp.tools_directory}/{tool_name}.py"
|
||||
# tool_path = f"tools/{tool_name}.py"
|
||||
if not os.path.exists(tool_path):
|
||||
logger.warning(f"未找到工具:{tool_path}")
|
||||
continue
|
||||
# 动态加载模块
|
||||
spec = importlib.util.spec_from_file_location(tool_name, tool_path)
|
||||
|
|
|
|||
36
run.sh
36
run.sh
|
|
@ -1,36 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
if [ -f "./envs" ]; then
|
||||
source ./envs
|
||||
fi
|
||||
|
||||
# 创建日志目录
|
||||
if ! [ -d ${LOG_DIR} ]; then
|
||||
mkdir -p $LOG_DIR
|
||||
fi
|
||||
|
||||
function start_hexstrike_ai() {
|
||||
if [ -f "$HEXSTRIKE_AI_PID_FILE" ]; then
|
||||
echo "错误: hexstrike_ai 似乎已在运行 (PID: $(cat $HEXSTRIKE_AI_PID_FILE))"
|
||||
exit 1
|
||||
fi
|
||||
# 启动 Gunicorn
|
||||
echo "正在启动 hexstrike_ai"
|
||||
cd skills/hexstrike_ai && gunicorn -w ${HEXSTRIKE_AI_WORKERS} \
|
||||
-b ${HEXSTRIKE_AI_BIND_ADDRESS} \
|
||||
--pid ${HEXSTRIKE_AI_PID_FILE} \
|
||||
--access-logfile "$LOG_DIR/access.log" \
|
||||
--error-logfile "$LOG_DIR/error.log" \
|
||||
-D \
|
||||
${HEXSTRIKE_AI_ENTRY}
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "启动成功!PID: $(cat $PID_FILE)"
|
||||
else
|
||||
echo "启动失败,请检查日志。"
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
|
||||
start_hexstrike_ai
|
||||
23
stop.sh
23
stop.sh
|
|
@ -1,23 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [ -f "./envs" ]; then
|
||||
source ./envs
|
||||
fi
|
||||
|
||||
function stop() {
|
||||
PID_FILE=$1
|
||||
echo $PID_FILE
|
||||
if [ -f "$PID_FILE" ]; then
|
||||
PID=$(cat $PID_FILE)
|
||||
echo "正在停止进程 $PID..."
|
||||
kill $PID
|
||||
# 循环等待进程结束并清理 PID 文件
|
||||
while ps -p $PID > /dev/null; do sleep 1; done
|
||||
rm $PID_FILE
|
||||
echo "服务已停止。"
|
||||
else
|
||||
echo "未发现正在运行的服务 (未找到 PID 文件)。"
|
||||
fi
|
||||
}
|
||||
|
||||
stop ${HEXSTRIKE_AI_PID_FILE}
|
||||
|
|
@ -6,9 +6,8 @@ tools/base_tool.py
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from mcp.mcp_protocol import ToolSchema
|
||||
from utils.logger import get_logger
|
||||
from agent.mcp.mcp_protocol import ToolSchema
|
||||
from agent.utils.logger import get_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -1,113 +0,0 @@
|
|||
"""
|
||||
tools/calculator.py
|
||||
数学计算工具 —— 支持基本四则运算及常用数学函数
|
||||
配置通过 settings.tools['calculator'] 读取
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
from config.settings import settings
|
||||
from tools.base_tool import BaseTool
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger("TOOL.Calculator")
|
||||
|
||||
|
||||
def _cfg(key: str, fallback=None):
|
||||
return settings.tools['calculator'].get(key, fallback)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
name = "calculator"
|
||||
description = (
|
||||
"执行数学计算,支持四则运算、幂运算、开方、三角函数、对数等。"
|
||||
"输入数学表达式字符串,返回计算结果。"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"数学表达式,例如: '2 + 3 * 4', 'sqrt(16)', "
|
||||
"'sin(3.14159/2)', 'log(100, 10)', '2 ** 10'"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["expression"],
|
||||
}
|
||||
|
||||
# 安全的内置函数白名单
|
||||
_SAFE_GLOBALS: dict[str, Any] = {
|
||||
"__builtins__": {},
|
||||
# 基本数学
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"pow": pow,
|
||||
"min": min,
|
||||
"max": max,
|
||||
# math 模块常用函数
|
||||
"sqrt": math.sqrt,
|
||||
"ceil": math.ceil,
|
||||
"floor": math.floor,
|
||||
"log": math.log,
|
||||
"log2": math.log2,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"asin": math.asin,
|
||||
"acos": math.acos,
|
||||
"atan": math.atan,
|
||||
"atan2": math.atan2,
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"inf": math.inf,
|
||||
"factorial": math.factorial,
|
||||
"gcd": math.gcd,
|
||||
"hypot": math.hypot,
|
||||
}
|
||||
|
||||
def execute(self, expression: str = "", **_) -> str:
|
||||
if not expression or not expression.strip():
|
||||
return "❌ 参数错误: expression 不能为空"
|
||||
|
||||
expr = expression.strip()
|
||||
logger.info(f"🔢 计算表达式: {expr}")
|
||||
|
||||
# 安全检查:禁止危险关键字
|
||||
forbidden = ["import", "exec", "eval", "open", "os", "sys",
|
||||
"__", "compile", "globals", "locals", "getattr"]
|
||||
for kw in forbidden:
|
||||
if kw in expr:
|
||||
return f"❌ 安全限制: 表达式包含禁止关键字 '{kw}'"
|
||||
|
||||
try:
|
||||
precision = _cfg('precision', 10)
|
||||
result = eval(expr, self._SAFE_GLOBALS, {}) # noqa: S307
|
||||
|
||||
# 格式化输出
|
||||
if isinstance(result, float):
|
||||
# 去除多余尾零
|
||||
formatted = f"{result:.{precision}f}".rstrip("0").rstrip(".")
|
||||
elif isinstance(result, complex):
|
||||
formatted = str(result)
|
||||
else:
|
||||
formatted = str(result)
|
||||
|
||||
logger.info(f"✅ 计算结果: {expr} = {formatted}")
|
||||
return f"{expr} = {formatted}"
|
||||
|
||||
except ZeroDivisionError:
|
||||
return f"❌ 计算错误: 除零错误 表达式: {expr}"
|
||||
except OverflowError:
|
||||
return f"❌ 计算错误: 数值溢出 表达式: {expr}"
|
||||
except ValueError as e:
|
||||
return f"❌ 计算错误: {e} 表达式: {expr}"
|
||||
except SyntaxError:
|
||||
return f"❌ 语法错误: 无法解析表达式 '{expr}'"
|
||||
except Exception as e:
|
||||
return f"❌ 计算失败: {e} 表达式: {expr}"
|
||||
|
|
@ -1,149 +0,0 @@
|
|||
"""
|
||||
tools/code_executor.py
|
||||
代码执行工具 —— 在沙箱中执行 Python 代码片段
|
||||
配置通过 settings.tools['code_executor'] 读取
|
||||
"""
|
||||
|
||||
import io
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
|
||||
from config.settings import settings
|
||||
from tools.base_tool import BaseTool
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger("TOOL.CodeExecutor")
|
||||
|
||||
|
||||
def _cfg(key: str, fallback=None):
|
||||
return settings.tools['code_executor'].get(key, fallback)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
name = "code_executor"
|
||||
description = (
|
||||
"在安全沙箱中执行 Python 代码片段,返回标准输出和执行结果。"
|
||||
"适用于数据处理、计算、格式转换等任务。"
|
||||
"注意:沙箱模式下禁止文件系统写入、网络访问和系统调用。"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "要执行的 Python 代码字符串",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "执行超时秒数(默认来自 config.yaml code_executor.timeout)",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
}
|
||||
|
||||
# 沙箱模式下禁止的模块和函数
|
||||
_FORBIDDEN_SANDBOX = [
|
||||
"import os", "import sys", "import subprocess",
|
||||
"import socket", "import requests", "import httpx",
|
||||
"import shutil", "open(", "__import__",
|
||||
"exec(", "eval(", "compile(",
|
||||
]
|
||||
|
||||
def execute(self, code: str = "", timeout: int | None = None, **_) -> str:
|
||||
if not code or not code.strip():
|
||||
return "❌ 参数错误: code 不能为空"
|
||||
|
||||
sandbox = _cfg('sandbox', True)
|
||||
t = timeout or _cfg('timeout', 5)
|
||||
code = textwrap.dedent(code)
|
||||
|
||||
logger.info(
|
||||
f"🐍 执行代码 sandbox={sandbox} timeout={t}s "
|
||||
f"[config timeout={_cfg('timeout')}s sandbox={_cfg('sandbox')}]\n"
|
||||
f" 代码预览: {code[:100]}"
|
||||
)
|
||||
|
||||
# 沙箱安全检查
|
||||
if sandbox:
|
||||
err = self._sandbox_check(code)
|
||||
if err:
|
||||
return err
|
||||
|
||||
# 使用线程超时执行
|
||||
return self._run_with_timeout(code, t)
|
||||
|
||||
def _run_with_timeout(self, code: str, timeout: int) -> str:
|
||||
"""在独立线程中执行代码,超时则终止"""
|
||||
import threading
|
||||
|
||||
result_box: list[str] = []
|
||||
error_box: list[str] = []
|
||||
|
||||
def _run():
|
||||
stdout_buf = io.StringIO()
|
||||
stderr_buf = io.StringIO()
|
||||
local_ns: dict = {}
|
||||
start = time.time()
|
||||
try:
|
||||
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
|
||||
exec(code, {"__builtins__": __builtins__}, local_ns) # noqa: S102
|
||||
elapsed = time.time() - start
|
||||
stdout = stdout_buf.getvalue()
|
||||
stderr = stderr_buf.getvalue()
|
||||
# 尝试获取最后一个表达式的值
|
||||
last_val = ""
|
||||
lines = [l.strip() for l in code.strip().splitlines() if l.strip()]
|
||||
if lines:
|
||||
last_line = lines[-1]
|
||||
if not last_line.startswith(("#", "print", "import", "from",
|
||||
"def ", "class ", "if ", "for ",
|
||||
"while ", "try:", "with ")):
|
||||
try:
|
||||
val = eval(last_line, {"__builtins__": __builtins__}, local_ns) # noqa: S307
|
||||
if val is not None:
|
||||
last_val = f"\n返回值: {repr(val)}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output = stdout + (f"\n[stderr]\n{stderr}" if stderr else "") + last_val
|
||||
result_box.append(
|
||||
f"✅ 执行成功 耗时={elapsed:.3f}s\n"
|
||||
f"{'─' * 40}\n"
|
||||
f"{output.strip() or '(无输出)'}"
|
||||
)
|
||||
except Exception:
|
||||
elapsed = time.time() - start
|
||||
tb = traceback.format_exc()
|
||||
error_box.append(
|
||||
f"❌ 执行错误 耗时={elapsed:.3f}s\n"
|
||||
f"{'─' * 40}\n{tb}"
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=_run, daemon=True)
|
||||
thread.start()
|
||||
thread.join(timeout=timeout)
|
||||
|
||||
if thread.is_alive():
|
||||
return (
|
||||
f"⏰ 执行超时(>{timeout}s)\n"
|
||||
f" 请增大 config.yaml → tools.code_executor.timeout\n"
|
||||
f" 或优化代码逻辑"
|
||||
)
|
||||
|
||||
if error_box:
|
||||
return error_box[0]
|
||||
return result_box[0] if result_box else "❌ 执行失败(未知错误)"
|
||||
|
||||
def _sandbox_check(self, code: str) -> str | None:
|
||||
"""沙箱模式下的静态安全检查"""
|
||||
for forbidden in self._FORBIDDEN_SANDBOX:
|
||||
if forbidden in code:
|
||||
return (
|
||||
f"❌ 沙箱限制: 代码包含禁止操作 '{forbidden}'\n"
|
||||
f" 如需完整权限请在 config.yaml → "
|
||||
f"tools.code_executor.sandbox 设置为 false"
|
||||
)
|
||||
return None
|
||||
|
|
@ -1,179 +0,0 @@
|
|||
"""
|
||||
tools/file_reader.py
|
||||
文件读取工具 —— 读取本地文件内容,支持文本/JSON/CSV
|
||||
配置通过 settings.tools['file_reader'] 读取
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from config.settings import settings
|
||||
from tools.base_tool import BaseTool
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger("TOOL.FileReader")
|
||||
|
||||
|
||||
def _cfg(key: str, fallback=None):
|
||||
return settings.tools['file_reader'].get(key, fallback)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
name = "file_reader"
|
||||
description = (
|
||||
"读取本地文件内容,支持 .txt / .md / .py / .json / .csv / .yaml / .log 等文本文件。"
|
||||
"文件必须位于 config.yaml file_reader.allowed_root 目录下。"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "文件路径(相对于 allowed_root 或绝对路径)",
|
||||
},
|
||||
"encoding": {
|
||||
"type": "string",
|
||||
"description": "文件编码,默认 utf-8",
|
||||
},
|
||||
"max_lines": {
|
||||
"type": "integer",
|
||||
"description": "最多读取行数,0 表示全部读取",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
}
|
||||
|
||||
_TEXT_EXTENSIONS = {
|
||||
".txt", ".md", ".py", ".js", ".ts", ".java", ".c", ".cpp",
|
||||
".h", ".hpp", ".go", ".rs", ".rb", ".php", ".sh", ".bash",
|
||||
".yaml", ".yml", ".toml", ".ini", ".cfg", ".conf",
|
||||
".json", ".csv", ".log", ".xml", ".html", ".css", ".sql",
|
||||
".env", ".gitignore", ".dockerfile",
|
||||
}
|
||||
|
||||
def execute(
|
||||
self,
|
||||
file_path: str = "",
|
||||
encoding: str = "utf-8",
|
||||
max_lines: int = 0,
|
||||
**_,
|
||||
) -> str:
|
||||
if not file_path or not file_path.strip():
|
||||
return "❌ 参数错误: file_path 不能为空"
|
||||
|
||||
allowed_root = Path(_cfg('allowed_root', './workspace')).resolve()
|
||||
max_size_kb = _cfg('max_file_size_kb', 512)
|
||||
|
||||
# 路径解析
|
||||
path = Path(file_path)
|
||||
if not path.is_absolute():
|
||||
path = allowed_root / path
|
||||
path = path.resolve()
|
||||
|
||||
logger.info(
|
||||
f"📄 读取文件: {path}\n"
|
||||
f" allowed_root={allowed_root} "
|
||||
f"max_size={max_size_kb}KB [config]"
|
||||
)
|
||||
|
||||
# 安全检查:必须在 allowed_root 内
|
||||
try:
|
||||
path.relative_to(allowed_root)
|
||||
except ValueError:
|
||||
return (
|
||||
f"❌ 安全限制: 文件路径超出允许范围\n"
|
||||
f" 路径: {path}\n"
|
||||
f" 允许范围: {allowed_root}\n"
|
||||
f" 请在 config.yaml → tools.file_reader.allowed_root 中调整"
|
||||
)
|
||||
|
||||
if not path.exists():
|
||||
return f"❌ 文件不存在: {path}"
|
||||
if not path.is_file():
|
||||
return f"❌ 路径不是文件: {path}"
|
||||
|
||||
# 文件大小检查
|
||||
size_kb = path.stat().st_size / 1024
|
||||
if size_kb > max_size_kb:
|
||||
return (
|
||||
f"❌ 文件过大: {size_kb:.1f} KB > 限制 {max_size_kb} KB\n"
|
||||
f" 请在 config.yaml → tools.file_reader.max_file_size_kb 中调整"
|
||||
)
|
||||
|
||||
# 扩展名检查
|
||||
suffix = path.suffix.lower()
|
||||
if suffix not in self._TEXT_EXTENSIONS:
|
||||
return (
|
||||
f"❌ 不支持的文件类型: {suffix}\n"
|
||||
f" 支持类型: {', '.join(sorted(self._TEXT_EXTENSIONS))}"
|
||||
)
|
||||
|
||||
# 读取文件
|
||||
try:
|
||||
if suffix == ".json":
|
||||
return self._read_json(path, encoding)
|
||||
if suffix == ".csv":
|
||||
return self._read_csv(path, encoding, max_lines)
|
||||
return self._read_text(path, encoding, max_lines)
|
||||
except UnicodeDecodeError:
|
||||
return (
|
||||
f"❌ 编码错误: 无法以 {encoding} 解码文件\n"
|
||||
f" 请尝试指定 encoding 参数,例如 'gbk' 或 'latin-1'"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"❌ 读取失败: {e}"
|
||||
|
||||
@staticmethod
|
||||
def _read_text(path: Path, encoding: str, max_lines: int) -> str:
|
||||
content = path.read_text(encoding=encoding)
|
||||
lines = content.splitlines()
|
||||
total = len(lines)
|
||||
if max_lines and max_lines < total:
|
||||
shown = lines[:max_lines]
|
||||
omitted = total - max_lines
|
||||
text = "\n".join(shown)
|
||||
return (
|
||||
f"📄 {path.name} ({total} 行,显示前 {max_lines} 行)\n"
|
||||
f"{'─' * 50}\n{text}\n"
|
||||
f"{'─' * 50}\n... 还有 {omitted} 行未显示"
|
||||
)
|
||||
return f"📄 {path.name} ({total} 行)\n{'─' * 50}\n{content}"
|
||||
|
||||
@staticmethod
|
||||
def _read_json(path: Path, encoding: str) -> str:
|
||||
content = path.read_text(encoding=encoding)
|
||||
try:
|
||||
data = json.loads(content)
|
||||
formatted = json.dumps(data, ensure_ascii=False, indent=2)
|
||||
return f"📄 {path.name} (JSON)\n{'─' * 50}\n{formatted}"
|
||||
except json.JSONDecodeError as e:
|
||||
return f"⚠️ JSON 解析失败: {e}\n原始内容:\n{content[:500]}"
|
||||
|
||||
@staticmethod
|
||||
def _read_csv(path: Path, encoding: str, max_lines: int) -> str:
|
||||
content = path.read_text(encoding=encoding)
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
rows = list(reader)
|
||||
total = len(rows)
|
||||
limit = max_lines if max_lines else min(total, 50)
|
||||
shown = rows[:limit]
|
||||
|
||||
# 计算列宽
|
||||
if not shown:
|
||||
return f"📄 {path.name} (CSV,空文件)"
|
||||
col_widths = [
|
||||
max(len(str(row[i])) if i < len(row) else 0 for row in shown)
|
||||
for i in range(len(shown[0]))
|
||||
]
|
||||
lines = [f"📄 {path.name} (CSV,{total} 行)", "─" * 50]
|
||||
for row in shown:
|
||||
cells = [
|
||||
str(row[i]).ljust(col_widths[i]) if i < len(row) else ""
|
||||
for i in range(len(shown[0]))
|
||||
]
|
||||
lines.append(" | ".join(cells))
|
||||
if total > limit:
|
||||
lines.append(f"... 还有 {total - limit} 行未显示")
|
||||
return "\n".join(lines)
|
||||
|
|
@ -1,733 +0,0 @@
|
|||
"""
|
||||
tools/ssh_docker.py
|
||||
SSH 远程 Docker 部署工具 —— 所有配置通过 settings.tools['ssh_docker'][key] 获取
|
||||
依赖: pip install paramiko>=3.0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from config.settings import settings
|
||||
from tools.base_tool import BaseTool
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger("TOOL.SSHDocker")
|
||||
|
||||
try:
|
||||
import paramiko
|
||||
_PARAMIKO_AVAILABLE = True
|
||||
except ImportError:
|
||||
_PARAMIKO_AVAILABLE = False
|
||||
logger.warning("⚠️ paramiko 未安装,请执行: pip install paramiko>=3.0.0")
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 配置访问快捷函数
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
def _cfg(key: str, fallback=None):
|
||||
"""读取 ssh_docker 工具配置,不存在时返回 fallback"""
|
||||
return settings.tools['ssh_docker'].get(key, fallback)
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 数据结构
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class SSHConfig:
|
||||
host: str
|
||||
port: int = 22
|
||||
username: str = "root"
|
||||
password: str = ""
|
||||
key_path: str = ""
|
||||
timeout: int = 30
|
||||
cmd_timeout: int = 120
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, kwargs: dict) -> "SSHConfig":
|
||||
"""
|
||||
从调用参数构造 SSHConfig
|
||||
支持通过 server 名称引用 config.yaml 中的预设
|
||||
缺省值全部来自 config.yaml → tools.ssh_docker
|
||||
"""
|
||||
server_name = kwargs.get("server", "")
|
||||
if server_name:
|
||||
servers = _cfg('servers', {})
|
||||
preset = servers.get(server_name)
|
||||
if not preset:
|
||||
raise ValueError(
|
||||
f"服务器预设 '{server_name}' 未在 config.yaml "
|
||||
f"tools.ssh_docker.servers 中定义\n"
|
||||
f"已有预设: {list(servers.keys())}"
|
||||
)
|
||||
logger.info(f"📋 使用服务器预设: {server_name} → {preset.get('host')}")
|
||||
return cls(
|
||||
host=preset.get("host", ""),
|
||||
port=int(preset.get("port", _cfg('default_ssh_port', 22))),
|
||||
username=preset.get("username", _cfg('default_username', 'root')),
|
||||
password=preset.get("password", ""),
|
||||
key_path=preset.get("key_path", ""),
|
||||
timeout=_cfg('connect_timeout', 30),
|
||||
cmd_timeout=_cfg('cmd_timeout', 120),
|
||||
)
|
||||
|
||||
return cls(
|
||||
host=kwargs.get("host", ""),
|
||||
port=int(kwargs.get("port", _cfg('default_ssh_port', 22))),
|
||||
username=kwargs.get("username", _cfg('default_username', 'root')),
|
||||
password=kwargs.get("password", ""),
|
||||
key_path=kwargs.get("key_path", ""),
|
||||
timeout=_cfg('connect_timeout', 30),
|
||||
cmd_timeout=_cfg('cmd_timeout', 120),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandResult:
|
||||
command: str
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int
|
||||
success: bool = True
|
||||
|
||||
@property
|
||||
def output(self) -> str:
|
||||
return self.stdout.strip() or self.stderr.strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeployConfig:
|
||||
image: str
|
||||
container_name: str
|
||||
action: str = "deploy"
|
||||
ports: list[str] = field(default_factory=list)
|
||||
volumes: list[str] = field(default_factory=list)
|
||||
env_vars: dict[str, str] = field(default_factory=dict)
|
||||
network: str = ""
|
||||
restart_policy: str = ""
|
||||
command: str = ""
|
||||
compose_file: str = ""
|
||||
pull_latest: bool = True
|
||||
extra_args: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
# 重启策略缺省值来自 config.yaml
|
||||
if not self.restart_policy:
|
||||
self.restart_policy = _cfg('default_restart_policy', 'unless-stopped')
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# SSH 连接管理器
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
class SSHManager:
|
||||
def __init__(self, cfg: SSHConfig):
|
||||
self.cfg = cfg
|
||||
self.client: "paramiko.SSHClient | None" = None
|
||||
|
||||
def connect(self) -> None:
|
||||
if not _PARAMIKO_AVAILABLE:
|
||||
raise RuntimeError("paramiko 未安装,请执行: pip install paramiko>=3.0.0")
|
||||
|
||||
self.client = paramiko.SSHClient()
|
||||
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
connect_kwargs: dict = {
|
||||
"hostname": self.cfg.host,
|
||||
"port": self.cfg.port,
|
||||
"username": self.cfg.username,
|
||||
"timeout": self.cfg.timeout,
|
||||
}
|
||||
if self.cfg.key_path:
|
||||
logger.info(f"🔑 使用密钥认证: {self.cfg.key_path}")
|
||||
connect_kwargs["key_filename"] = self.cfg.key_path
|
||||
elif self.cfg.password:
|
||||
logger.info("🔐 使用密码认证")
|
||||
connect_kwargs["password"] = self.cfg.password
|
||||
else:
|
||||
logger.info("🔓 尝试 SSH Agent / 默认密钥认证")
|
||||
|
||||
self.client.connect(**connect_kwargs)
|
||||
logger.info(
|
||||
f"✅ SSH 连接成功: {self.cfg.username}@{self.cfg.host}:{self.cfg.port}\n"
|
||||
f" 连接超时: {self.cfg.timeout}s "
|
||||
f"[config.yaml connect_timeout={_cfg('connect_timeout')}s]\n"
|
||||
f" 命令超时: {self.cfg.cmd_timeout}s "
|
||||
f"[config.yaml cmd_timeout={_cfg('cmd_timeout')}s]"
|
||||
)
|
||||
|
||||
def exec(self, command: str, timeout: int | None = None) -> CommandResult:
|
||||
if not self.client:
|
||||
raise RuntimeError("SSH 未连接,请先调用 connect()")
|
||||
t = timeout or self.cfg.cmd_timeout
|
||||
logger.debug(f"🖥 执行命令 (timeout={t}s): {command}")
|
||||
|
||||
_, stdout, stderr = self.client.exec_command(command, timeout=t)
|
||||
exit_code = stdout.channel.recv_exit_status()
|
||||
out = stdout.read().decode("utf-8", errors="replace")
|
||||
err = stderr.read().decode("utf-8", errors="replace")
|
||||
|
||||
result = CommandResult(
|
||||
command=command, stdout=out, stderr=err,
|
||||
exit_code=exit_code, success=(exit_code == 0),
|
||||
)
|
||||
logger.debug(f" exit={exit_code} out={out[:80]} err={err[:80]}")
|
||||
return result
|
||||
|
||||
def close(self) -> None:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
self.close()
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# Docker 操作执行器
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
class DockerExecutor:
|
||||
ALLOWED_ACTIONS = {
|
||||
"deploy", "start", "stop", "restart",
|
||||
"status", "logs", "remove",
|
||||
"compose_up", "compose_down", "compose_ps",
|
||||
"pull", "inspect", "stats",
|
||||
}
|
||||
|
||||
def __init__(self, ssh: SSHManager):
|
||||
self.ssh = ssh
|
||||
|
||||
def check_docker(self) -> CommandResult:
|
||||
return self.ssh.exec(
|
||||
"docker --version && docker info --format '{{.ServerVersion}}'"
|
||||
)
|
||||
|
||||
def pull_image(self, image: str) -> CommandResult:
|
||||
logger.info(f"📥 拉取镜像: {image}")
|
||||
return self.ssh.exec(
|
||||
f"docker pull {image}",
|
||||
timeout=_cfg('deploy_timeout', 300),
|
||||
)
|
||||
|
||||
def deploy(self, cfg: DeployConfig) -> list[CommandResult]:
|
||||
results = []
|
||||
if cfg.pull_latest:
|
||||
results.append(self.pull_image(cfg.image))
|
||||
results.append(self.ssh.exec(
|
||||
f"docker stop {cfg.container_name} 2>/dev/null || true"
|
||||
))
|
||||
results.append(self.ssh.exec(
|
||||
f"docker rm {cfg.container_name} 2>/dev/null || true"
|
||||
))
|
||||
cmd = self._build_run_command(cfg)
|
||||
logger.info(f"🚀 启动容器: {cmd}")
|
||||
results.append(self.ssh.exec(cmd, timeout=_cfg('deploy_timeout', 300)))
|
||||
return results
|
||||
|
||||
def start(self, name: str) -> CommandResult: return self.ssh.exec(f"docker start {name}")
|
||||
def stop(self, name: str) -> CommandResult: return self.ssh.exec(f"docker stop {name}")
|
||||
def restart(self, name: str) -> CommandResult: return self.ssh.exec(f"docker restart {name}")
|
||||
|
||||
def remove(self, name: str, force: bool = True) -> CommandResult:
|
||||
return self.ssh.exec(f"docker rm {'-f' if force else ''} {name}")
|
||||
|
||||
def status(self, name: str) -> CommandResult:
|
||||
cmd = (
|
||||
f"docker inspect {name} "
|
||||
f"--format '{{{{.Name}}}} | {{{{.State.Status}}}} | "
|
||||
f"Started: {{{{.State.StartedAt}}}} | Image: {{{{.Config.Image}}}}'"
|
||||
f" 2>/dev/null || echo 'Container {name} not found'"
|
||||
)
|
||||
return self.ssh.exec(cmd)
|
||||
|
||||
def logs(self, name: str, tail: int | None = None) -> CommandResult:
|
||||
n = tail if tail is not None else _cfg('default_tail_lines', 100)
|
||||
logger.info(
|
||||
f"📋 获取日志: {name} tail={n} "
|
||||
f"[config.yaml default_tail_lines={_cfg('default_tail_lines')}]"
|
||||
)
|
||||
return self.ssh.exec(f"docker logs --tail={n} --timestamps {name} 2>&1")
|
||||
|
||||
def inspect(self, name: str) -> CommandResult:
|
||||
return self.ssh.exec(f"docker inspect {name}")
|
||||
|
||||
def stats(self, name: str) -> CommandResult:
|
||||
return self.ssh.exec(
|
||||
f"docker stats {name} --no-stream "
|
||||
f"--format 'table {{{{.Name}}}}\t{{{{.CPUPerc}}}}\t"
|
||||
f"{{{{.MemUsage}}}}\t{{{{.NetIO}}}}'"
|
||||
)
|
||||
|
||||
def compose_up(self, compose_file: str, detach: bool = True) -> CommandResult:
|
||||
work_dir = compose_file.rsplit("/", 1)[0] if "/" in compose_file else "."
|
||||
logger.info(f"🐙 Compose Up: {compose_file}")
|
||||
return self.ssh.exec(
|
||||
f"cd {work_dir} && docker compose -f {compose_file} "
|
||||
f"up {'-d' if detach else ''} --pull always",
|
||||
timeout=_cfg('deploy_timeout', 300),
|
||||
)
|
||||
|
||||
def compose_down(self, compose_file: str) -> CommandResult:
|
||||
work_dir = compose_file.rsplit("/", 1)[0] if "/" in compose_file else "."
|
||||
return self.ssh.exec(
|
||||
f"cd {work_dir} && docker compose -f {compose_file} down"
|
||||
)
|
||||
|
||||
def compose_ps(self, compose_file: str) -> CommandResult:
|
||||
work_dir = compose_file.rsplit("/", 1)[0] if "/" in compose_file else "."
|
||||
return self.ssh.exec(
|
||||
f"cd {work_dir} && docker compose -f {compose_file} ps"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_run_command(cfg: DeployConfig) -> str:
|
||||
"""
|
||||
构造 docker run 命令
|
||||
安全检查:config.yaml allow_privileged=false 时拒绝 --privileged
|
||||
"""
|
||||
if "--privileged" in cfg.extra_args and not _cfg('allow_privileged', False):
|
||||
logger.warning(
|
||||
"⚠️ 已移除 --privileged 参数\n"
|
||||
" 如需启用请在 config.yaml → "
|
||||
"tools.ssh_docker.allow_privileged 设置为 true"
|
||||
)
|
||||
cfg.extra_args = cfg.extra_args.replace("--privileged", "").strip()
|
||||
|
||||
parts = ["docker", "run", "-d", f"--name {cfg.container_name}"]
|
||||
if cfg.restart_policy:
|
||||
parts.append(f"--restart {cfg.restart_policy}")
|
||||
for p in cfg.ports:
|
||||
parts.append(f"-p {p}")
|
||||
for v in cfg.volumes:
|
||||
parts.append(f"-v {v}")
|
||||
for k, val in cfg.env_vars.items():
|
||||
safe_val = str(val).replace('"', '\\"')
|
||||
parts.append(f'-e {k}="{safe_val}"')
|
||||
if cfg.network:
|
||||
parts.append(f"--network {cfg.network}")
|
||||
if cfg.extra_args:
|
||||
parts.append(cfg.extra_args)
|
||||
parts.append(cfg.image)
|
||||
if cfg.command:
|
||||
parts.append(cfg.command)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 主工具类
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""
|
||||
SSH 远程 Docker 部署工具
|
||||
所有配置均通过 settings.tools['ssh_docker'][key] 读取
|
||||
"""
|
||||
|
||||
name = "ssh_docker"
|
||||
description = (
|
||||
"通过 SSH 连接到远程服务器,使用 Docker 部署和管理容器应用。"
|
||||
"支持: deploy | start | stop | restart | status | logs | "
|
||||
"remove | compose_up | compose_down | compose_ps | pull | inspect | stats"
|
||||
)
|
||||
parameters = {
|
||||
"host": {
|
||||
"type": "string",
|
||||
"description": "远程服务器 IP 或域名(与 server 参数二选一)",
|
||||
},
|
||||
"server": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"使用 config.yaml tools.ssh_docker.servers 中预设的服务器名称"
|
||||
"(与 host 二选一),例如 'prod' 或 'staging'"
|
||||
),
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"description": "SSH 用户名(不传则使用 config.yaml default_username)",
|
||||
},
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Docker 操作类型: deploy(部署)| start | stop | restart | "
|
||||
"status(查看状态)| logs(查看日志)| remove(删除)| "
|
||||
"compose_up | compose_down | compose_ps | pull | inspect | stats"
|
||||
),
|
||||
"enum": sorted(DockerExecutor.ALLOWED_ACTIONS),
|
||||
},
|
||||
"image": {
|
||||
"type": "string",
|
||||
"description": "Docker 镜像名称,例如 nginx:latest(deploy/pull 时必填)",
|
||||
},
|
||||
"container_name": {
|
||||
"type": "string",
|
||||
"description": "容器名称,例如 my-nginx",
|
||||
},
|
||||
"port": {
|
||||
"type": "integer",
|
||||
"description": "SSH 端口(不传则使用 config.yaml default_ssh_port)",
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"description": "SSH 密码(与 key_path 二选一)",
|
||||
},
|
||||
"key_path": {
|
||||
"type": "string",
|
||||
"description": "SSH 私钥路径,例如 /home/user/.ssh/id_rsa",
|
||||
},
|
||||
"ports": {
|
||||
"type": "string",
|
||||
"description": "端口映射,逗号分隔,例如 '8080:80,443:443'",
|
||||
},
|
||||
"volumes": {
|
||||
"type": "string",
|
||||
"description": "数据卷挂载,逗号分隔,例如 '/data:/app/data,/logs:/var/log'",
|
||||
},
|
||||
"env_vars": {
|
||||
"type": "string",
|
||||
"description": "环境变量 JSON 字符串,例如 '{\"DB_HOST\":\"localhost\",\"DB_PORT\":\"5432\"}'",
|
||||
},
|
||||
"network": {
|
||||
"type": "string",
|
||||
"description": "Docker 网络名称,例如 bridge 或 my-network",
|
||||
},
|
||||
"restart_policy": {
|
||||
"type": "string",
|
||||
"description": "重启策略(不传则使用 config.yaml default_restart_policy): no | always | unless-stopped | on-failure",
|
||||
},
|
||||
"compose_file": {
|
||||
"type": "string",
|
||||
"description": "docker-compose.yml 在远程服务器上的绝对路径",
|
||||
},
|
||||
"pull_latest": {
|
||||
"type": "boolean",
|
||||
"description": "部署前是否拉取最新镜像,默认 true",
|
||||
},
|
||||
"tail_lines": {
|
||||
"type": "integer",
|
||||
"description": "查看日志时返回的行数(不传则使用 config.yaml default_tail_lines)",
|
||||
},
|
||||
"extra_args": {
|
||||
"type": "string",
|
||||
"description": "传递给 docker run 的额外参数,例如 '--memory=512m --cpus=1'",
|
||||
},
|
||||
}
|
||||
|
||||
def execute(self, **kwargs) -> str:
|
||||
# ── 解析参数,缺省值全部来自 config.yaml ──────────────
|
||||
action = kwargs.get("action", "status").lower()
|
||||
image = kwargs.get("image", "")
|
||||
container_name = kwargs.get("container_name", "")
|
||||
ports_str = kwargs.get("ports", "")
|
||||
volumes_str = kwargs.get("volumes", "")
|
||||
env_vars_str = kwargs.get("env_vars", "{}")
|
||||
network = kwargs.get("network", "")
|
||||
restart_policy = kwargs.get("restart_policy", "") # 空→由 DeployConfig.__post_init__ 填充
|
||||
compose_file = kwargs.get("compose_file", "")
|
||||
pull_latest = bool(kwargs.get("pull_latest", True))
|
||||
tail_lines_raw = kwargs.get("tail_lines", None)
|
||||
tail_lines = int(tail_lines_raw) if tail_lines_raw is not None else None
|
||||
extra_args = kwargs.get("extra_args", "")
|
||||
|
||||
logger.info(
|
||||
f"🐳 SSH Docker 操作启动\n"
|
||||
f" 操作 : {action}\n"
|
||||
f" 容器 : {container_name or '(未指定)'}\n"
|
||||
f" 镜像 : {image or '(未指定)'}\n"
|
||||
f" server预设: {kwargs.get('server', '(无)')} "
|
||||
f"host: {kwargs.get('host', '(无)')}\n"
|
||||
f" deploy_timeout : {_cfg('deploy_timeout')}s "
|
||||
f"[config.yaml]\n"
|
||||
f" allow_privileged: {_cfg('allow_privileged')} "
|
||||
f"[config.yaml]"
|
||||
)
|
||||
|
||||
# ── 参数校验 ──────────────────────────────────────────
|
||||
err = self._validate(kwargs, action, image, container_name, compose_file)
|
||||
if err:
|
||||
return err
|
||||
|
||||
# ── 解析复合参数 ──────────────────────────────────────
|
||||
ports = [p.strip() for p in ports_str.split(",") if p.strip()]
|
||||
volumes = [v.strip() for v in volumes_str.split(",") if v.strip()]
|
||||
try:
|
||||
env_vars: dict = json.loads(env_vars_str) if env_vars_str.strip() else {}
|
||||
except json.JSONDecodeError:
|
||||
return f"❌ env_vars 格式错误,请使用 JSON 格式: {env_vars_str}"
|
||||
|
||||
# ── 镜像黑名单检查(来自 config.yaml blocked_images)──
|
||||
if image:
|
||||
blocked = _cfg('blocked_images', [])
|
||||
if any(image.startswith(b) for b in blocked):
|
||||
return (
|
||||
f"❌ 安全限制: 镜像 '{image}' 在黑名单中\n"
|
||||
f" 黑名单: {blocked}\n"
|
||||
f" 请在 config.yaml → tools.ssh_docker.blocked_images 中移除"
|
||||
)
|
||||
|
||||
# ── 构造配置对象 ──────────────────────────────────────
|
||||
try:
|
||||
ssh_cfg = SSHConfig.from_kwargs(kwargs)
|
||||
except ValueError as e:
|
||||
return f"❌ SSH 配置错误: {e}"
|
||||
|
||||
deploy_cfg = DeployConfig(
|
||||
image=image,
|
||||
container_name=container_name,
|
||||
action=action,
|
||||
ports=ports,
|
||||
volumes=volumes,
|
||||
env_vars=env_vars,
|
||||
network=network,
|
||||
restart_policy=restart_policy,
|
||||
compose_file=compose_file,
|
||||
pull_latest=pull_latest,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
# ── 执行操作 ──────────────────────────────────────────
|
||||
try:
|
||||
with SSHManager(ssh_cfg) as ssh:
|
||||
executor = DockerExecutor(ssh)
|
||||
return self._dispatch(action, executor, deploy_cfg, tail_lines)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"❌ SSH Docker 操作失败: {error_msg}")
|
||||
return self._format_error(action, ssh_cfg.host, error_msg)
|
||||
|
||||
# ── 操作分发 ──────────────────────────────────────────────
|
||||
|
||||
def _dispatch(
|
||||
self,
|
||||
action: str,
|
||||
executor: DockerExecutor,
|
||||
cfg: DeployConfig,
|
||||
tail_lines: int | None,
|
||||
) -> str:
|
||||
# 先检查 Docker 环境
|
||||
check = executor.check_docker()
|
||||
if not check.success:
|
||||
return (
|
||||
f"❌ 远程服务器 Docker 不可用\n"
|
||||
f" 错误: {check.stderr[:200]}\n"
|
||||
f" 请确认 Docker 已安装并运行: sudo systemctl start docker"
|
||||
)
|
||||
|
||||
match action:
|
||||
case "deploy":
|
||||
return self._do_deploy(executor, cfg)
|
||||
case "start":
|
||||
return self._fmt_single(executor.start(cfg.container_name), "start")
|
||||
case "stop":
|
||||
return self._fmt_single(executor.stop(cfg.container_name), "stop")
|
||||
case "restart":
|
||||
return self._fmt_single(executor.restart(cfg.container_name), "restart")
|
||||
case "status":
|
||||
return self._do_status(executor, cfg.container_name)
|
||||
case "logs":
|
||||
return self._do_logs(executor, cfg.container_name, tail_lines)
|
||||
case "remove":
|
||||
return self._fmt_single(executor.remove(cfg.container_name), "remove")
|
||||
case "pull":
|
||||
return self._fmt_single(executor.pull_image(cfg.image), "pull")
|
||||
case "inspect":
|
||||
return self._do_inspect(executor, cfg.container_name)
|
||||
case "stats":
|
||||
return self._fmt_single(executor.stats(cfg.container_name), "stats")
|
||||
case "compose_up":
|
||||
return self._fmt_single(executor.compose_up(cfg.compose_file), "compose_up")
|
||||
case "compose_down":
|
||||
return self._fmt_single(executor.compose_down(cfg.compose_file), "compose_down")
|
||||
case "compose_ps":
|
||||
return self._fmt_single(executor.compose_ps(cfg.compose_file), "compose_ps")
|
||||
case _:
|
||||
return f"❌ 不支持的操作: {action}"
|
||||
|
||||
def _do_deploy(self, executor: DockerExecutor, cfg: DeployConfig) -> str:
|
||||
if cfg.compose_file:
|
||||
result = executor.compose_up(cfg.compose_file)
|
||||
icon = "✅" if result.success else "❌"
|
||||
return (
|
||||
f"{icon} Compose 部署{'成功' if result.success else '失败'}\n"
|
||||
f"{'─' * 50}\n"
|
||||
f" Compose 文件: {cfg.compose_file}\n"
|
||||
f"{'─' * 50}\n"
|
||||
f"{result.output[:1500]}"
|
||||
)
|
||||
results = executor.deploy(cfg)
|
||||
return self._fmt_deploy(results, cfg)
|
||||
|
||||
def _do_status(self, executor: DockerExecutor, container_name: str) -> str:
|
||||
status_r = executor.status(container_name)
|
||||
stats_r = executor.stats(container_name)
|
||||
lines = [
|
||||
f"📊 容器状态: {container_name}",
|
||||
"─" * 50,
|
||||
status_r.output or "容器不存在或未运行",
|
||||
]
|
||||
if stats_r.success and stats_r.output:
|
||||
lines += ["", "📈 资源使用:", stats_r.output]
|
||||
return "\n".join(lines)
|
||||
|
||||
def _do_logs(
|
||||
self, executor: DockerExecutor, container_name: str, tail: int | None
|
||||
) -> str:
|
||||
result = executor.logs(container_name, tail)
|
||||
n = tail if tail is not None else _cfg('default_tail_lines', 100)
|
||||
if result.success:
|
||||
return (
|
||||
f"📋 容器日志: {container_name} (最近 {n} 行)\n"
|
||||
f"{'─' * 50}\n"
|
||||
f"{result.output or '(无日志输出)'}"
|
||||
)
|
||||
return f"❌ 获取日志失败: {result.stderr[:300]}"
|
||||
|
||||
def _do_inspect(self, executor: DockerExecutor, container_name: str) -> str:
|
||||
result = executor.inspect(container_name)
|
||||
if result.success:
|
||||
try:
|
||||
data = json.loads(result.stdout)
|
||||
if data:
|
||||
c = data[0]
|
||||
info = {
|
||||
"Name": c.get("Name", ""),
|
||||
"Status": c.get("State", {}).get("Status", ""),
|
||||
"Image": c.get("Config", {}).get("Image", ""),
|
||||
"Ports": c.get("NetworkSettings", {}).get("Ports", {}),
|
||||
"Mounts": [m.get("Source") for m in c.get("Mounts", [])],
|
||||
"Created": c.get("Created", ""),
|
||||
}
|
||||
return (
|
||||
f"🔍 容器详情: {container_name}\n"
|
||||
f"{'─' * 50}\n"
|
||||
f"{json.dumps(info, ensure_ascii=False, indent=2)}"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return f"❌ 获取容器详情失败: {result.stderr[:300]}"
|
||||
|
||||
# ── 格式化输出 ─────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _fmt_deploy(results: list[CommandResult], cfg: DeployConfig) -> str:
|
||||
lines = [
|
||||
"🚀 容器部署结果",
|
||||
"─" * 50,
|
||||
f" 镜像 : {cfg.image}",
|
||||
f" 容器名 : {cfg.container_name}",
|
||||
f" 端口 : {', '.join(cfg.ports) or '(无)'}",
|
||||
f" 数据卷 : {', '.join(cfg.volumes) or '(无)'}",
|
||||
f" 重启策略: {cfg.restart_policy} "
|
||||
f"[config.yaml default_restart_policy="
|
||||
f"{_cfg('default_restart_policy')}]",
|
||||
"─" * 50,
|
||||
]
|
||||
all_ok = True
|
||||
for r in results:
|
||||
icon = "✅" if r.success else "❌"
|
||||
lines.append(f" {icon} $ {r.command[:70]}")
|
||||
if r.output:
|
||||
lines.append(f" └─ {r.output[:150]}")
|
||||
if not r.success:
|
||||
all_ok = False
|
||||
lines.append(f" └─ 错误: {r.stderr[:150]}")
|
||||
lines.append("─" * 50)
|
||||
lines.append(
|
||||
f"✅ 部署成功!容器 [{cfg.container_name}] 已启动"
|
||||
if all_ok else
|
||||
"⚠️ 部署过程中有步骤失败,请检查上方错误信息"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _fmt_single(result: CommandResult, action: str) -> str:
|
||||
icon = "✅" if result.success else "❌"
|
||||
status = "成功" if result.success else "失败"
|
||||
return (
|
||||
f"{icon} {action} {status}\n"
|
||||
f"{'─' * 40}\n"
|
||||
f"{result.output[:500] or '(无输出)'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_error(action: str, host: str, error: str) -> str:
|
||||
lines = [
|
||||
f"❌ SSH Docker [{action}] 操作失败",
|
||||
"─" * 50,
|
||||
f" 服务器: {host}",
|
||||
f" 错误 : {error}",
|
||||
"─" * 50,
|
||||
"💡 排查建议:",
|
||||
]
|
||||
el = error.lower()
|
||||
if "authentication" in el or "auth" in el:
|
||||
lines += [
|
||||
" • 检查用户名/密码是否正确",
|
||||
" • 检查 SSH 密钥路径和权限(chmod 600 ~/.ssh/id_rsa)",
|
||||
" • 或在 config.yaml tools.ssh_docker.servers 中配置预设",
|
||||
]
|
||||
elif "connection" in el or "timed out" in el:
|
||||
lines += [
|
||||
" • 检查服务器 IP 和 SSH 端口是否正确",
|
||||
" • 检查防火墙是否开放 SSH 端口",
|
||||
f" • config.yaml connect_timeout={_cfg('connect_timeout')}s,可适当增大",
|
||||
]
|
||||
elif "docker" in el:
|
||||
lines += [
|
||||
" • 确认 Docker 已安装: docker --version",
|
||||
" • 确认 Docker 服务运行: sudo systemctl start docker",
|
||||
" • 确认用户有 Docker 权限: sudo usermod -aG docker $USER",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
# ── 参数校验 ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _validate(
|
||||
kwargs: dict,
|
||||
action: str,
|
||||
image: str,
|
||||
container_name: str,
|
||||
compose_file: str,
|
||||
) -> str | None:
|
||||
# 必须提供 host 或 server 之一
|
||||
if not kwargs.get("host") and not kwargs.get("server"):
|
||||
return "❌ 参数错误: 必须提供 host(服务器地址)或 server(预设名称)之一"
|
||||
|
||||
if action not in DockerExecutor.ALLOWED_ACTIONS:
|
||||
return (
|
||||
f"❌ 不支持的操作: {action}\n"
|
||||
f" 可选值: {', '.join(sorted(DockerExecutor.ALLOWED_ACTIONS))}"
|
||||
)
|
||||
|
||||
# 主机白名单检查(来自 config.yaml allowed_hosts)
|
||||
host = kwargs.get("host", "")
|
||||
allowed_hosts = _cfg('allowed_hosts', [])
|
||||
if host and allowed_hosts and host not in allowed_hosts:
|
||||
return (
|
||||
f"❌ 安全限制: 服务器 '{host}' 不在白名单中\n"
|
||||
f" 白名单: {allowed_hosts}\n"
|
||||
f" 请在 config.yaml → tools.ssh_docker.allowed_hosts 中添加"
|
||||
)
|
||||
|
||||
if action == "deploy" and not image and not compose_file:
|
||||
return "❌ deploy 操作需要指定 image(镜像名)或 compose_file(Compose 文件路径)"
|
||||
|
||||
needs_container = {
|
||||
"start", "stop", "restart", "logs", "remove", "inspect", "stats"
|
||||
}
|
||||
if action in needs_container and not container_name:
|
||||
return f"❌ {action} 操作需要指定 container_name(容器名称)"
|
||||
|
||||
needs_compose = {"compose_up", "compose_down", "compose_ps"}
|
||||
if action in needs_compose and not compose_file:
|
||||
return f"❌ {action} 操作需要指定 compose_file(docker-compose.yml 路径)"
|
||||
|
||||
return None
|
||||
|
|
@ -1,483 +0,0 @@
|
|||
"""
|
||||
tools/static_analyzer.py
|
||||
C/C++ 静态分析工具 —— 所有配置通过 settings.tools['static_analyzer'][key] 获取
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from config.settings import settings
|
||||
from tools.base_tool import BaseTool
|
||||
from utils.logger import get_logger
|
||||
|
||||
logger = get_logger("TOOL.StaticAnalyzer")
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 配置访问快捷函数(统一入口,便于调试)
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
def _cfg(key: str, fallback=None):
|
||||
"""读取 static_analyzer 工具配置,不存在时返回 fallback"""
|
||||
return settings.tools['static_analyzer'].get(key, fallback)
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 数据结构
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
@dataclass
|
||||
class AnalysisIssue:
|
||||
file: str
|
||||
line: int
|
||||
column: int
|
||||
severity: str # error | warning | style | performance | information
|
||||
rule_id: str
|
||||
message: str
|
||||
tool: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"file": self.file, "line": self.line, "column": self.column,
|
||||
"severity": self.severity, "rule_id": self.rule_id,
|
||||
"message": self.message, "tool": self.tool,
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"[{self.severity.upper():12s}] {self.file}:{self.line}:{self.column}"
|
||||
f" ({self.rule_id}) {self.message}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
project_dir: str
|
||||
tool: str
|
||||
success: bool
|
||||
issues: list[AnalysisIssue] = field(default_factory=list)
|
||||
raw_output: str = ""
|
||||
error: str = ""
|
||||
elapsed_sec: float = 0.0
|
||||
|
||||
@property
|
||||
def error_count(self) -> int: return sum(1 for i in self.issues if i.severity == "error")
|
||||
@property
|
||||
def warning_count(self) -> int: return sum(1 for i in self.issues if i.severity == "warning")
|
||||
@property
|
||||
def style_count(self) -> int: return sum(1 for i in self.issues if i.severity in ("style", "performance"))
|
||||
@property
|
||||
def total_count(self) -> int: return len(self.issues)
|
||||
|
||||
def summary(self) -> str:
|
||||
max_show = min(20, _cfg('max_issues', 500))
|
||||
if not self.success:
|
||||
return f"❌ 分析失败: {self.error}"
|
||||
lines = [
|
||||
f"📊 静态分析完成 [{self.tool}] 耗时: {self.elapsed_sec:.1f}s",
|
||||
f" 工程目录 : {self.project_dir}",
|
||||
f" 问题总计 : {self.total_count} 条",
|
||||
f" ├─ 错误 (error) : {self.error_count} 条",
|
||||
f" ├─ 警告 (warning): {self.warning_count} 条",
|
||||
f" └─ 风格 (style) : {self.style_count} 条",
|
||||
]
|
||||
if self.issues:
|
||||
lines.append(f"\n📋 问题详情(最多显示 {max_show} 条):")
|
||||
for issue in self.issues[:max_show]:
|
||||
lines.append(f" {issue}")
|
||||
if self.total_count > max_show:
|
||||
lines.append(f" ... 还有 {self.total_count - max_show} 条")
|
||||
else:
|
||||
lines.append(" ✅ 未发现任何问题!")
|
||||
return "\n".join(lines)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
max_issues = _cfg('max_issues', 500)
|
||||
return {
|
||||
"project_dir": self.project_dir,
|
||||
"tool": self.tool,
|
||||
"success": self.success,
|
||||
"elapsed_sec": round(self.elapsed_sec, 2),
|
||||
"stats": {
|
||||
"total": self.total_count,
|
||||
"error": self.error_count,
|
||||
"warning": self.warning_count,
|
||||
"style": self.style_count,
|
||||
},
|
||||
"issues": [i.to_dict() for i in self.issues[:max_issues]],
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 各工具解析器
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
class CppcheckParser:
|
||||
SEVERITY_MAP = {
|
||||
"error": "error", "warning": "warning", "style": "style",
|
||||
"performance": "performance", "portability": "style",
|
||||
"information": "information",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def build_command(cls, project_dir: str, standard: str, extra_args: str) -> list[str]:
|
||||
jobs = _cfg('jobs', 4)
|
||||
cfg_extra = _cfg('tool_extra_args', {}).get('cppcheck', '')
|
||||
full_args = f"{cfg_extra} {extra_args}".strip()
|
||||
|
||||
cmd = [
|
||||
"cppcheck",
|
||||
"--enable=all",
|
||||
"--xml", "--xml-version=2",
|
||||
f"--std={standard}",
|
||||
f"-j{jobs}",
|
||||
]
|
||||
if full_args:
|
||||
cmd.extend(full_args.split())
|
||||
cmd.append(project_dir)
|
||||
return cmd
|
||||
|
||||
@classmethod
|
||||
def parse(cls, output: str, tool: str = "cppcheck") -> list[AnalysisIssue]:
|
||||
issues: list[AnalysisIssue] = []
|
||||
try:
|
||||
import xml.etree.ElementTree as ET
|
||||
root = ET.fromstring(output)
|
||||
for error in root.iter("error"):
|
||||
severity = cls.SEVERITY_MAP.get(error.get("severity", "warning"), "warning")
|
||||
rule_id = error.get("id", "unknown")
|
||||
message = error.get("msg", "")
|
||||
loc = error.find("location")
|
||||
if loc is not None:
|
||||
file_path = loc.get("file", "unknown")
|
||||
line = int(loc.get("line", 0))
|
||||
column = int(loc.get("column", 0))
|
||||
else:
|
||||
file_path, line, column = "unknown", 0, 0
|
||||
issues.append(AnalysisIssue(
|
||||
file=file_path, line=line, column=column,
|
||||
severity=severity, rule_id=rule_id,
|
||||
message=message, tool=tool,
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ XML 解析失败,回退文本解析: {e}")
|
||||
issues = cls._parse_text(output, tool)
|
||||
return issues
|
||||
|
||||
@staticmethod
|
||||
def _parse_text(output: str, tool: str) -> list[AnalysisIssue]:
|
||||
issues = []
|
||||
pattern = re.compile(
|
||||
r"^(.+?):(\d+):(\d+):\s+(error|warning|style|performance|information):\s+"
|
||||
r"(.+?)(?:\s+\[(\w+)\])?$", re.MULTILINE,
|
||||
)
|
||||
for m in pattern.finditer(output):
|
||||
issues.append(AnalysisIssue(
|
||||
file=m.group(1), line=int(m.group(2)), column=int(m.group(3)),
|
||||
severity=m.group(4), rule_id=m.group(6) or "unknown",
|
||||
message=m.group(5), tool=tool,
|
||||
))
|
||||
return issues
|
||||
|
||||
|
||||
class ClangTidyParser:
|
||||
@classmethod
|
||||
def build_command(cls, project_dir: str, standard: str, extra_args: str) -> list[str]:
|
||||
cfg_extra = _cfg('tool_extra_args', {}).get('clang-tidy', '')
|
||||
full_extra = f"{cfg_extra} {extra_args}".strip()
|
||||
|
||||
# 从 extra 中提取 --checks 值
|
||||
m = re.search(r"--checks=(\S+)", full_extra)
|
||||
checks = m.group(1) if m else "*"
|
||||
|
||||
if shutil.which("run-clang-tidy"):
|
||||
cmd = [
|
||||
"run-clang-tidy",
|
||||
f"-checks={checks}",
|
||||
"-p", os.path.join(project_dir, "build"),
|
||||
]
|
||||
else:
|
||||
cmd = ["clang-tidy"]
|
||||
if full_extra:
|
||||
cmd.extend(full_extra.split())
|
||||
src_files = []
|
||||
for ext in ("*.cpp", "*.c", "*.cc", "*.cxx"):
|
||||
src_files.extend(Path(project_dir).rglob(ext))
|
||||
cmd.extend(str(f) for f in src_files[:50])
|
||||
return cmd
|
||||
|
||||
@classmethod
|
||||
def parse(cls, output: str, tool: str = "clang-tidy") -> list[AnalysisIssue]:
|
||||
issues = []
|
||||
pattern = re.compile(
|
||||
r"^(.+?):(\d+):(\d+):\s+(error|warning|note):\s+(.+?)(?:\s+\[([\w\-\.]+)\])?$",
|
||||
re.MULTILINE,
|
||||
)
|
||||
for m in pattern.finditer(output):
|
||||
if m.group(4) == "note":
|
||||
continue
|
||||
issues.append(AnalysisIssue(
|
||||
file=m.group(1), line=int(m.group(2)), column=int(m.group(3)),
|
||||
severity=m.group(4), rule_id=m.group(6) or "unknown",
|
||||
message=m.group(5), tool=tool,
|
||||
))
|
||||
return issues
|
||||
|
||||
|
||||
class InferParser:
|
||||
@classmethod
|
||||
def build_command(cls, project_dir: str, standard: str, extra_args: str) -> list[str]:
|
||||
cfg_extra = _cfg('tool_extra_args', {}).get('infer', '')
|
||||
full_extra = f"{cfg_extra} {extra_args}".strip()
|
||||
cmd = [
|
||||
"infer", "run",
|
||||
"--results-dir", os.path.join(project_dir, "infer-out"),
|
||||
]
|
||||
if full_extra:
|
||||
cmd.extend(full_extra.split())
|
||||
cmd += ["--", "make", "-C", project_dir]
|
||||
return cmd
|
||||
|
||||
@classmethod
|
||||
def parse(cls, output: str, tool: str = "infer") -> list[AnalysisIssue]:
|
||||
issues = []
|
||||
try:
|
||||
data = json.loads(output)
|
||||
for item in data:
|
||||
issues.append(AnalysisIssue(
|
||||
file=item.get("file", "unknown"),
|
||||
line=item.get("line", 0),
|
||||
column=0,
|
||||
severity="error" if item.get("severity") == "ERROR" else "warning",
|
||||
rule_id=item.get("bug_type", "unknown"),
|
||||
message=item.get("qualifier", ""),
|
||||
tool=tool,
|
||||
))
|
||||
except json.JSONDecodeError:
|
||||
pattern = re.compile(r"(.+\.(?:cpp|c|cc|h)):(\d+):\s+(?:error|warning):\s+(.+)")
|
||||
for m in pattern.finditer(output):
|
||||
issues.append(AnalysisIssue(
|
||||
file=m.group(1), line=int(m.group(2)), column=0,
|
||||
severity="warning", rule_id="infer",
|
||||
message=m.group(3), tool=tool,
|
||||
))
|
||||
return issues
|
||||
|
||||
|
||||
_TOOL_REGISTRY: dict[str, type] = {
|
||||
"cppcheck": CppcheckParser,
|
||||
"clang-tidy": ClangTidyParser,
|
||||
"infer": InferParser,
|
||||
}
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
# 主工具类
|
||||
# ════════════════════════════════════════════════════════════════
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""
|
||||
C/C++ 静态分析工具
|
||||
所有配置均通过 settings.tools['static_analyzer'][key] 读取
|
||||
"""
|
||||
|
||||
name = "static_analyzer"
|
||||
description = (
|
||||
"对指定目录下的 C/C++ 工程调用外部静态分析工具(cppcheck/clang-tidy/infer)"
|
||||
"进行代码质量检查,返回错误、警告及代码风格问题"
|
||||
)
|
||||
parameters = {
|
||||
"project_dir": {
|
||||
"type": "string",
|
||||
"description": "C/C++ 工程根目录的绝对路径,例如 /home/user/myproject",
|
||||
},
|
||||
"tool": {
|
||||
"type": "string",
|
||||
"description": "静态分析工具: cppcheck(默认)| clang-tidy | infer",
|
||||
"enum": ["cppcheck", "clang-tidy", "infer"],
|
||||
},
|
||||
"standard": {
|
||||
"type": "string",
|
||||
"description": "C/C++ 语言标准: c89 | c99 | c11 | c++11 | c++14 | c++17 | c++20",
|
||||
},
|
||||
"extra_args": {
|
||||
"type": "string",
|
||||
"description": "额外命令行参数(追加到 config.yaml tool_extra_args 之后)",
|
||||
},
|
||||
"output_format": {
|
||||
"type": "string",
|
||||
"description": "输出格式: summary(默认)| json | full",
|
||||
"enum": ["summary", "json", "full"],
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "分析超时秒数(不传则使用 config.yaml 中的 timeout)",
|
||||
},
|
||||
}
|
||||
|
||||
def execute(self, **kwargs) -> str:
|
||||
# ── 读取参数,未提供时使用 config.yaml 中的默认值 ──────
|
||||
project_dir = kwargs.get("project_dir", "")
|
||||
tool_name = kwargs.get("tool", _cfg('default_tool', 'cppcheck')).lower()
|
||||
standard = kwargs.get("standard", _cfg('default_std', 'c++17'))
|
||||
extra_args = kwargs.get("extra_args", "")
|
||||
output_format = kwargs.get("output_format", _cfg('output_format', 'summary'))
|
||||
timeout = int(kwargs.get("timeout", _cfg('timeout', 120)))
|
||||
|
||||
logger.info(
|
||||
f"🔍 静态分析启动\n"
|
||||
f" 工程目录 : {project_dir}\n"
|
||||
f" 分析工具 : {tool_name} "
|
||||
f"[config default_tool={_cfg('default_tool')}]\n"
|
||||
f" 语言标准 : {standard} "
|
||||
f"[config default_std={_cfg('default_std')}]\n"
|
||||
f" 超时 : {timeout}s "
|
||||
f"[config timeout={_cfg('timeout')}s]\n"
|
||||
f" 并行数 : {_cfg('jobs')} "
|
||||
f"[config jobs={_cfg('jobs')}]\n"
|
||||
f" 最大问题数: {_cfg('max_issues')}"
|
||||
)
|
||||
|
||||
# ── 参数校验 ──────────────────────────────────────────
|
||||
err = self._validate(project_dir, tool_name)
|
||||
if err:
|
||||
return err
|
||||
|
||||
# ── 构造并执行命令 ────────────────────────────────────
|
||||
parser_cls = _TOOL_REGISTRY[tool_name]
|
||||
try:
|
||||
cmd = parser_cls.build_command(project_dir, standard, extra_args)
|
||||
except Exception as e:
|
||||
return f"❌ 构造分析命令失败: {e}"
|
||||
|
||||
logger.info(f"🚀 执行命令: {' '.join(cmd)}")
|
||||
result = self._run_command(cmd, project_dir, timeout, tool_name)
|
||||
|
||||
# 截断超过 max_issues 的问题
|
||||
max_issues = _cfg('max_issues', 500)
|
||||
if len(result.issues) > max_issues:
|
||||
logger.info(f"⚠️ 问题数 {len(result.issues)} 超过上限 {max_issues},已截断")
|
||||
result.issues = result.issues[:max_issues]
|
||||
|
||||
return self._format_output(result, output_format)
|
||||
|
||||
# ── 私有方法 ──────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _validate(project_dir: str, tool_name: str) -> str | None:
|
||||
if not project_dir:
|
||||
return "❌ 参数错误: project_dir 不能为空"
|
||||
|
||||
path = Path(project_dir)
|
||||
if not path.exists():
|
||||
return f"❌ 目录不存在: {project_dir}"
|
||||
if not path.is_dir():
|
||||
return f"❌ 路径不是目录: {project_dir}"
|
||||
|
||||
# 白名单校验(来自 config.yaml allowed_roots)
|
||||
allowed_roots = _cfg('allowed_roots', [])
|
||||
if allowed_roots and not any(
|
||||
project_dir.startswith(r) for r in allowed_roots
|
||||
):
|
||||
return (
|
||||
f"❌ 安全限制: {project_dir} 不在白名单中\n"
|
||||
f" 白名单: {allowed_roots}\n"
|
||||
f" 请在 config.yaml → tools.static_analyzer.allowed_roots 中添加"
|
||||
)
|
||||
|
||||
# 检查是否包含 C/C++ 源文件
|
||||
src_files = (
|
||||
list(path.rglob("*.cpp")) + list(path.rglob("*.c")) +
|
||||
list(path.rglob("*.cc")) + list(path.rglob("*.h"))
|
||||
)
|
||||
if not src_files:
|
||||
return f"❌ 目录中未找到 C/C++ 源文件: {project_dir}"
|
||||
|
||||
if tool_name not in _TOOL_REGISTRY:
|
||||
return (
|
||||
f"❌ 不支持的分析工具: {tool_name}\n"
|
||||
f" 可选值: {', '.join(_TOOL_REGISTRY.keys())}"
|
||||
)
|
||||
|
||||
exe = "run-clang-tidy" if tool_name == "clang-tidy" else tool_name
|
||||
if not shutil.which(exe) and not shutil.which(tool_name):
|
||||
return (
|
||||
f"❌ 分析工具未安装: {tool_name}\n"
|
||||
f" 安装方式:\n"
|
||||
f" cppcheck : sudo apt install cppcheck\n"
|
||||
f" clang-tidy: sudo apt install clang-tidy\n"
|
||||
f" infer : https://fbinfer.com/docs/getting-started"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _run_command(
|
||||
cmd: list[str], project_dir: str, timeout: int, tool_name: str,
|
||||
) -> AnalysisResult:
|
||||
start = time.time()
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
cmd, cwd=project_dir,
|
||||
capture_output=True, text=True,
|
||||
timeout=timeout, encoding="utf-8", errors="replace",
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
raw_output = proc.stderr if proc.stderr.strip() else proc.stdout
|
||||
logger.debug(f"📄 原始输出(前 500 字符):\n{raw_output[:500]}")
|
||||
|
||||
parser_cls = _TOOL_REGISTRY[tool_name]
|
||||
issues = parser_cls.parse(raw_output, tool_name)
|
||||
|
||||
if tool_name == "infer":
|
||||
report_path = Path(project_dir) / "infer-out" / "report.json"
|
||||
if report_path.exists():
|
||||
issues = InferParser.parse(
|
||||
report_path.read_text(encoding="utf-8"), "infer"
|
||||
)
|
||||
|
||||
logger.info(f"✅ 分析完成: {len(issues)} 个问题,耗时 {elapsed:.1f}s")
|
||||
return AnalysisResult(
|
||||
project_dir=project_dir, tool=tool_name,
|
||||
success=True, issues=issues,
|
||||
raw_output=raw_output, elapsed_sec=elapsed,
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
elapsed = time.time() - start
|
||||
msg = (
|
||||
f"分析超时(>{timeout}s)\n"
|
||||
f" 请增大 config.yaml → tools.static_analyzer.timeout"
|
||||
)
|
||||
logger.error(f"⏰ {msg}")
|
||||
return AnalysisResult(
|
||||
project_dir=project_dir, tool=tool_name,
|
||||
success=False, error=msg, elapsed_sec=elapsed,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return AnalysisResult(
|
||||
project_dir=project_dir, tool=tool_name,
|
||||
success=False, error=f"命令未找到: {cmd[0]}",
|
||||
)
|
||||
except Exception as e:
|
||||
return AnalysisResult(
|
||||
project_dir=project_dir, tool=tool_name,
|
||||
success=False, error=str(e),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_output(result: AnalysisResult, fmt: str) -> str:
|
||||
if fmt == "json":
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False, indent=2)
|
||||
if fmt == "full":
|
||||
return (
|
||||
f"{result.summary()}\n\n{'─' * 60}\n"
|
||||
f"📄 原始输出:\n{result.raw_output[:3000]}"
|
||||
)
|
||||
return result.summary()
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from typing import List
|
||||
from agent.device import controller
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
name = "uav_control"
|
||||
description = (
|
||||
"该工具用于对无人系统进行控制并在飞行控制中完成一系列的活动(Action),例如起飞、移动、悬停、照相等"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "无人系统飞行控制命令,例如arm、disarm、takeoff、rtl、goto等",
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "控制命令参数"
|
||||
},
|
||||
"actions": {
|
||||
"type": "array",
|
||||
"description": "过程中执行的一些列的动作,例如pause、capture_phone、standby等"
|
||||
}
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
|
||||
def execute(self, command: str, params: dict={}, actions: List[object]=()) -> ToolResult:
|
||||
if not controller.connect():
|
||||
return ToolResult(success=False, output="连接无人机失败")
|
||||
try:
|
||||
command_handler = getattr(controller, command)
|
||||
try:
|
||||
return ToolResult(success=True, output=str(command_handler(**params)))
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, output=str(e))
|
||||
finally:
|
||||
controller.disconnect()
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import json
|
||||
from dataclasses import asdict
|
||||
from typing import List
|
||||
|
||||
from agent.config.settings import settings
|
||||
from agent.device import controller
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from core.uas_control import command_registry
|
||||
from core.uas_control.controllers.robot_dog_controller import RobotDogController
|
||||
from core.uas_control.controllers.uav_controller import UAVController
|
||||
from core.uas_control.protocols.mavlink_adapter import MAVLinkAdapter
|
||||
from core.uas_control.protocols.ros_adapter import ROSAdapter
|
||||
from core.uas_control.protocols.simulation_adapter import SimulationAdapter
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
name = "uav_get_state"
|
||||
description = (
|
||||
"获取无人系统当前状态, 例如当前位置等信息"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
},
|
||||
"required": [""],
|
||||
}
|
||||
|
||||
def execute(self) -> ToolResult:
|
||||
if not controller.connect():
|
||||
return ToolResult(success=False, output="连接无人机失败")
|
||||
try:
|
||||
telemetry = controller.get_telemetry()
|
||||
return ToolResult(success=True, output=json.dumps(telemetry.to_dict(), ensure_ascii=True))
|
||||
finally:
|
||||
controller.disconnect()
|
||||
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
import time
|
||||
from pymavlink import mavutil
|
||||
from pymavlink.dialects.v20.all import MAVLink_sys_status_message
|
||||
|
||||
from agent.device import controller
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
name = "uav_self_check"
|
||||
description = ("对无人机设备进行使用前自检")
|
||||
parameters = {}
|
||||
|
||||
def execute(self, **kwargs) -> ToolResult:
|
||||
"""
|
||||
MAVLink 飞控系统健康自检
|
||||
解决:Arming denied: Resolve system health failures first
|
||||
"""
|
||||
# 等待健康数据
|
||||
|
||||
assert controller.connect(), "连接应成功"
|
||||
try:
|
||||
connection = controller.adapter.get_connection()
|
||||
check_list = []
|
||||
|
||||
health = connection.wait_heartbeat(timeout=5)
|
||||
if not health:
|
||||
return ToolResult(success=False, output="❌ 无法获取飞控状态,连接异常")
|
||||
|
||||
# 2. 获取系统健康状态(关键:判断解锁失败原因)
|
||||
check_list.append("\n===== 系统健康自检结果 =====")
|
||||
system_status = health.system_status
|
||||
check_list.append(f"系统状态码:{system_status}")
|
||||
|
||||
# 系统状态解释
|
||||
status_map = {
|
||||
0: "未初始化",
|
||||
1: "初始化中",
|
||||
2: "测试模式",
|
||||
3: "待机",
|
||||
4: "活动",
|
||||
5: "临界",
|
||||
6: "紧急",
|
||||
7: "故障"
|
||||
}
|
||||
check_list.append(f"系统状态:{status_map.get(system_status, '未知')}")
|
||||
# 3. 读取详细传感器健康(核心自检)
|
||||
check_list.append("\n===== 传感器健康检测 =====")
|
||||
for _ in range(30):
|
||||
msg = connection.recv_match(type="SYS_STATUS", blocking=True)
|
||||
if msg:
|
||||
# 传感器是否存在
|
||||
sensors_present = msg.onboard_control_sensors_present
|
||||
sensors_enable = msg.onboard_control_sensors_enabled
|
||||
|
||||
# 传感器列表
|
||||
sensor_list = [
|
||||
("3D陀螺仪", 0),
|
||||
("3D加速度计", 1),
|
||||
("3D磁力计", 2),
|
||||
("气压计", 3),
|
||||
("GPS", 4),
|
||||
("光学流量", 5),
|
||||
("视觉定位", 6),
|
||||
("测距传感器", 7),
|
||||
("电机控制", 8),
|
||||
("电池", 9),
|
||||
("遥控信号", 10),
|
||||
("飞行控制", 11),
|
||||
("避障系统", 12),
|
||||
("云台", 13),
|
||||
]
|
||||
|
||||
# 逐个检测
|
||||
for name, bit in sensor_list:
|
||||
present = (sensors_present >> bit) & 1
|
||||
enable = (sensors_enable >> bit) & 1
|
||||
|
||||
if present:
|
||||
if not enable:
|
||||
check_list.append(f"❌ {name}:故障")
|
||||
else:
|
||||
check_list.append(f"✅ {name}:正常")
|
||||
else:
|
||||
check_list.append(f"ℹ️ {name}:未安装")
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
# 4. 解锁能力判断(解决 Arming denied)
|
||||
check_list.append("\n===== 解锁能力检测 =====")
|
||||
can_arm = False
|
||||
for _ in range(20):
|
||||
msg = connection.recv_match(type="HEARTBEAT", blocking=True)
|
||||
if msg:
|
||||
base_mode = msg.base_mode
|
||||
custom_mode = msg.custom_mode
|
||||
|
||||
# 0b10000000 = 已解锁
|
||||
if (base_mode & mavutil.mavlink.MAV_MODE_FLAG_SAFETY_ARMED) != 0:
|
||||
check_list.append("✅ 当前状态:已解锁")
|
||||
can_arm = True
|
||||
else:
|
||||
check_list.append("⚠️ 当前状态:未解锁")
|
||||
|
||||
# 判断是否禁止解锁
|
||||
if system_status in [5, 6, 7]:
|
||||
check_list.append("❌ 禁止解锁原因:系统处于临界/紧急/故障状态")
|
||||
else:
|
||||
check_list.append("✅ 系统允许解锁,可执行 arm 指令")
|
||||
can_arm = True
|
||||
break
|
||||
|
||||
print("\n===== 自检完成 =====")
|
||||
if can_arm:
|
||||
check_list.append("🎉 结论:无健康故障,可以执行解锁(arm)")
|
||||
else:
|
||||
check_list.append("⚠️ 结论:存在系统故障,必须修复后才能 arm")
|
||||
|
||||
return ToolResult(success=True, output="\n".join(check_list))
|
||||
except Exception as e:
|
||||
return ToolResult(success=False, output=str(e))
|
||||
finally:
|
||||
controller.disconnect()
|
||||
|
|
@ -6,9 +6,9 @@ tools/web_search.py
|
|||
|
||||
from dataclasses import dataclass
|
||||
from serpapi import SerpApiClient
|
||||
from config.settings import settings
|
||||
from tools.base_tool import BaseTool
|
||||
from utils.logger import get_logger
|
||||
from agent.config.settings import settings
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("TOOL.WebSearch")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue