614 lines
22 KiB
Python
614 lines
22 KiB
Python
"""
|
||
mcp/mcp_skill_client.py
|
||
在线 MCP Server 客户端
|
||
|
||
负责连接单个远端 MCP Server,获取其工具列表,并代理调用工具。
|
||
支持三种传输协议:
|
||
- sse : Server-Sent Events(最常见的在线 MCP 形式)
|
||
- http : Streamable HTTP
|
||
- stdio : 本地子进程(通过 stdin/stdout 通信)
|
||
|
||
依赖:
|
||
pip install httpx>=0.27.0 httpx-sse>=0.4.0
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import subprocess
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Iterator
|
||
|
||
from config.settings import MCPSkillConfig
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger("MCP.SkillClient")
|
||
|
||
try:
|
||
import httpx
|
||
_HTTPX_AVAILABLE = True
|
||
except ImportError:
|
||
_HTTPX_AVAILABLE = False
|
||
logger.warning("⚠️ httpx 未安装,请执行: pip install httpx>=0.27.0")
|
||
|
||
try:
|
||
from httpx_sse import connect_sse
|
||
_SSE_AVAILABLE = True
|
||
except ImportError:
|
||
_SSE_AVAILABLE = False
|
||
logger.warning("⚠️ httpx-sse 未安装,请执行: pip install httpx-sse>=0.4.0")
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# MCP JSON-RPC 协议常量
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
_JSONRPC = "2.0"
|
||
_METHOD_INITIALIZE = "initialize"
|
||
_METHOD_LIST_TOOLS = "tools/list"
|
||
_METHOD_CALL_TOOL = "tools/call"
|
||
_METHOD_PING = "ping"
|
||
|
||
_CLIENT_INFO = {
|
||
"name": "agent-demo",
|
||
"version": "1.0.0",
|
||
}
|
||
_PROTOCOL_VERSION = "2024-11-05"
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 数据结构
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
@dataclass
|
||
class RemoteTool:
|
||
"""远端 MCP Server 暴露的单个工具描述"""
|
||
name: str
|
||
description: str
|
||
parameters: dict[str, Any] # JSON Schema
|
||
skill_name: str # 所属 skill 组名称(来自 config.yaml)
|
||
|
||
def to_function_schema(self) -> dict:
|
||
"""转换为 OpenAI function calling schema"""
|
||
return {
|
||
"name": self.name,
|
||
"description": self.description,
|
||
"parameters": self.parameters,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class ToolCallResult:
|
||
"""工具调用结果"""
|
||
tool_name: str
|
||
skill_name: str
|
||
success: bool
|
||
content: str = ""
|
||
error: str = ""
|
||
elapsed_sec: float = 0.0
|
||
|
||
def __str__(self) -> str:
|
||
if self.success:
|
||
return self.content
|
||
return f"❌ [{self.skill_name}/{self.tool_name}] 调用失败: {self.error}"
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 传输层基类
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class BaseTransport:
|
||
"""MCP 传输层基类"""
|
||
|
||
def __init__(self, cfg: MCPSkillConfig):
|
||
self.cfg = cfg
|
||
|
||
def send_request(self, method: str, params: dict | None = None) -> dict:
|
||
raise NotImplementedError
|
||
|
||
def close(self):
|
||
pass
|
||
|
||
def _make_request(self, method: str, params: dict | None = None) -> dict:
|
||
return {
|
||
"jsonrpc": _JSONRPC,
|
||
"id": str(uuid.uuid4()),
|
||
"method": method,
|
||
"params": params or {},
|
||
}
|
||
|
||
def _check_response(self, resp: dict, method: str) -> dict:
|
||
if "error" in resp:
|
||
err = resp["error"]
|
||
raise RuntimeError(
|
||
f"MCP 错误 [{method}]: "
|
||
f"code={err.get('code')} msg={err.get('message')}"
|
||
)
|
||
return resp.get("result", {})
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# SSE 传输层
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class SSETransport(BaseTransport):
|
||
"""
|
||
Server-Sent Events 传输层
|
||
|
||
MCP over SSE 协议流程:
|
||
1. GET {url} → 建立 SSE 连接,服务端推送 endpoint 事件
|
||
2. POST {endpoint_url} → 发送 JSON-RPC 请求
|
||
3. SSE 流 → 接收响应事件
|
||
"""
|
||
|
||
def __init__(self, cfg: MCPSkillConfig):
|
||
super().__init__(cfg)
|
||
if not _HTTPX_AVAILABLE or not _SSE_AVAILABLE:
|
||
raise RuntimeError("SSE 传输需要: pip install httpx httpx-sse")
|
||
self._client = httpx.Client(
|
||
headers=cfg.headers,
|
||
timeout=cfg.timeout,
|
||
)
|
||
self._endpoint_url: str = ""
|
||
self._pending: dict = {} # id → response
|
||
self._sse_thread: threading.Thread | None = None
|
||
self._connected: bool = False
|
||
self._lock = threading.Lock()
|
||
self._connect()
|
||
|
||
def _connect(self) -> None:
|
||
"""建立 SSE 连接,获取 endpoint URL"""
|
||
logger.info(
|
||
f"🔌 SSE 连接: {self.cfg.url} "
|
||
f"timeout={self.cfg.timeout}s"
|
||
)
|
||
self._sse_thread = threading.Thread(
|
||
target=self._sse_listener, daemon=True
|
||
)
|
||
self._sse_thread.start()
|
||
# 等待 endpoint 事件(最多 10s)
|
||
deadline = time.time() + 10
|
||
while not self._endpoint_url and time.time() < deadline:
|
||
time.sleep(0.05)
|
||
if not self._endpoint_url:
|
||
raise RuntimeError(
|
||
f"SSE 连接超时:未收到 endpoint 事件\n"
|
||
f" URL: {self.cfg.url}\n"
|
||
f" 请检查 MCP Server 是否正常运行"
|
||
)
|
||
self._connected = True
|
||
logger.info(f"✅ SSE 已连接,endpoint: {self._endpoint_url}")
|
||
|
||
def _sse_listener(self) -> None:
|
||
"""后台线程:持续监听 SSE 事件流"""
|
||
try:
|
||
with connect_sse(
|
||
self._client, "GET", self.cfg.url
|
||
) as event_source:
|
||
for event in event_source.iter_sse():
|
||
self._handle_sse_event(event)
|
||
except Exception as e:
|
||
logger.error(f"❌ SSE 监听异常: {e}")
|
||
self._connected = False
|
||
|
||
def _handle_sse_event(self, event) -> None:
|
||
"""处理单条 SSE 事件"""
|
||
if event.event == "endpoint":
|
||
# 服务端推送 POST endpoint URL
|
||
raw = event.data.strip()
|
||
if raw.startswith("http"):
|
||
self._endpoint_url = raw
|
||
else:
|
||
# 相对路径,拼接 base URL
|
||
from urllib.parse import urljoin
|
||
self._endpoint_url = urljoin(self.cfg.url, raw)
|
||
|
||
elif event.event == "message":
|
||
try:
|
||
data = json.loads(event.data)
|
||
req_id = str(data.get("id", ""))
|
||
with self._lock:
|
||
self._pending[req_id] = data
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
def send_request(self, method: str, params: dict | None = None) -> dict:
|
||
"""发送 JSON-RPC 请求并等待响应"""
|
||
req = self._make_request(method, params)
|
||
req_id = req["id"]
|
||
|
||
resp = self._client.post(
|
||
self._endpoint_url,
|
||
json=req,
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
resp.raise_for_status()
|
||
|
||
# 等待 SSE 响应(最多 timeout 秒)
|
||
deadline = time.time() + self.cfg.timeout
|
||
while time.time() < deadline:
|
||
with self._lock:
|
||
if req_id in self._pending:
|
||
result = self._pending.pop(req_id)
|
||
return self._check_response(result, method)
|
||
time.sleep(0.02)
|
||
|
||
raise TimeoutError(
|
||
f"等待 MCP 响应超时 (>{self.cfg.timeout}s) "
|
||
f"method={method} skill={self.cfg.name}"
|
||
)
|
||
|
||
def close(self):
|
||
self._client.close()
|
||
self._connected = False
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# HTTP 传输层(Streamable HTTP)
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class HTTPTransport(BaseTransport):
|
||
"""
|
||
Streamable HTTP 传输层(MCP 2024-11-05 规范)
|
||
|
||
直接 POST JSON-RPC 到固定 URL,响应为 JSON 或 SSE 流
|
||
"""
|
||
|
||
def __init__(self, cfg: MCPSkillConfig):
|
||
super().__init__(cfg)
|
||
if not _HTTPX_AVAILABLE:
|
||
raise RuntimeError("HTTP 传输需要: pip install httpx")
|
||
self._client = httpx.Client(
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Accept": "application/json, text/event-stream",
|
||
**cfg.headers,
|
||
},
|
||
timeout=cfg.timeout,
|
||
)
|
||
logger.info(f"🔌 HTTP 传输初始化: {cfg.url}")
|
||
|
||
def send_request(self, method: str, params: dict | None = None) -> dict:
|
||
req = self._make_request(method, params)
|
||
resp = self._client.post(self.cfg.url, json=req)
|
||
resp.raise_for_status()
|
||
|
||
content_type = resp.headers.get("content-type", "")
|
||
if "text/event-stream" in content_type:
|
||
# 解析 SSE 格式响应
|
||
return self._parse_sse_response(resp.text, method)
|
||
else:
|
||
data = resp.json()
|
||
return self._check_response(data, method)
|
||
|
||
def _parse_sse_response(self, text: str, method: str) -> dict:
|
||
"""解析 SSE 格式的 HTTP 响应体"""
|
||
for line in text.splitlines():
|
||
if line.startswith("data:"):
|
||
raw = line[5:].strip()
|
||
if raw and raw != "[DONE]":
|
||
try:
|
||
data = json.loads(raw)
|
||
return self._check_response(data, method)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
raise RuntimeError(f"无法解析 SSE 响应: {text[:200]}")
|
||
|
||
def close(self):
|
||
self._client.close()
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# stdio 传输层
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class StdioTransport(BaseTransport):
|
||
"""
|
||
stdio 传输层:启动本地子进程,通过 stdin/stdout 通信
|
||
"""
|
||
|
||
def __init__(self, cfg: MCPSkillConfig):
|
||
super().__init__(cfg)
|
||
if not cfg.command:
|
||
raise ValueError(
|
||
f"stdio 传输需要配置 command\n"
|
||
f" skill: {cfg.name}\n"
|
||
f" 请在 config.yaml mcp_skills[{cfg.name}].command 中设置"
|
||
)
|
||
import os as _os
|
||
env = {**_os.environ, **cfg.env}
|
||
cmd = [cfg.command] + cfg.args
|
||
logger.info(f"🔌 stdio 启动子进程: {' '.join(cmd)}")
|
||
self._proc = subprocess.Popen(
|
||
cmd,
|
||
stdin=subprocess.PIPE,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
env=env,
|
||
text=True,
|
||
encoding="utf-8",
|
||
)
|
||
self._lock = threading.Lock()
|
||
logger.info(f"✅ stdio 子进程已启动 PID={self._proc.pid}")
|
||
|
||
def send_request(self, method: str, params: dict | None = None) -> dict:
|
||
req = self._make_request(method, params)
|
||
line = json.dumps(req, ensure_ascii=False) + "\n"
|
||
with self._lock:
|
||
self._proc.stdin.write(line)
|
||
self._proc.stdin.flush()
|
||
resp_line = self._proc.stdout.readline()
|
||
if not resp_line:
|
||
raise RuntimeError(
|
||
f"stdio 子进程无响应 skill={self.cfg.name} method={method}"
|
||
)
|
||
data = json.loads(resp_line)
|
||
return self._check_response(data, method)
|
||
|
||
def close(self):
|
||
if self._proc and self._proc.poll() is None:
|
||
self._proc.terminate()
|
||
try:
|
||
self._proc.wait(timeout=5)
|
||
except subprocess.TimeoutExpired:
|
||
self._proc.kill()
|
||
logger.debug(f"🔌 stdio 子进程已关闭 skill={self.cfg.name}")
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# 传输层工厂
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
def _create_transport(cfg: MCPSkillConfig) -> BaseTransport:
|
||
match cfg.transport.lower():
|
||
case "sse":
|
||
return SSETransport(cfg)
|
||
case "http":
|
||
return HTTPTransport(cfg)
|
||
case "stdio":
|
||
return StdioTransport(cfg)
|
||
case _:
|
||
raise ValueError(
|
||
f"不支持的传输协议: {cfg.transport}\n"
|
||
f" skill: {cfg.name}\n"
|
||
f" 可选值: sse | http | stdio"
|
||
)
|
||
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# MCP Skill 客户端
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
class MCPSkillClient:
|
||
"""
|
||
单个在线 MCP Server 的客户端
|
||
|
||
职责:
|
||
1. 建立连接(SSE / HTTP / stdio)
|
||
2. 执行 MCP initialize 握手
|
||
3. 获取工具列表(list_tools)
|
||
4. 代理调用工具(call_tool)
|
||
5. 支持 include/exclude 过滤
|
||
6. 支持失败重试(retry 次数来自 config.yaml)
|
||
|
||
用法:
|
||
client = MCPSkillClient(skill_cfg)
|
||
client.connect()
|
||
tools = client.list_tools()
|
||
result = client.call_tool("tool_name", {"arg": "value"})
|
||
client.close()
|
||
"""
|
||
|
||
def __init__(self, cfg: MCPSkillConfig):
|
||
self.cfg: MCPSkillConfig = cfg
|
||
self._transport: BaseTransport | None = None
|
||
self._tools: list[RemoteTool] = []
|
||
self._initialized: bool = False
|
||
|
||
# ── 连接管理 ──────────────────────────────────────────────
|
||
|
||
def connect(self) -> None:
|
||
"""建立连接并完成 MCP 握手"""
|
||
logger.info(
|
||
f"🌐 连接在线 MCP Skill: [{self.cfg.name}]\n"
|
||
f" 传输协议: {self.cfg.transport}\n"
|
||
f" 地址 : {self.cfg.url or self.cfg.command}\n"
|
||
f" 超时 : {self.cfg.timeout}s\n"
|
||
f" 重试 : {self.cfg.retry} 次"
|
||
)
|
||
last_err: Exception | None = None
|
||
for attempt in range(self.cfg.retry + 1):
|
||
try:
|
||
self._transport = _create_transport(self.cfg)
|
||
self._handshake()
|
||
self._initialized = True
|
||
logger.info(f"✅ MCP Skill [{self.cfg.name}] 连接成功")
|
||
return
|
||
except Exception as e:
|
||
last_err = e
|
||
if attempt < self.cfg.retry:
|
||
wait = 2 ** attempt # 指数退避
|
||
logger.warning(
|
||
f"⚠️ 连接失败 (attempt {attempt + 1}/{self.cfg.retry + 1}),"
|
||
f"{wait}s 后重试: {e}"
|
||
)
|
||
time.sleep(wait)
|
||
if self._transport:
|
||
try:
|
||
self._transport.close()
|
||
except Exception:
|
||
pass
|
||
self._transport = None
|
||
|
||
raise ConnectionError(
|
||
f"❌ MCP Skill [{self.cfg.name}] 连接失败(已重试 {self.cfg.retry} 次)\n"
|
||
f" 最后错误: {last_err}"
|
||
)
|
||
|
||
def _handshake(self) -> None:
|
||
"""执行 MCP initialize 握手"""
|
||
result = self._transport.send_request(
|
||
_METHOD_INITIALIZE,
|
||
{
|
||
"protocolVersion": _PROTOCOL_VERSION,
|
||
"capabilities": {"tools": {}},
|
||
"clientInfo": _CLIENT_INFO,
|
||
},
|
||
)
|
||
server_info = result.get("serverInfo", {})
|
||
server_version = result.get("protocolVersion", "unknown")
|
||
logger.info(
|
||
f"🤝 MCP 握手成功 [{self.cfg.name}]\n"
|
||
f" 服务端: {server_info.get('name', 'unknown')} "
|
||
f"v{server_info.get('version', '?')}\n"
|
||
f" 协议版本: {server_version}"
|
||
)
|
||
|
||
def close(self) -> None:
|
||
if self._transport:
|
||
self._transport.close()
|
||
self._transport = None
|
||
self._initialized = False
|
||
logger.debug(f"🔌 MCP Skill [{self.cfg.name}] 已断开")
|
||
|
||
def __enter__(self):
|
||
self.connect()
|
||
return self
|
||
|
||
def __exit__(self, *_):
|
||
self.close()
|
||
|
||
# ── 工具发现 ──────────────────────────────────────────────
|
||
|
||
def list_tools(self, force_refresh: bool = False) -> list[RemoteTool]:
|
||
"""
|
||
获取远端工具列表(带缓存)
|
||
|
||
Args:
|
||
force_refresh: 强制重新拉取,忽略缓存
|
||
|
||
Returns:
|
||
经过 include/exclude 过滤后的 RemoteTool 列表
|
||
"""
|
||
if self._tools and not force_refresh:
|
||
return self._tools
|
||
|
||
self._ensure_connected()
|
||
result = self._transport.send_request(_METHOD_LIST_TOOLS)
|
||
raw_tools = result.get("tools", [])
|
||
|
||
tools = []
|
||
for t in raw_tools:
|
||
name = t.get("name", "")
|
||
if not name:
|
||
continue
|
||
# include / exclude 过滤(来自 config.yaml)
|
||
if not self.cfg.is_tool_allowed(name):
|
||
logger.debug(
|
||
f" ⏭ 跳过工具 [{name}](被 include/exclude 过滤)"
|
||
)
|
||
continue
|
||
tools.append(RemoteTool(
|
||
name=name,
|
||
description=t.get("description", ""),
|
||
parameters=t.get("inputSchema", {"type": "object", "properties": {}}),
|
||
skill_name=self.cfg.name,
|
||
))
|
||
|
||
self._tools = tools
|
||
logger.info(
|
||
f"📦 MCP Skill [{self.cfg.name}] 工具列表:\n"
|
||
+ "\n".join(f" • {t.name}: {t.description[:60]}" for t in tools)
|
||
)
|
||
return tools
|
||
|
||
# ── 工具调用 ──────────────────────────────────────────────
|
||
|
||
def call_tool(
|
||
self,
|
||
tool_name: str,
|
||
arguments: dict[str, Any],
|
||
) -> ToolCallResult:
|
||
"""
|
||
调用远端工具
|
||
|
||
Args:
|
||
tool_name : 工具名称
|
||
arguments : 工具参数字典
|
||
|
||
Returns:
|
||
ToolCallResult 实例
|
||
"""
|
||
self._ensure_connected()
|
||
start = time.time()
|
||
logger.info(
|
||
f"🔧 调用远端工具: [{self.cfg.name}] / {tool_name}\n"
|
||
f" 参数: {json.dumps(arguments, ensure_ascii=False)[:200]}"
|
||
)
|
||
|
||
try:
|
||
result = self._transport.send_request(
|
||
_METHOD_CALL_TOOL,
|
||
{"name": tool_name, "arguments": arguments},
|
||
)
|
||
elapsed = time.time() - start
|
||
content = self._extract_content(result)
|
||
logger.info(
|
||
f"✅ 工具调用成功: {tool_name} 耗时={elapsed:.2f}s\n"
|
||
f" 结果: {content[:150]}"
|
||
)
|
||
return ToolCallResult(
|
||
tool_name=tool_name,
|
||
skill_name=self.cfg.name,
|
||
success=True,
|
||
content=content,
|
||
elapsed_sec=elapsed,
|
||
)
|
||
|
||
except Exception as e:
|
||
elapsed = time.time() - start
|
||
logger.error(f"❌ 工具调用失败: {tool_name} {e}")
|
||
return ToolCallResult(
|
||
tool_name=tool_name,
|
||
skill_name=self.cfg.name,
|
||
success=False,
|
||
error=str(e),
|
||
elapsed_sec=elapsed,
|
||
)
|
||
|
||
# ── 私有工具方法 ──────────────────────────────────────────
|
||
|
||
def _ensure_connected(self) -> None:
|
||
if not self._initialized or not self._transport:
|
||
raise RuntimeError(
|
||
f"MCP Skill [{self.cfg.name}] 未连接,请先调用 connect()"
|
||
)
|
||
|
||
@staticmethod
|
||
def _extract_content(result: dict) -> str:
|
||
"""
|
||
从 MCP tools/call 响应中提取文本内容
|
||
|
||
MCP 响应格式:
|
||
{"content": [{"type": "text", "text": "..."}]}
|
||
或
|
||
{"content": [{"type": "image", "data": "...", "mimeType": "..."}]}
|
||
"""
|
||
content_list = result.get("content", [])
|
||
if not content_list:
|
||
return json.dumps(result, ensure_ascii=False)
|
||
|
||
parts = []
|
||
for item in content_list:
|
||
match item.get("type"):
|
||
case "text":
|
||
parts.append(item.get("text", ""))
|
||
case "image":
|
||
parts.append(f"[图片: {item.get('mimeType', 'image')}]")
|
||
case "resource":
|
||
parts.append(f"[资源: {item.get('uri', '')}]")
|
||
case _:
|
||
parts.append(json.dumps(item, ensure_ascii=False))
|
||
return "\n".join(parts) |