base_agent/mcp/mcp_skill_client.py

614 lines
22 KiB
Python
Raw Normal View History

2026-03-30 08:48:36 +00:00
"""
mcp/mcp_skill_client.py
在线 MCP Server 客户端
负责连接单个远端 MCP Server获取其工具列表并代理调用工具
支持三种传输协议:
- sse : Server-Sent Events最常见的在线 MCP 形式
- http : Streamable HTTP
- stdio : 本地子进程通过 stdin/stdout 通信
依赖:
pip install httpx>=0.27.0 httpx-sse>=0.4.0
"""
import asyncio
import json
import subprocess
import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterator
from config.settings import MCPSkillConfig
from utils.logger import get_logger
logger = get_logger("MCP.SkillClient")
try:
import httpx
_HTTPX_AVAILABLE = True
except ImportError:
_HTTPX_AVAILABLE = False
logger.warning("⚠️ httpx 未安装,请执行: pip install httpx>=0.27.0")
try:
from httpx_sse import connect_sse
_SSE_AVAILABLE = True
except ImportError:
_SSE_AVAILABLE = False
logger.warning("⚠️ httpx-sse 未安装,请执行: pip install httpx-sse>=0.4.0")
# ════════════════════════════════════════════════════════════════
# MCP JSON-RPC 协议常量
# ════════════════════════════════════════════════════════════════
_JSONRPC = "2.0"
_METHOD_INITIALIZE = "initialize"
_METHOD_LIST_TOOLS = "tools/list"
_METHOD_CALL_TOOL = "tools/call"
_METHOD_PING = "ping"
_CLIENT_INFO = {
"name": "agent-demo",
"version": "1.0.0",
}
_PROTOCOL_VERSION = "2024-11-05"
# ════════════════════════════════════════════════════════════════
# 数据结构
# ════════════════════════════════════════════════════════════════
@dataclass
class RemoteTool:
"""远端 MCP Server 暴露的单个工具描述"""
name: str
description: str
parameters: dict[str, Any] # JSON Schema
skill_name: str # 所属 skill 组名称(来自 config.yaml
def to_function_schema(self) -> dict:
"""转换为 OpenAI function calling schema"""
return {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
@dataclass
class ToolCallResult:
"""工具调用结果"""
tool_name: str
skill_name: str
success: bool
content: str = ""
error: str = ""
elapsed_sec: float = 0.0
def __str__(self) -> str:
if self.success:
return self.content
return f"❌ [{self.skill_name}/{self.tool_name}] 调用失败: {self.error}"
# ════════════════════════════════════════════════════════════════
# 传输层基类
# ════════════════════════════════════════════════════════════════
class BaseTransport:
"""MCP 传输层基类"""
def __init__(self, cfg: MCPSkillConfig):
self.cfg = cfg
def send_request(self, method: str, params: dict | None = None) -> dict:
raise NotImplementedError
def close(self):
pass
def _make_request(self, method: str, params: dict | None = None) -> dict:
return {
"jsonrpc": _JSONRPC,
"id": str(uuid.uuid4()),
"method": method,
"params": params or {},
}
def _check_response(self, resp: dict, method: str) -> dict:
if "error" in resp:
err = resp["error"]
raise RuntimeError(
f"MCP 错误 [{method}]: "
f"code={err.get('code')} msg={err.get('message')}"
)
return resp.get("result", {})
# ════════════════════════════════════════════════════════════════
# SSE 传输层
# ════════════════════════════════════════════════════════════════
class SSETransport(BaseTransport):
"""
Server-Sent Events 传输层
MCP over SSE 协议流程:
1. GET {url} 建立 SSE 连接服务端推送 endpoint 事件
2. POST {endpoint_url} 发送 JSON-RPC 请求
3. SSE 接收响应事件
"""
def __init__(self, cfg: MCPSkillConfig):
super().__init__(cfg)
if not _HTTPX_AVAILABLE or not _SSE_AVAILABLE:
raise RuntimeError("SSE 传输需要: pip install httpx httpx-sse")
self._client = httpx.Client(
headers=cfg.headers,
timeout=cfg.timeout,
)
self._endpoint_url: str = ""
self._pending: dict = {} # id → response
self._sse_thread: threading.Thread | None = None
self._connected: bool = False
self._lock = threading.Lock()
self._connect()
def _connect(self) -> None:
"""建立 SSE 连接,获取 endpoint URL"""
logger.info(
f"🔌 SSE 连接: {self.cfg.url} "
f"timeout={self.cfg.timeout}s"
)
self._sse_thread = threading.Thread(
target=self._sse_listener, daemon=True
)
self._sse_thread.start()
# 等待 endpoint 事件(最多 10s
deadline = time.time() + 10
while not self._endpoint_url and time.time() < deadline:
time.sleep(0.05)
if not self._endpoint_url:
raise RuntimeError(
f"SSE 连接超时:未收到 endpoint 事件\n"
f" URL: {self.cfg.url}\n"
f" 请检查 MCP Server 是否正常运行"
)
self._connected = True
logger.info(f"✅ SSE 已连接endpoint: {self._endpoint_url}")
def _sse_listener(self) -> None:
"""后台线程:持续监听 SSE 事件流"""
try:
with connect_sse(
self._client, "GET", self.cfg.url
) as event_source:
for event in event_source.iter_sse():
self._handle_sse_event(event)
except Exception as e:
logger.error(f"❌ SSE 监听异常: {e}")
self._connected = False
def _handle_sse_event(self, event) -> None:
"""处理单条 SSE 事件"""
if event.event == "endpoint":
# 服务端推送 POST endpoint URL
raw = event.data.strip()
if raw.startswith("http"):
self._endpoint_url = raw
else:
# 相对路径,拼接 base URL
from urllib.parse import urljoin
self._endpoint_url = urljoin(self.cfg.url, raw)
elif event.event == "message":
try:
data = json.loads(event.data)
req_id = str(data.get("id", ""))
with self._lock:
self._pending[req_id] = data
except json.JSONDecodeError:
pass
def send_request(self, method: str, params: dict | None = None) -> dict:
"""发送 JSON-RPC 请求并等待响应"""
req = self._make_request(method, params)
req_id = req["id"]
resp = self._client.post(
self._endpoint_url,
json=req,
headers={"Content-Type": "application/json"},
)
resp.raise_for_status()
# 等待 SSE 响应(最多 timeout 秒)
deadline = time.time() + self.cfg.timeout
while time.time() < deadline:
with self._lock:
if req_id in self._pending:
result = self._pending.pop(req_id)
return self._check_response(result, method)
time.sleep(0.02)
raise TimeoutError(
f"等待 MCP 响应超时 (>{self.cfg.timeout}s) "
f"method={method} skill={self.cfg.name}"
)
def close(self):
self._client.close()
self._connected = False
# ════════════════════════════════════════════════════════════════
# HTTP 传输层Streamable HTTP
# ════════════════════════════════════════════════════════════════
class HTTPTransport(BaseTransport):
"""
Streamable HTTP 传输层MCP 2024-11-05 规范
直接 POST JSON-RPC 到固定 URL响应为 JSON SSE
"""
def __init__(self, cfg: MCPSkillConfig):
super().__init__(cfg)
if not _HTTPX_AVAILABLE:
raise RuntimeError("HTTP 传输需要: pip install httpx")
self._client = httpx.Client(
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
**cfg.headers,
},
timeout=cfg.timeout,
)
logger.info(f"🔌 HTTP 传输初始化: {cfg.url}")
def send_request(self, method: str, params: dict | None = None) -> dict:
req = self._make_request(method, params)
resp = self._client.post(self.cfg.url, json=req)
resp.raise_for_status()
content_type = resp.headers.get("content-type", "")
if "text/event-stream" in content_type:
# 解析 SSE 格式响应
return self._parse_sse_response(resp.text, method)
else:
data = resp.json()
return self._check_response(data, method)
def _parse_sse_response(self, text: str, method: str) -> dict:
"""解析 SSE 格式的 HTTP 响应体"""
for line in text.splitlines():
if line.startswith("data:"):
raw = line[5:].strip()
if raw and raw != "[DONE]":
try:
data = json.loads(raw)
return self._check_response(data, method)
except json.JSONDecodeError:
continue
raise RuntimeError(f"无法解析 SSE 响应: {text[:200]}")
def close(self):
self._client.close()
# ════════════════════════════════════════════════════════════════
# stdio 传输层
# ════════════════════════════════════════════════════════════════
class StdioTransport(BaseTransport):
"""
stdio 传输层启动本地子进程通过 stdin/stdout 通信
"""
def __init__(self, cfg: MCPSkillConfig):
super().__init__(cfg)
if not cfg.command:
raise ValueError(
f"stdio 传输需要配置 command\n"
f" skill: {cfg.name}\n"
f" 请在 config.yaml mcp_skills[{cfg.name}].command 中设置"
)
import os as _os
env = {**_os.environ, **cfg.env}
cmd = [cfg.command] + cfg.args
logger.info(f"🔌 stdio 启动子进程: {' '.join(cmd)}")
self._proc = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
text=True,
encoding="utf-8",
)
self._lock = threading.Lock()
logger.info(f"✅ stdio 子进程已启动 PID={self._proc.pid}")
def send_request(self, method: str, params: dict | None = None) -> dict:
req = self._make_request(method, params)
line = json.dumps(req, ensure_ascii=False) + "\n"
with self._lock:
self._proc.stdin.write(line)
self._proc.stdin.flush()
resp_line = self._proc.stdout.readline()
if not resp_line:
raise RuntimeError(
f"stdio 子进程无响应 skill={self.cfg.name} method={method}"
)
data = json.loads(resp_line)
return self._check_response(data, method)
def close(self):
if self._proc and self._proc.poll() is None:
self._proc.terminate()
try:
self._proc.wait(timeout=5)
except subprocess.TimeoutExpired:
self._proc.kill()
logger.debug(f"🔌 stdio 子进程已关闭 skill={self.cfg.name}")
# ════════════════════════════════════════════════════════════════
# 传输层工厂
# ════════════════════════════════════════════════════════════════
def _create_transport(cfg: MCPSkillConfig) -> BaseTransport:
match cfg.transport.lower():
case "sse":
return SSETransport(cfg)
case "http":
return HTTPTransport(cfg)
case "stdio":
return StdioTransport(cfg)
case _:
raise ValueError(
f"不支持的传输协议: {cfg.transport}\n"
f" skill: {cfg.name}\n"
f" 可选值: sse | http | stdio"
)
# ════════════════════════════════════════════════════════════════
# MCP Skill 客户端
# ════════════════════════════════════════════════════════════════
class MCPSkillClient:
"""
单个在线 MCP Server 的客户端
职责:
1. 建立连接SSE / HTTP / stdio
2. 执行 MCP initialize 握手
3. 获取工具列表list_tools
4. 代理调用工具call_tool
5. 支持 include/exclude 过滤
6. 支持失败重试retry 次数来自 config.yaml
用法:
client = MCPSkillClient(skill_cfg)
client.connect()
tools = client.list_tools()
result = client.call_tool("tool_name", {"arg": "value"})
client.close()
"""
def __init__(self, cfg: MCPSkillConfig):
self.cfg: MCPSkillConfig = cfg
self._transport: BaseTransport | None = None
self._tools: list[RemoteTool] = []
self._initialized: bool = False
# ── 连接管理 ──────────────────────────────────────────────
def connect(self) -> None:
"""建立连接并完成 MCP 握手"""
logger.info(
f"🌐 连接在线 MCP Skill: [{self.cfg.name}]\n"
f" 传输协议: {self.cfg.transport}\n"
f" 地址 : {self.cfg.url or self.cfg.command}\n"
f" 超时 : {self.cfg.timeout}s\n"
f" 重试 : {self.cfg.retry}"
)
last_err: Exception | None = None
for attempt in range(self.cfg.retry + 1):
try:
self._transport = _create_transport(self.cfg)
self._handshake()
self._initialized = True
logger.info(f"✅ MCP Skill [{self.cfg.name}] 连接成功")
return
except Exception as e:
last_err = e
if attempt < self.cfg.retry:
wait = 2 ** attempt # 指数退避
logger.warning(
f"⚠️ 连接失败 (attempt {attempt + 1}/{self.cfg.retry + 1})"
f"{wait}s 后重试: {e}"
)
time.sleep(wait)
if self._transport:
try:
self._transport.close()
except Exception:
pass
self._transport = None
raise ConnectionError(
f"❌ MCP Skill [{self.cfg.name}] 连接失败(已重试 {self.cfg.retry} 次)\n"
f" 最后错误: {last_err}"
)
def _handshake(self) -> None:
"""执行 MCP initialize 握手"""
result = self._transport.send_request(
_METHOD_INITIALIZE,
{
"protocolVersion": _PROTOCOL_VERSION,
"capabilities": {"tools": {}},
"clientInfo": _CLIENT_INFO,
},
)
server_info = result.get("serverInfo", {})
server_version = result.get("protocolVersion", "unknown")
logger.info(
f"🤝 MCP 握手成功 [{self.cfg.name}]\n"
f" 服务端: {server_info.get('name', 'unknown')} "
f"v{server_info.get('version', '?')}\n"
f" 协议版本: {server_version}"
)
def close(self) -> None:
if self._transport:
self._transport.close()
self._transport = None
self._initialized = False
logger.debug(f"🔌 MCP Skill [{self.cfg.name}] 已断开")
def __enter__(self):
self.connect()
return self
def __exit__(self, *_):
self.close()
# ── 工具发现 ──────────────────────────────────────────────
def list_tools(self, force_refresh: bool = False) -> list[RemoteTool]:
"""
获取远端工具列表带缓存
Args:
force_refresh: 强制重新拉取忽略缓存
Returns:
经过 include/exclude 过滤后的 RemoteTool 列表
"""
if self._tools and not force_refresh:
return self._tools
self._ensure_connected()
result = self._transport.send_request(_METHOD_LIST_TOOLS)
raw_tools = result.get("tools", [])
tools = []
for t in raw_tools:
name = t.get("name", "")
if not name:
continue
# include / exclude 过滤(来自 config.yaml
if not self.cfg.is_tool_allowed(name):
logger.debug(
f" ⏭ 跳过工具 [{name}](被 include/exclude 过滤)"
)
continue
tools.append(RemoteTool(
name=name,
description=t.get("description", ""),
parameters=t.get("inputSchema", {"type": "object", "properties": {}}),
skill_name=self.cfg.name,
))
self._tools = tools
logger.info(
f"📦 MCP Skill [{self.cfg.name}] 工具列表:\n"
+ "\n".join(f"{t.name}: {t.description[:60]}" for t in tools)
)
return tools
# ── 工具调用 ──────────────────────────────────────────────
def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
) -> ToolCallResult:
"""
调用远端工具
Args:
tool_name : 工具名称
arguments : 工具参数字典
Returns:
ToolCallResult 实例
"""
self._ensure_connected()
start = time.time()
logger.info(
f"🔧 调用远端工具: [{self.cfg.name}] / {tool_name}\n"
f" 参数: {json.dumps(arguments, ensure_ascii=False)[:200]}"
)
try:
result = self._transport.send_request(
_METHOD_CALL_TOOL,
{"name": tool_name, "arguments": arguments},
)
elapsed = time.time() - start
content = self._extract_content(result)
logger.info(
f"✅ 工具调用成功: {tool_name} 耗时={elapsed:.2f}s\n"
f" 结果: {content[:150]}"
)
return ToolCallResult(
tool_name=tool_name,
skill_name=self.cfg.name,
success=True,
content=content,
elapsed_sec=elapsed,
)
except Exception as e:
elapsed = time.time() - start
logger.error(f"❌ 工具调用失败: {tool_name} {e}")
return ToolCallResult(
tool_name=tool_name,
skill_name=self.cfg.name,
success=False,
error=str(e),
elapsed_sec=elapsed,
)
# ── 私有工具方法 ──────────────────────────────────────────
def _ensure_connected(self) -> None:
if not self._initialized or not self._transport:
raise RuntimeError(
f"MCP Skill [{self.cfg.name}] 未连接,请先调用 connect()"
)
@staticmethod
def _extract_content(result: dict) -> str:
"""
MCP tools/call 响应中提取文本内容
MCP 响应格式:
{"content": [{"type": "text", "text": "..."}]}
{"content": [{"type": "image", "data": "...", "mimeType": "..."}]}
"""
content_list = result.get("content", [])
if not content_list:
return json.dumps(result, ensure_ascii=False)
parts = []
for item in content_list:
match item.get("type"):
case "text":
parts.append(item.get("text", ""))
case "image":
parts.append(f"[图片: {item.get('mimeType', 'image')}]")
case "resource":
parts.append(f"[资源: {item.get('uri', '')}]")
case _:
parts.append(json.dumps(item, ensure_ascii=False))
return "\n".join(parts)