408 lines
16 KiB
Python
408 lines
16 KiB
Python
"""
|
||
mcp/skill_registry.py
|
||
统一 Skill 注册表
|
||
|
||
将本地工具(LocalTool)和在线 MCP Skill(RemoteTool)统一注册,
|
||
对外提供一致的接口:
|
||
- get_all_schemas() → 返回所有工具的 function calling schema
|
||
- dispatch() → 根据工具名路由到本地或远端执行
|
||
- refresh_skills() → 重新拉取在线 Skill 工具列表
|
||
"""
|
||
import importlib
|
||
import os
|
||
import time
|
||
from dataclasses import dataclass
|
||
from typing import Any
|
||
|
||
from config.settings import settings
|
||
from mcp.mcp_skill_client import MCPSkillClient, RemoteTool, ToolCallResult
|
||
from mcp.skill_loader import SkillLoader
|
||
from tools.base_tool import BaseTool
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("MCP.SkillRegistry")
|
||
|
||
|
||
class BaseEntry:
|
||
def to_function_schema(self) -> dict:
|
||
raise NotImplementedError
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 本地SKILL.md包装
|
||
# ════════════════════════════════════════════════════════════════
|
||
@dataclass
|
||
class LocalSkillEntry(BaseEntry):
|
||
name: str
|
||
description: str
|
||
def to_function_schema(self) -> dict:
|
||
return {
|
||
"name": self.name,
|
||
"description": self.description,
|
||
"parameters": {},
|
||
}
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 本地工具包装
|
||
@dataclass# ════════════════════════════════════════════════════════════════
|
||
class LocalToolEntry(BaseEntry):
|
||
"""本地工具注册条目"""
|
||
name: str
|
||
description: str
|
||
parameters: dict[str, Any]
|
||
instance: Any # 需有 execute(**kwargs) → str
|
||
|
||
def to_function_schema(self) -> dict:
|
||
return {
|
||
"name": self.name,
|
||
"description": self.description,
|
||
"parameters": self.parameters,
|
||
}
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 调用结果统一封装
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
@dataclass
|
||
class DispatchResult:
|
||
tool_name: str
|
||
source: str # "local" | skill_name
|
||
success: bool
|
||
content: str = ""
|
||
error: str = ""
|
||
elapsed_sec: float = 0.0
|
||
|
||
def __str__(self) -> str:
|
||
return self.content if self.success else f"❌ {self.error}"
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 统一 Skill 注册表
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class SkillRegistry:
|
||
"""
|
||
统一 Skill 注册表
|
||
|
||
用法:
|
||
registry = SkillRegistry()
|
||
registry.register_local(tool_instance)
|
||
registry.connect_skills()
|
||
schemas = registry.get_all_schemas()
|
||
result = registry.dispatch("tool_name", {"arg": "val"})
|
||
registry.close()
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 本地工具表: tool_name → LocalToolEntry
|
||
self._local: dict[str, LocalToolEntry|LocalSkillEntry] = {}
|
||
# 远端工具表: tool_name → (MCPSkillClient, RemoteTool)
|
||
self._remote: dict[str, tuple[MCPSkillClient, RemoteTool]] = {}
|
||
# 在线 Skill 客户端列表(用于生命周期管理)
|
||
self._clients: list[MCPSkillClient] = []
|
||
|
||
|
||
def load_local_skills(self, directory: str) -> None:
|
||
skills = SkillLoader.load_skills_from_directory(directory)
|
||
for skill_name, skill_info in skills.items():
|
||
logger.info(f"📦 加载技能: {skill_name}")
|
||
# 将技能注册到远端工具表
|
||
self.register_local_skill(skill_info)
|
||
|
||
def load_local_tools(self):
|
||
enabled = settings.mcp.enabled_tools
|
||
logger.info(f"🔧 注册本地工具: {enabled}")
|
||
for tool_name in enabled:
|
||
tool_path = f"tools/{tool_name}.py"
|
||
if not os.path.exists(tool_path):
|
||
continue
|
||
# 动态加载模块
|
||
spec = importlib.util.spec_from_file_location(tool_name, tool_path)
|
||
module = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(module)
|
||
if not hasattr(module, "Tool"):
|
||
continue
|
||
cls = getattr(module, 'Tool')
|
||
if not issubclass(cls, BaseTool):
|
||
continue
|
||
if cls:
|
||
self.register_local_tool(cls())
|
||
else:
|
||
logger.warning(f"⚠️ 未知工具: {tool_name},跳过")
|
||
|
||
|
||
def register_local_skill(self, skill_info: dict[str, Any]) -> None:
|
||
name = skill_info.get("name", None)
|
||
if not name:
|
||
return
|
||
self._local[name] = LocalSkillEntry(
|
||
name=name,
|
||
description=skill_info.get( "description", ""),
|
||
# parameters=getattr(tool_instance, "parameters", {}),
|
||
# instance=tool_instance,
|
||
)
|
||
logger.debug(f"📌 注册本地工具: {name}")
|
||
|
||
# ── 注册本地工具 ──────────────────────────────────────────
|
||
|
||
def register_local_tool(self, tool_instance: Any) -> None:
|
||
"""
|
||
注册本地工具实例
|
||
工具实例需具备: .name / .description / .parameters / .execute(**kwargs)
|
||
"""
|
||
name = getattr(tool_instance, "name", None)
|
||
if not name:
|
||
logger.warning(f"⚠️ 工具实例缺少 name 属性,跳过: {tool_instance}")
|
||
return
|
||
self._local[name] = LocalToolEntry(
|
||
name=name,
|
||
description=getattr(tool_instance, "description", ""),
|
||
parameters=getattr(tool_instance, "parameters", {}),
|
||
instance=tool_instance,
|
||
)
|
||
logger.debug(f"📌 注册本地工具: {name}")
|
||
|
||
def register_local_many(self, *tool_instances: Any) -> None:
|
||
for t in tool_instances:
|
||
self.register_local_tool(t)
|
||
|
||
# ── 连接在线 MCP Skill ────────────────────────────────────
|
||
|
||
def connect_mcp_skills(self) -> dict[str, list[str]]:
|
||
"""
|
||
连接所有 config.yaml 中 enabled=true 的在线 MCP Skill
|
||
并将其工具注册到远端工具表
|
||
|
||
Returns:
|
||
{skill_name: [tool_name, ...]} 成功注册的工具映射
|
||
"""
|
||
enabled = settings.enabled_mcp_skills
|
||
if not enabled:
|
||
logger.info("ℹ️ 未配置任何在线 MCP Skill")
|
||
return {}
|
||
|
||
logger.info(f"🌐 开始连接在线 MCP Skills,数量={len(enabled)}")
|
||
registered_map: dict[str, list[str]] = {}
|
||
|
||
for skill_cfg in enabled:
|
||
client = MCPSkillClient(skill_cfg)
|
||
try:
|
||
client.connect()
|
||
tools = client.list_tools()
|
||
self._clients.append(client)
|
||
|
||
names = []
|
||
for tool in tools:
|
||
# 冲突警告
|
||
if tool.name in self._local:
|
||
logger.warning(
|
||
f"⚠️ 工具名冲突 [{tool.name}]:本地工具被远端 "
|
||
f"Skill [{skill_cfg.name}] 覆盖"
|
||
)
|
||
if tool.name in self._remote:
|
||
prev = self._remote[tool.name][1].skill_name
|
||
logger.warning(
|
||
f"⚠️ 工具名冲突 [{tool.name}]:远端 Skill [{prev}] "
|
||
f"被 [{skill_cfg.name}] 覆盖"
|
||
)
|
||
self._remote[tool.name] = (client, tool)
|
||
names.append(tool.name)
|
||
|
||
registered_map[skill_cfg.name] = names
|
||
logger.info(
|
||
f"✅ Skill [{skill_cfg.name}] 注册完成 "
|
||
f"工具数={len(names)}: {names}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"❌ Skill [{skill_cfg.name}] 连接失败,跳过\n"
|
||
f" 错误: {e}"
|
||
)
|
||
try:
|
||
client.close()
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info(
|
||
f"📊 SkillRegistry 初始化完成\n"
|
||
f" 本地工具 : {len(self._local)} 个 {list(self._local.keys())}\n"
|
||
f" 远端工具 : {len(self._remote)} 个 {list(self._remote.keys())}"
|
||
)
|
||
return registered_map
|
||
|
||
def refresh_skills(self) -> None:
|
||
"""重新拉取所有在线 Skill 的工具列表(不重新建立连接)"""
|
||
logger.info("🔄 刷新在线 Skill 工具列表...")
|
||
self._remote.clear()
|
||
for client in self._clients:
|
||
try:
|
||
tools = client.list_tools(force_refresh=True)
|
||
for tool in tools:
|
||
self._remote[tool.name] = (client, tool)
|
||
logger.info(
|
||
f" ✅ [{client.cfg.name}] 刷新完成 "
|
||
f"工具数={len(tools)}"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f" ❌ [{client.cfg.name}] 刷新失败: {e}")
|
||
|
||
# ── 工具查询 ──────────────────────────────────────────────
|
||
|
||
def get_all_schemas(self) -> list[dict]:
|
||
"""
|
||
返回所有工具(本地 + 远端)的 function calling schema 列表
|
||
用于构造 LLM 的 tools 参数
|
||
"""
|
||
schemas = []
|
||
# 本地工具
|
||
for entry in self._local.values():
|
||
schemas.append(entry.to_function_schema())
|
||
# 远端工具(不被本地同名工具覆盖的)
|
||
for name, (_, tool) in self._remote.items():
|
||
if name not in self._local:
|
||
schemas.append(tool.to_function_schema())
|
||
return schemas
|
||
|
||
def get_tool_info(self, tool_name: str) -> dict | None:
|
||
"""查询单个工具的来源和描述信息"""
|
||
if tool_name in self._local:
|
||
entry = self._local[tool_name]
|
||
return {
|
||
"name": entry.name,
|
||
"source": "local",
|
||
"description": entry.description,
|
||
}
|
||
if tool_name in self._remote:
|
||
_, tool = self._remote[tool_name]
|
||
return {
|
||
"name": tool.name,
|
||
"source": f"remote:{tool.skill_name}",
|
||
"description": tool.description,
|
||
}
|
||
return None
|
||
|
||
def list_all_tools(self) -> list[dict]:
|
||
"""列出所有工具及其来源(用于调试/展示)"""
|
||
result = []
|
||
for name, entry in self._local.items():
|
||
result.append({
|
||
"name": name,
|
||
"source": "local",
|
||
"description": entry.description[:80],
|
||
})
|
||
for name, (_, tool) in self._remote.items():
|
||
if name not in self._local:
|
||
result.append({
|
||
"name": name,
|
||
"source": f"remote:{tool.skill_name}",
|
||
"description": tool.description[:80],
|
||
})
|
||
return result
|
||
|
||
def has_tool(self, tool_name: str) -> bool:
|
||
return tool_name in self._local or tool_name in self._remote
|
||
|
||
# ── 工具调用路由 ──────────────────────────────────────────
|
||
|
||
def dispatch(
|
||
self,
|
||
tool_name: str,
|
||
arguments: dict[str, Any],
|
||
) -> DispatchResult:
|
||
"""
|
||
统一工具调用入口:自动路由到本地或远端
|
||
|
||
优先级: 本地工具 > 远端 Skill 工具
|
||
"""
|
||
# ── 本地工具 ──────────────────────────────────────────
|
||
if tool_name in self._local:
|
||
return self._dispatch_local(tool_name, arguments)
|
||
|
||
# ── 远端工具 ──────────────────────────────────────────
|
||
if tool_name in self._remote:
|
||
return self._dispatch_remote(tool_name, arguments)
|
||
|
||
# ── 未找到 ────────────────────────────────────────────
|
||
available = list(self._local.keys()) + list(self._remote.keys())
|
||
return DispatchResult(
|
||
tool_name=tool_name,
|
||
source="unknown",
|
||
success=False,
|
||
error=(
|
||
f"工具 '{tool_name}' 未注册\n"
|
||
f"可用工具: {available}"
|
||
),
|
||
)
|
||
|
||
def _dispatch_local(
|
||
self, tool_name: str, arguments: dict[str, Any]
|
||
) -> DispatchResult:
|
||
"""调用本地工具"""
|
||
entry = self._local[tool_name]
|
||
start = time.time()
|
||
logger.info(f"🔧 调用本地工具: {tool_name} 参数={arguments}")
|
||
try:
|
||
content = entry.instance.execute(**arguments)
|
||
elapsed = time.time() - start
|
||
logger.info(f"✅ 本地工具完成: {tool_name} 耗时={elapsed:.2f}s")
|
||
return DispatchResult(
|
||
tool_name=tool_name,
|
||
source="local",
|
||
success=True,
|
||
content=str(content),
|
||
elapsed_sec=elapsed,
|
||
)
|
||
except Exception as e:
|
||
elapsed = time.time() - start
|
||
logger.error(f"❌ 本地工具异常: {tool_name} {e}")
|
||
return DispatchResult(
|
||
tool_name=tool_name,
|
||
source="local",
|
||
success=False,
|
||
error=str(e),
|
||
elapsed_sec=elapsed,
|
||
)
|
||
|
||
def _dispatch_remote(
|
||
self, tool_name: str, arguments: dict[str, Any]
|
||
) -> DispatchResult:
|
||
"""调用远端 MCP Skill 工具"""
|
||
client, tool = self._remote[tool_name]
|
||
logger.info(
|
||
f"🌐 调用远端工具: [{tool.skill_name}] / {tool_name} "
|
||
f"参数={arguments}"
|
||
)
|
||
result = client.call_tool(tool_name, arguments)
|
||
return DispatchResult(
|
||
tool_name=tool_name,
|
||
source=f"remote:{tool.skill_name}",
|
||
success=result.success,
|
||
content=result.content,
|
||
error=result.error,
|
||
elapsed_sec=result.elapsed_sec,
|
||
)
|
||
|
||
# ── 生命周期 ──────────────────────────────────────────────
|
||
|
||
def close(self) -> None:
|
||
"""关闭所有在线 Skill 连接"""
|
||
for client in self._clients:
|
||
try:
|
||
client.close()
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ 关闭 Skill [{client.cfg.name}] 时异常: {e}")
|
||
self._clients.clear()
|
||
self._remote.clear()
|
||
logger.info("🔌 SkillRegistry 已关闭所有连接")
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, *_):
|
||
self.close()
|
||
|
||
def __repr__(self) -> str:
|
||
return (
|
||
f"SkillRegistry("
|
||
f"local={list(self._local.keys())}, "
|
||
f"remote={list(self._remote.keys())})"
|
||
) |