""" mcp/skill_registry.py 统一 Skill 注册表 将本地工具(LocalTool)和在线 MCP Skill(RemoteTool)统一注册, 对外提供一致的接口: - get_all_schemas() → 返回所有工具的 function calling schema - dispatch() → 根据工具名路由到本地或远端执行 - refresh_skills() → 重新拉取在线 Skill 工具列表 """ import time from dataclasses import dataclass from typing import Any from config.settings import settings from mcp.mcp_skill_client import MCPSkillClient, RemoteTool, ToolCallResult from utils.logger import get_logger logger = get_logger("MCP.SkillRegistry") # ════════════════════════════════════════════════════════════════ # 本地工具包装 # ════════════════════════════════════════════════════════════════ @dataclass class LocalToolEntry: """本地工具注册条目""" name: str description: str parameters: dict[str, Any] instance: Any # 需有 execute(**kwargs) → str def to_function_schema(self) -> dict: return { "name": self.name, "description": self.description, "parameters": self.parameters, } # ════════════════════════════════════════════════════════════════ # 调用结果统一封装 # ════════════════════════════════════════════════════════════════ @dataclass class DispatchResult: tool_name: str source: str # "local" | skill_name success: bool content: str = "" error: str = "" elapsed_sec: float = 0.0 def __str__(self) -> str: return self.content if self.success else f"❌ {self.error}" # ════════════════════════════════════════════════════════════════ # 统一 Skill 注册表 # ════════════════════════════════════════════════════════════════ class SkillRegistry: """ 统一 Skill 注册表 用法: registry = SkillRegistry() registry.register_local(tool_instance) registry.connect_skills() schemas = registry.get_all_schemas() result = registry.dispatch("tool_name", {"arg": "val"}) registry.close() """ def __init__(self): # 本地工具表: tool_name → LocalToolEntry self._local: dict[str, LocalToolEntry] = {} # 远端工具表: tool_name → (MCPSkillClient, RemoteTool) self._remote: dict[str, tuple[MCPSkillClient, RemoteTool]] = {} # 在线 Skill 客户端列表(用于生命周期管理) self._clients: list[MCPSkillClient] = [] # ── 注册本地工具 ────────────────────────────────────────── def register_local(self, tool_instance: Any) -> None: """ 注册本地工具实例 工具实例需具备: .name / .description / .parameters / .execute(**kwargs) """ name = getattr(tool_instance, "name", None) if not name: logger.warning(f"⚠️ 工具实例缺少 name 属性,跳过: {tool_instance}") return self._local[name] = LocalToolEntry( name=name, description=getattr(tool_instance, "description", ""), parameters=getattr(tool_instance, "parameters", {}), instance=tool_instance, ) logger.debug(f"📌 注册本地工具: {name}") def register_local_many(self, *tool_instances: Any) -> None: for t in tool_instances: self.register_local(t) # ── 连接在线 MCP Skill ──────────────────────────────────── def connect_skills(self) -> dict[str, list[str]]: """ 连接所有 config.yaml 中 enabled=true 的在线 MCP Skill 并将其工具注册到远端工具表 Returns: {skill_name: [tool_name, ...]} 成功注册的工具映射 """ enabled = settings.enabled_mcp_skills if not enabled: logger.info("ℹ️ 未配置任何在线 MCP Skill") return {} logger.info(f"🌐 开始连接在线 MCP Skills,数量={len(enabled)}") registered_map: dict[str, list[str]] = {} for skill_cfg in enabled: client = MCPSkillClient(skill_cfg) try: client.connect() tools = client.list_tools() self._clients.append(client) names = [] for tool in tools: # 冲突警告 if tool.name in self._local: logger.warning( f"⚠️ 工具名冲突 [{tool.name}]:本地工具被远端 " f"Skill [{skill_cfg.name}] 覆盖" ) if tool.name in self._remote: prev = self._remote[tool.name][1].skill_name logger.warning( f"⚠️ 工具名冲突 [{tool.name}]:远端 Skill [{prev}] " f"被 [{skill_cfg.name}] 覆盖" ) self._remote[tool.name] = (client, tool) names.append(tool.name) registered_map[skill_cfg.name] = names logger.info( f"✅ Skill [{skill_cfg.name}] 注册完成 " f"工具数={len(names)}: {names}" ) except Exception as e: logger.error( f"❌ Skill [{skill_cfg.name}] 连接失败,跳过\n" f" 错误: {e}" ) try: client.close() except Exception: pass logger.info( f"📊 SkillRegistry 初始化完成\n" f" 本地工具 : {len(self._local)} 个 {list(self._local.keys())}\n" f" 远端工具 : {len(self._remote)} 个 {list(self._remote.keys())}" ) return registered_map def refresh_skills(self) -> None: """重新拉取所有在线 Skill 的工具列表(不重新建立连接)""" logger.info("🔄 刷新在线 Skill 工具列表...") self._remote.clear() for client in self._clients: try: tools = client.list_tools(force_refresh=True) for tool in tools: self._remote[tool.name] = (client, tool) logger.info( f" ✅ [{client.cfg.name}] 刷新完成 " f"工具数={len(tools)}" ) except Exception as e: logger.error(f" ❌ [{client.cfg.name}] 刷新失败: {e}") # ── 工具查询 ────────────────────────────────────────────── def get_all_schemas(self) -> list[dict]: """ 返回所有工具(本地 + 远端)的 function calling schema 列表 用于构造 LLM 的 tools 参数 """ schemas = [] # 本地工具 for entry in self._local.values(): schemas.append(entry.to_function_schema()) # 远端工具(不被本地同名工具覆盖的) for name, (_, tool) in self._remote.items(): if name not in self._local: schemas.append(tool.to_function_schema()) return schemas def get_tool_info(self, tool_name: str) -> dict | None: """查询单个工具的来源和描述信息""" if tool_name in self._local: entry = self._local[tool_name] return { "name": entry.name, "source": "local", "description": entry.description, } if tool_name in self._remote: _, tool = self._remote[tool_name] return { "name": tool.name, "source": f"remote:{tool.skill_name}", "description": tool.description, } return None def list_all_tools(self) -> list[dict]: """列出所有工具及其来源(用于调试/展示)""" result = [] for name, entry in self._local.items(): result.append({ "name": name, "source": "local", "description": entry.description[:80], }) for name, (_, tool) in self._remote.items(): if name not in self._local: result.append({ "name": name, "source": f"remote:{tool.skill_name}", "description": tool.description[:80], }) return result def has_tool(self, tool_name: str) -> bool: return tool_name in self._local or tool_name in self._remote # ── 工具调用路由 ────────────────────────────────────────── def dispatch( self, tool_name: str, arguments: dict[str, Any], ) -> DispatchResult: """ 统一工具调用入口:自动路由到本地或远端 优先级: 本地工具 > 远端 Skill 工具 """ # ── 本地工具 ────────────────────────────────────────── if tool_name in self._local: return self._dispatch_local(tool_name, arguments) # ── 远端工具 ────────────────────────────────────────── if tool_name in self._remote: return self._dispatch_remote(tool_name, arguments) # ── 未找到 ──────────────────────────────────────────── available = list(self._local.keys()) + list(self._remote.keys()) return DispatchResult( tool_name=tool_name, source="unknown", success=False, error=( f"工具 '{tool_name}' 未注册\n" f"可用工具: {available}" ), ) def _dispatch_local( self, tool_name: str, arguments: dict[str, Any] ) -> DispatchResult: """调用本地工具""" entry = self._local[tool_name] start = time.time() logger.info(f"🔧 调用本地工具: {tool_name} 参数={arguments}") try: content = entry.instance.execute(**arguments) elapsed = time.time() - start logger.info(f"✅ 本地工具完成: {tool_name} 耗时={elapsed:.2f}s") return DispatchResult( tool_name=tool_name, source="local", success=True, content=str(content), elapsed_sec=elapsed, ) except Exception as e: elapsed = time.time() - start logger.error(f"❌ 本地工具异常: {tool_name} {e}") return DispatchResult( tool_name=tool_name, source="local", success=False, error=str(e), elapsed_sec=elapsed, ) def _dispatch_remote( self, tool_name: str, arguments: dict[str, Any] ) -> DispatchResult: """调用远端 MCP Skill 工具""" client, tool = self._remote[tool_name] logger.info( f"🌐 调用远端工具: [{tool.skill_name}] / {tool_name} " f"参数={arguments}" ) result = client.call_tool(tool_name, arguments) return DispatchResult( tool_name=tool_name, source=f"remote:{tool.skill_name}", success=result.success, content=result.content, error=result.error, elapsed_sec=result.elapsed_sec, ) # ── 生命周期 ────────────────────────────────────────────── def close(self) -> None: """关闭所有在线 Skill 连接""" for client in self._clients: try: client.close() except Exception as e: logger.warning(f"⚠️ 关闭 Skill [{client.cfg.name}] 时异常: {e}") self._clients.clear() self._remote.clear() logger.info("🔌 SkillRegistry 已关闭所有连接") def __enter__(self): return self def __exit__(self, *_): self.close() def __repr__(self) -> str: return ( f"SkillRegistry(" f"local={list(self._local.keys())}, " f"remote={list(self._remote.keys())})" )