base_agent/mcp/mcp_server.py

213 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
mcp/mcp_server.py
本地 MCP Server —— 集成 SkillRegistry统一处理本地工具和在线 Skill 调用
"""
import json
import sys
from typing import Any
from config.settings import settings
from mcp.skill_registry import SkillRegistry
from tools.calculator import CalculatorTool
from tools.code_executor import CodeExecutorTool
from tools.file_reader import FileReaderTool
from tools.ssh_docker import SSHDockerTool
from tools.static_analyzer import StaticAnalyzerTool
from tools.web_search import WebSearchTool
from utils.logger import get_logger
logger = get_logger("MCP.Server")
# 本地工具类映射表
_LOCAL_TOOL_CLASSES: dict[str, type] = {
"calculator": CalculatorTool,
"web_search": WebSearchTool,
"file_reader": FileReaderTool,
"code_executor": CodeExecutorTool,
"static_analyzer": StaticAnalyzerTool,
"ssh_docker": SSHDockerTool,
}
class MCPServer:
"""
本地 MCP Server
启动流程:
1. 根据 config.yaml mcp.enabled_tools 实例化本地工具
2. 通过 SkillRegistry 注册本地工具
3. 连接 config.yaml mcp_skills 中所有 enabled 的在线 MCP Skill
4. 进入请求处理循环stdio 模式)
"""
def __init__(self):
self.registry = SkillRegistry()
self._setup()
def _setup(self) -> None:
"""初始化:注册本地工具 + 连接在线 Skill"""
# ── 注册本地工具 ──────────────────────────────────────
enabled = settings.mcp.enabled_tools
logger.info(f"🔧 注册本地工具: {enabled}")
for tool_name in enabled:
cls = _LOCAL_TOOL_CLASSES.get(tool_name)
if cls:
self.registry.register_local(cls())
else:
logger.warning(f"⚠️ 未知工具: {tool_name},跳过")
# ── 连接在线 MCP Skill ────────────────────────────────
skill_map = self.registry.connect_skills()
if skill_map:
logger.info(
"🌐 在线 Skill 注册汇总:\n" +
"\n".join(
f" [{name}]: {tools}"
for name, tools in skill_map.items()
)
)
# ── 打印工具总览 ──────────────────────────────────────
all_tools = self.registry.list_all_tools()
logger.info(
f"📦 工具总览(共 {len(all_tools)} 个):\n" +
"\n".join(
f" {'🔵' if t['source'] == 'local' else '🟢'} "
f"[{t['source']:20s}] {t['name']}: {t['description']}"
for t in all_tools
)
)
# ── 请求处理 ──────────────────────────────────────────────
def handle_request(self, request: dict) -> dict:
"""
处理单条 JSON-RPC 请求
支持的 method:
initialize → 握手
tools/list → 返回所有工具 schema
tools/call → 调用工具(自动路由本地/远端)
ping → 心跳
"""
method = request.get("method", "")
req_id = request.get("id")
params = request.get("params", {})
logger.debug(f"📨 收到请求: method={method} id={req_id}")
try:
match method:
case "initialize":
result = self._handle_initialize(params)
case "tools/list":
result = self._handle_list_tools()
case "tools/call":
result = self._handle_call_tool(params)
case "ping":
result = {}
case _:
return self._error_response(
req_id, -32601, f"Method not found: {method}"
)
return {"jsonrpc": "2.0", "id": req_id, "result": result}
except Exception as e:
logger.error(f"❌ 处理请求异常: {e}")
return self._error_response(req_id, -32603, str(e))
def _handle_initialize(self, params: dict) -> dict:
client_info = params.get("clientInfo", {})
logger.info(
f"🤝 MCP 握手\n"
f" 客户端: {client_info.get('name', 'unknown')} "
f"v{client_info.get('version', '?')}\n"
f" 协议版本: {params.get('protocolVersion', 'unknown')}"
)
return {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {
"name": settings.mcp.server_name,
"version": "1.0.0",
},
}
def _handle_list_tools(self) -> dict:
schemas = self.registry.get_all_schemas()
logger.debug(f"📋 tools/list → {len(schemas)} 个工具")
return {"tools": schemas}
def _handle_call_tool(self, params: dict) -> dict:
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
if not tool_name:
raise ValueError("tools/call 缺少 name 参数")
result = self.registry.dispatch(tool_name, arguments)
logger.info(
f"{'' if result.success else ''} "
f"tools/call [{result.source}] {tool_name} "
f"耗时={result.elapsed_sec:.2f}s"
)
if not result.success:
raise RuntimeError(result.error)
return {
"content": [{"type": "text", "text": result.content}]
}
@staticmethod
def _error_response(req_id: Any, code: int, message: str) -> dict:
return {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": code, "message": message},
}
# ── stdio 运行模式 ────────────────────────────────────────
def run_stdio(self) -> None:
"""
stdio 模式主循环
从 stdin 逐行读取 JSON-RPC 请求,向 stdout 写入响应
"""
logger.info(
f"🚀 {settings.mcp.server_name} 已启动stdio 模式)\n"
f"{settings.display()}"
)
try:
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = self.handle_request(request)
print(json.dumps(response, ensure_ascii=False), flush=True)
except json.JSONDecodeError as e:
err = self._error_response(None, -32700, f"Parse error: {e}")
print(json.dumps(err, ensure_ascii=False), flush=True)
except KeyboardInterrupt:
logger.info("⏹ 收到中断信号,正在关闭...")
finally:
self.registry.close()
logger.info("👋 MCP Server 已关闭")
def close(self) -> None:
self.registry.close()
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
# ── 入口 ──────────────────────────────────────────────────────
if __name__ == "__main__":
with MCPServer() as server:
server.run_stdio()