base_agent/mcp/skill_registry.py

359 lines
14 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
mcp/skill_registry.py
统一 Skill 注册表
将本地工具LocalTool和在线 MCP SkillRemoteTool统一注册
对外提供一致的接口:
- get_all_schemas() → 返回所有工具的 function calling schema
- dispatch() → 根据工具名路由到本地或远端执行
- refresh_skills() → 重新拉取在线 Skill 工具列表
"""
import time
from dataclasses import dataclass
from typing import Any
from config.settings import settings
from mcp.mcp_skill_client import MCPSkillClient, RemoteTool, ToolCallResult
from mcp.skill_loader import SkillLoader
from utils.logger import get_logger
logger = get_logger("MCP.SkillRegistry")
# ════════════════════════════════════════════════════════════════
# 本地工具包装
# ════════════════════════════════════════════════════════════════
@dataclass
class LocalToolEntry:
"""本地工具注册条目"""
name: str
description: str
parameters: dict[str, Any]
instance: Any # 需有 execute(**kwargs) → str
def to_function_schema(self) -> dict:
return {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
# ════════════════════════════════════════════════════════════════
# 调用结果统一封装
# ════════════════════════════════════════════════════════════════
@dataclass
class DispatchResult:
tool_name: str
source: str # "local" | skill_name
success: bool
content: str = ""
error: str = ""
elapsed_sec: float = 0.0
def __str__(self) -> str:
return self.content if self.success else f"{self.error}"
# ════════════════════════════════════════════════════════════════
# 统一 Skill 注册表
# ════════════════════════════════════════════════════════════════
class SkillRegistry:
"""
统一 Skill 注册表
用法:
registry = SkillRegistry()
registry.register_local(tool_instance)
registry.connect_skills()
schemas = registry.get_all_schemas()
result = registry.dispatch("tool_name", {"arg": "val"})
registry.close()
"""
def __init__(self):
# 本地工具表: tool_name → LocalToolEntry
self._local: dict[str, LocalToolEntry] = {}
# 远端工具表: tool_name → (MCPSkillClient, RemoteTool)
self._remote: dict[str, tuple[MCPSkillClient, RemoteTool]] = {}
# 在线 Skill 客户端列表(用于生命周期管理)
self._clients: list[MCPSkillClient] = []
# 指定 SKILL.md 文件所在目录
self.load_skills_from_md(settings.skills_directory)
def load_skills_from_md(self, directory: str) -> None:
skills = SkillLoader.load_skills_from_directory(directory)
for skill_name, skill_info in skills.items():
logger.info(f"📦 加载技能: {skill_name}")
# 将技能注册到远端工具表
self._remote[skill_name] = skill_info
# ── 注册本地工具 ──────────────────────────────────────────
def register_local(self, tool_instance: Any) -> None:
"""
注册本地工具实例
工具实例需具备: .name / .description / .parameters / .execute(**kwargs)
"""
name = getattr(tool_instance, "name", None)
if not name:
logger.warning(f"⚠️ 工具实例缺少 name 属性,跳过: {tool_instance}")
return
self._local[name] = LocalToolEntry(
name=name,
description=getattr(tool_instance, "description", ""),
parameters=getattr(tool_instance, "parameters", {}),
instance=tool_instance,
)
logger.debug(f"📌 注册本地工具: {name}")
def register_local_many(self, *tool_instances: Any) -> None:
for t in tool_instances:
self.register_local(t)
# ── 连接在线 MCP Skill ────────────────────────────────────
def connect_skills(self) -> dict[str, list[str]]:
"""
连接所有 config.yaml 中 enabled=true 的在线 MCP Skill
并将其工具注册到远端工具表
Returns:
{skill_name: [tool_name, ...]} 成功注册的工具映射
"""
enabled = settings.enabled_mcp_skills
if not enabled:
logger.info(" 未配置任何在线 MCP Skill")
return {}
logger.info(f"🌐 开始连接在线 MCP Skills数量={len(enabled)}")
registered_map: dict[str, list[str]] = {}
for skill_cfg in enabled:
client = MCPSkillClient(skill_cfg)
try:
client.connect()
tools = client.list_tools()
self._clients.append(client)
names = []
for tool in tools:
# 冲突警告
if tool.name in self._local:
logger.warning(
f"⚠️ 工具名冲突 [{tool.name}]:本地工具被远端 "
f"Skill [{skill_cfg.name}] 覆盖"
)
if tool.name in self._remote:
prev = self._remote[tool.name][1].skill_name
logger.warning(
f"⚠️ 工具名冲突 [{tool.name}]:远端 Skill [{prev}] "
f"被 [{skill_cfg.name}] 覆盖"
)
self._remote[tool.name] = (client, tool)
names.append(tool.name)
registered_map[skill_cfg.name] = names
logger.info(
f"✅ Skill [{skill_cfg.name}] 注册完成 "
f"工具数={len(names)}: {names}"
)
except Exception as e:
logger.error(
f"❌ Skill [{skill_cfg.name}] 连接失败,跳过\n"
f" 错误: {e}"
)
try:
client.close()
except Exception:
pass
logger.info(
f"📊 SkillRegistry 初始化完成\n"
f" 本地工具 : {len(self._local)}{list(self._local.keys())}\n"
f" 远端工具 : {len(self._remote)}{list(self._remote.keys())}"
)
return registered_map
def refresh_skills(self) -> None:
"""重新拉取所有在线 Skill 的工具列表(不重新建立连接)"""
logger.info("🔄 刷新在线 Skill 工具列表...")
self._remote.clear()
for client in self._clients:
try:
tools = client.list_tools(force_refresh=True)
for tool in tools:
self._remote[tool.name] = (client, tool)
logger.info(
f" ✅ [{client.cfg.name}] 刷新完成 "
f"工具数={len(tools)}"
)
except Exception as e:
logger.error(f" ❌ [{client.cfg.name}] 刷新失败: {e}")
# ── 工具查询 ──────────────────────────────────────────────
def get_all_schemas(self) -> list[dict]:
"""
返回所有工具(本地 + 远端)的 function calling schema 列表
用于构造 LLM 的 tools 参数
"""
schemas = []
# 本地工具
for entry in self._local.values():
schemas.append(entry.to_function_schema())
# 远端工具(不被本地同名工具覆盖的)
for name, (_, tool) in self._remote.items():
if name not in self._local:
schemas.append(tool.to_function_schema())
return schemas
def get_tool_info(self, tool_name: str) -> dict | None:
"""查询单个工具的来源和描述信息"""
if tool_name in self._local:
entry = self._local[tool_name]
return {
"name": entry.name,
"source": "local",
"description": entry.description,
}
if tool_name in self._remote:
_, tool = self._remote[tool_name]
return {
"name": tool.name,
"source": f"remote:{tool.skill_name}",
"description": tool.description,
}
return None
def list_all_tools(self) -> list[dict]:
"""列出所有工具及其来源(用于调试/展示)"""
result = []
for name, entry in self._local.items():
result.append({
"name": name,
"source": "local",
"description": entry.description[:80],
})
for name, (_, tool) in self._remote.items():
if name not in self._local:
result.append({
"name": name,
"source": f"remote:{tool.skill_name}",
"description": tool.description[:80],
})
return result
def has_tool(self, tool_name: str) -> bool:
return tool_name in self._local or tool_name in self._remote
# ── 工具调用路由 ──────────────────────────────────────────
def dispatch(
self,
tool_name: str,
arguments: dict[str, Any],
) -> DispatchResult:
"""
统一工具调用入口:自动路由到本地或远端
优先级: 本地工具 > 远端 Skill 工具
"""
# ── 本地工具 ──────────────────────────────────────────
if tool_name in self._local:
return self._dispatch_local(tool_name, arguments)
# ── 远端工具 ──────────────────────────────────────────
if tool_name in self._remote:
return self._dispatch_remote(tool_name, arguments)
# ── 未找到 ────────────────────────────────────────────
available = list(self._local.keys()) + list(self._remote.keys())
return DispatchResult(
tool_name=tool_name,
source="unknown",
success=False,
error=(
f"工具 '{tool_name}' 未注册\n"
f"可用工具: {available}"
),
)
def _dispatch_local(
self, tool_name: str, arguments: dict[str, Any]
) -> DispatchResult:
"""调用本地工具"""
entry = self._local[tool_name]
start = time.time()
logger.info(f"🔧 调用本地工具: {tool_name} 参数={arguments}")
try:
content = entry.instance.execute(**arguments)
elapsed = time.time() - start
logger.info(f"✅ 本地工具完成: {tool_name} 耗时={elapsed:.2f}s")
return DispatchResult(
tool_name=tool_name,
source="local",
success=True,
content=str(content),
elapsed_sec=elapsed,
)
except Exception as e:
elapsed = time.time() - start
logger.error(f"❌ 本地工具异常: {tool_name} {e}")
return DispatchResult(
tool_name=tool_name,
source="local",
success=False,
error=str(e),
elapsed_sec=elapsed,
)
def _dispatch_remote(
self, tool_name: str, arguments: dict[str, Any]
) -> DispatchResult:
"""调用远端 MCP Skill 工具"""
client, tool = self._remote[tool_name]
logger.info(
f"🌐 调用远端工具: [{tool.skill_name}] / {tool_name} "
f"参数={arguments}"
)
result = client.call_tool(tool_name, arguments)
return DispatchResult(
tool_name=tool_name,
source=f"remote:{tool.skill_name}",
success=result.success,
content=result.content,
error=result.error,
elapsed_sec=result.elapsed_sec,
)
# ── 生命周期 ──────────────────────────────────────────────
def close(self) -> None:
"""关闭所有在线 Skill 连接"""
for client in self._clients:
try:
client.close()
except Exception as e:
logger.warning(f"⚠️ 关闭 Skill [{client.cfg.name}] 时异常: {e}")
self._clients.clear()
self._remote.clear()
logger.info("🔌 SkillRegistry 已关闭所有连接")
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
def __repr__(self) -> str:
return (
f"SkillRegistry("
f"local={list(self._local.keys())}, "
f"remote={list(self._remote.keys())})"
)