base_agent/mcp/skill_registry.py

349 lines
14 KiB
Python
Raw Normal View History

2026-03-30 08:48:36 +00:00
"""
mcp/skill_registry.py
统一 Skill 注册表
将本地工具LocalTool和在线 MCP SkillRemoteTool统一注册
对外提供一致的接口
- get_all_schemas() 返回所有工具的 function calling schema
- dispatch() 根据工具名路由到本地或远端执行
- refresh_skills() 重新拉取在线 Skill 工具列表
"""
import time
from dataclasses import dataclass
from typing import Any
from config.settings import settings
from mcp.mcp_skill_client import MCPSkillClient, RemoteTool, ToolCallResult
from utils.logger import get_logger
logger = get_logger("MCP.SkillRegistry")
# ════════════════════════════════════════════════════════════════
# 本地工具包装
# ════════════════════════════════════════════════════════════════
@dataclass
class LocalToolEntry:
"""本地工具注册条目"""
name: str
description: str
parameters: dict[str, Any]
instance: Any # 需有 execute(**kwargs) → str
def to_function_schema(self) -> dict:
return {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
# ════════════════════════════════════════════════════════════════
# 调用结果统一封装
# ════════════════════════════════════════════════════════════════
@dataclass
class DispatchResult:
tool_name: str
source: str # "local" | skill_name
success: bool
content: str = ""
error: str = ""
elapsed_sec: float = 0.0
def __str__(self) -> str:
return self.content if self.success else f"{self.error}"
# ════════════════════════════════════════════════════════════════
# 统一 Skill 注册表
# ════════════════════════════════════════════════════════════════
class SkillRegistry:
"""
统一 Skill 注册表
用法:
registry = SkillRegistry()
registry.register_local(tool_instance)
registry.connect_skills()
schemas = registry.get_all_schemas()
result = registry.dispatch("tool_name", {"arg": "val"})
registry.close()
"""
def __init__(self):
# 本地工具表: tool_name → LocalToolEntry
self._local: dict[str, LocalToolEntry] = {}
# 远端工具表: tool_name → (MCPSkillClient, RemoteTool)
self._remote: dict[str, tuple[MCPSkillClient, RemoteTool]] = {}
# 在线 Skill 客户端列表(用于生命周期管理)
self._clients: list[MCPSkillClient] = []
# ── 注册本地工具 ──────────────────────────────────────────
def register_local(self, tool_instance: Any) -> None:
"""
注册本地工具实例
工具实例需具备: .name / .description / .parameters / .execute(**kwargs)
"""
name = getattr(tool_instance, "name", None)
if not name:
logger.warning(f"⚠️ 工具实例缺少 name 属性,跳过: {tool_instance}")
return
self._local[name] = LocalToolEntry(
name=name,
description=getattr(tool_instance, "description", ""),
parameters=getattr(tool_instance, "parameters", {}),
instance=tool_instance,
)
logger.debug(f"📌 注册本地工具: {name}")
def register_local_many(self, *tool_instances: Any) -> None:
for t in tool_instances:
self.register_local(t)
# ── 连接在线 MCP Skill ────────────────────────────────────
def connect_skills(self) -> dict[str, list[str]]:
"""
连接所有 config.yaml enabled=true 的在线 MCP Skill
并将其工具注册到远端工具表
Returns:
{skill_name: [tool_name, ...]} 成功注册的工具映射
"""
enabled = settings.enabled_mcp_skills
if not enabled:
logger.info(" 未配置任何在线 MCP Skill")
return {}
logger.info(f"🌐 开始连接在线 MCP Skills数量={len(enabled)}")
registered_map: dict[str, list[str]] = {}
for skill_cfg in enabled:
client = MCPSkillClient(skill_cfg)
try:
client.connect()
tools = client.list_tools()
self._clients.append(client)
names = []
for tool in tools:
# 冲突警告
if tool.name in self._local:
logger.warning(
f"⚠️ 工具名冲突 [{tool.name}]:本地工具被远端 "
f"Skill [{skill_cfg.name}] 覆盖"
)
if tool.name in self._remote:
prev = self._remote[tool.name][1].skill_name
logger.warning(
f"⚠️ 工具名冲突 [{tool.name}]:远端 Skill [{prev}] "
f"被 [{skill_cfg.name}] 覆盖"
)
self._remote[tool.name] = (client, tool)
names.append(tool.name)
registered_map[skill_cfg.name] = names
logger.info(
f"✅ Skill [{skill_cfg.name}] 注册完成 "
f"工具数={len(names)}: {names}"
)
except Exception as e:
logger.error(
f"❌ Skill [{skill_cfg.name}] 连接失败,跳过\n"
f" 错误: {e}"
)
try:
client.close()
except Exception:
pass
logger.info(
f"📊 SkillRegistry 初始化完成\n"
f" 本地工具 : {len(self._local)}{list(self._local.keys())}\n"
f" 远端工具 : {len(self._remote)}{list(self._remote.keys())}"
)
return registered_map
def refresh_skills(self) -> None:
"""重新拉取所有在线 Skill 的工具列表(不重新建立连接)"""
logger.info("🔄 刷新在线 Skill 工具列表...")
self._remote.clear()
for client in self._clients:
try:
tools = client.list_tools(force_refresh=True)
for tool in tools:
self._remote[tool.name] = (client, tool)
logger.info(
f" ✅ [{client.cfg.name}] 刷新完成 "
f"工具数={len(tools)}"
)
except Exception as e:
logger.error(f" ❌ [{client.cfg.name}] 刷新失败: {e}")
# ── 工具查询 ──────────────────────────────────────────────
def get_all_schemas(self) -> list[dict]:
"""
返回所有工具本地 + 远端 function calling schema 列表
用于构造 LLM tools 参数
"""
schemas = []
# 本地工具
for entry in self._local.values():
schemas.append(entry.to_function_schema())
# 远端工具(不被本地同名工具覆盖的)
for name, (_, tool) in self._remote.items():
if name not in self._local:
schemas.append(tool.to_function_schema())
return schemas
def get_tool_info(self, tool_name: str) -> dict | None:
"""查询单个工具的来源和描述信息"""
if tool_name in self._local:
entry = self._local[tool_name]
return {
"name": entry.name,
"source": "local",
"description": entry.description,
}
if tool_name in self._remote:
_, tool = self._remote[tool_name]
return {
"name": tool.name,
"source": f"remote:{tool.skill_name}",
"description": tool.description,
}
return None
def list_all_tools(self) -> list[dict]:
"""列出所有工具及其来源(用于调试/展示)"""
result = []
for name, entry in self._local.items():
result.append({
"name": name,
"source": "local",
"description": entry.description[:80],
})
for name, (_, tool) in self._remote.items():
if name not in self._local:
result.append({
"name": name,
"source": f"remote:{tool.skill_name}",
"description": tool.description[:80],
})
return result
def has_tool(self, tool_name: str) -> bool:
return tool_name in self._local or tool_name in self._remote
# ── 工具调用路由 ──────────────────────────────────────────
def dispatch(
self,
tool_name: str,
arguments: dict[str, Any],
) -> DispatchResult:
"""
统一工具调用入口自动路由到本地或远端
优先级: 本地工具 > 远端 Skill 工具
"""
# ── 本地工具 ──────────────────────────────────────────
if tool_name in self._local:
return self._dispatch_local(tool_name, arguments)
# ── 远端工具 ──────────────────────────────────────────
if tool_name in self._remote:
return self._dispatch_remote(tool_name, arguments)
# ── 未找到 ────────────────────────────────────────────
available = list(self._local.keys()) + list(self._remote.keys())
return DispatchResult(
tool_name=tool_name,
source="unknown",
success=False,
error=(
f"工具 '{tool_name}' 未注册\n"
f"可用工具: {available}"
),
)
def _dispatch_local(
self, tool_name: str, arguments: dict[str, Any]
) -> DispatchResult:
"""调用本地工具"""
entry = self._local[tool_name]
start = time.time()
logger.info(f"🔧 调用本地工具: {tool_name} 参数={arguments}")
try:
content = entry.instance.execute(**arguments)
elapsed = time.time() - start
logger.info(f"✅ 本地工具完成: {tool_name} 耗时={elapsed:.2f}s")
return DispatchResult(
tool_name=tool_name,
source="local",
success=True,
content=str(content),
elapsed_sec=elapsed,
)
except Exception as e:
elapsed = time.time() - start
logger.error(f"❌ 本地工具异常: {tool_name} {e}")
return DispatchResult(
tool_name=tool_name,
source="local",
success=False,
error=str(e),
elapsed_sec=elapsed,
)
def _dispatch_remote(
self, tool_name: str, arguments: dict[str, Any]
) -> DispatchResult:
"""调用远端 MCP Skill 工具"""
client, tool = self._remote[tool_name]
logger.info(
f"🌐 调用远端工具: [{tool.skill_name}] / {tool_name} "
f"参数={arguments}"
)
result = client.call_tool(tool_name, arguments)
return DispatchResult(
tool_name=tool_name,
source=f"remote:{tool.skill_name}",
success=result.success,
content=result.content,
error=result.error,
elapsed_sec=result.elapsed_sec,
)
# ── 生命周期 ──────────────────────────────────────────────
def close(self) -> None:
"""关闭所有在线 Skill 连接"""
for client in self._clients:
try:
client.close()
except Exception as e:
logger.warning(f"⚠️ 关闭 Skill [{client.cfg.name}] 时异常: {e}")
self._clients.clear()
self._remote.clear()
logger.info("🔌 SkillRegistry 已关闭所有连接")
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
def __repr__(self) -> str:
return (
f"SkillRegistry("
f"local={list(self._local.keys())}, "
f"remote={list(self._remote.keys())})"
)