base_agent/mcp/mcp_skill_client.py

614 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
mcp/mcp_skill_client.py
在线 MCP Server 客户端
负责连接单个远端 MCP Server获取其工具列表并代理调用工具。
支持三种传输协议:
- sse : Server-Sent Events最常见的在线 MCP 形式)
- http : Streamable HTTP
- stdio : 本地子进程(通过 stdin/stdout 通信)
依赖:
pip install httpx>=0.27.0 httpx-sse>=0.4.0
"""
import asyncio
import json
import subprocess
import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterator
from config.settings import MCPSkillConfig
from utils.logger import get_logger
logger = get_logger("MCP.SkillClient")
try:
import httpx
_HTTPX_AVAILABLE = True
except ImportError:
_HTTPX_AVAILABLE = False
logger.warning("⚠️ httpx 未安装,请执行: pip install httpx>=0.27.0")
try:
from httpx_sse import connect_sse
_SSE_AVAILABLE = True
except ImportError:
_SSE_AVAILABLE = False
logger.warning("⚠️ httpx-sse 未安装,请执行: pip install httpx-sse>=0.4.0")
# ════════════════════════════════════════════════════════════════
# MCP JSON-RPC 协议常量
# ════════════════════════════════════════════════════════════════
_JSONRPC = "2.0"
_METHOD_INITIALIZE = "initialize"
_METHOD_LIST_TOOLS = "tools/list"
_METHOD_CALL_TOOL = "tools/call"
_METHOD_PING = "ping"
_CLIENT_INFO = {
"name": "agent-demo",
"version": "1.0.0",
}
_PROTOCOL_VERSION = "2024-11-05"
# ════════════════════════════════════════════════════════════════
# 数据结构
# ════════════════════════════════════════════════════════════════
@dataclass
class RemoteTool:
"""远端 MCP Server 暴露的单个工具描述"""
name: str
description: str
parameters: dict[str, Any] # JSON Schema
skill_name: str # 所属 skill 组名称(来自 config.yaml
def to_function_schema(self) -> dict:
"""转换为 OpenAI function calling schema"""
return {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
@dataclass
class ToolCallResult:
"""工具调用结果"""
tool_name: str
skill_name: str
success: bool
content: str = ""
error: str = ""
elapsed_sec: float = 0.0
def __str__(self) -> str:
if self.success:
return self.content
return f"❌ [{self.skill_name}/{self.tool_name}] 调用失败: {self.error}"
# ════════════════════════════════════════════════════════════════
# 传输层基类
# ════════════════════════════════════════════════════════════════
class BaseTransport:
"""MCP 传输层基类"""
def __init__(self, cfg: MCPSkillConfig):
self.cfg = cfg
def send_request(self, method: str, params: dict | None = None) -> dict:
raise NotImplementedError
def close(self):
pass
def _make_request(self, method: str, params: dict | None = None) -> dict:
return {
"jsonrpc": _JSONRPC,
"id": str(uuid.uuid4()),
"method": method,
"params": params or {},
}
def _check_response(self, resp: dict, method: str) -> dict:
if "error" in resp:
err = resp["error"]
raise RuntimeError(
f"MCP 错误 [{method}]: "
f"code={err.get('code')} msg={err.get('message')}"
)
return resp.get("result", {})
# ════════════════════════════════════════════════════════════════
# SSE 传输层
# ════════════════════════════════════════════════════════════════
class SSETransport(BaseTransport):
"""
Server-Sent Events 传输层
MCP over SSE 协议流程:
1. GET {url} → 建立 SSE 连接,服务端推送 endpoint 事件
2. POST {endpoint_url} → 发送 JSON-RPC 请求
3. SSE 流 → 接收响应事件
"""
def __init__(self, cfg: MCPSkillConfig):
super().__init__(cfg)
if not _HTTPX_AVAILABLE or not _SSE_AVAILABLE:
raise RuntimeError("SSE 传输需要: pip install httpx httpx-sse")
self._client = httpx.Client(
headers=cfg.headers,
timeout=cfg.timeout,
)
self._endpoint_url: str = ""
self._pending: dict = {} # id → response
self._sse_thread: threading.Thread | None = None
self._connected: bool = False
self._lock = threading.Lock()
self._connect()
def _connect(self) -> None:
"""建立 SSE 连接,获取 endpoint URL"""
logger.info(
f"🔌 SSE 连接: {self.cfg.url} "
f"timeout={self.cfg.timeout}s"
)
self._sse_thread = threading.Thread(
target=self._sse_listener, daemon=True
)
self._sse_thread.start()
# 等待 endpoint 事件(最多 10s
deadline = time.time() + 10
while not self._endpoint_url and time.time() < deadline:
time.sleep(0.05)
if not self._endpoint_url:
raise RuntimeError(
f"SSE 连接超时:未收到 endpoint 事件\n"
f" URL: {self.cfg.url}\n"
f" 请检查 MCP Server 是否正常运行"
)
self._connected = True
logger.info(f"✅ SSE 已连接endpoint: {self._endpoint_url}")
def _sse_listener(self) -> None:
"""后台线程:持续监听 SSE 事件流"""
try:
with connect_sse(
self._client, "GET", self.cfg.url
) as event_source:
for event in event_source.iter_sse():
self._handle_sse_event(event)
except Exception as e:
logger.error(f"❌ SSE 监听异常: {e}")
self._connected = False
def _handle_sse_event(self, event) -> None:
"""处理单条 SSE 事件"""
if event.event == "endpoint":
# 服务端推送 POST endpoint URL
raw = event.data.strip()
if raw.startswith("http"):
self._endpoint_url = raw
else:
# 相对路径,拼接 base URL
from urllib.parse import urljoin
self._endpoint_url = urljoin(self.cfg.url, raw)
elif event.event == "message":
try:
data = json.loads(event.data)
req_id = str(data.get("id", ""))
with self._lock:
self._pending[req_id] = data
except json.JSONDecodeError:
pass
def send_request(self, method: str, params: dict | None = None) -> dict:
"""发送 JSON-RPC 请求并等待响应"""
req = self._make_request(method, params)
req_id = req["id"]
resp = self._client.post(
self._endpoint_url,
json=req,
headers={"Content-Type": "application/json"},
)
resp.raise_for_status()
# 等待 SSE 响应(最多 timeout 秒)
deadline = time.time() + self.cfg.timeout
while time.time() < deadline:
with self._lock:
if req_id in self._pending:
result = self._pending.pop(req_id)
return self._check_response(result, method)
time.sleep(0.02)
raise TimeoutError(
f"等待 MCP 响应超时 (>{self.cfg.timeout}s) "
f"method={method} skill={self.cfg.name}"
)
def close(self):
self._client.close()
self._connected = False
# ════════════════════════════════════════════════════════════════
# HTTP 传输层Streamable HTTP
# ════════════════════════════════════════════════════════════════
class HTTPTransport(BaseTransport):
"""
Streamable HTTP 传输层MCP 2024-11-05 规范)
直接 POST JSON-RPC 到固定 URL响应为 JSON 或 SSE 流
"""
def __init__(self, cfg: MCPSkillConfig):
super().__init__(cfg)
if not _HTTPX_AVAILABLE:
raise RuntimeError("HTTP 传输需要: pip install httpx")
self._client = httpx.Client(
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
**cfg.headers,
},
timeout=cfg.timeout,
)
logger.info(f"🔌 HTTP 传输初始化: {cfg.url}")
def send_request(self, method: str, params: dict | None = None) -> dict:
req = self._make_request(method, params)
resp = self._client.post(self.cfg.url, json=req)
resp.raise_for_status()
content_type = resp.headers.get("content-type", "")
if "text/event-stream" in content_type:
# 解析 SSE 格式响应
return self._parse_sse_response(resp.text, method)
else:
data = resp.json()
return self._check_response(data, method)
def _parse_sse_response(self, text: str, method: str) -> dict:
"""解析 SSE 格式的 HTTP 响应体"""
for line in text.splitlines():
if line.startswith("data:"):
raw = line[5:].strip()
if raw and raw != "[DONE]":
try:
data = json.loads(raw)
return self._check_response(data, method)
except json.JSONDecodeError:
continue
raise RuntimeError(f"无法解析 SSE 响应: {text[:200]}")
def close(self):
self._client.close()
# ════════════════════════════════════════════════════════════════
# stdio 传输层
# ════════════════════════════════════════════════════════════════
class StdioTransport(BaseTransport):
"""
stdio 传输层:启动本地子进程,通过 stdin/stdout 通信
"""
def __init__(self, cfg: MCPSkillConfig):
super().__init__(cfg)
if not cfg.command:
raise ValueError(
f"stdio 传输需要配置 command\n"
f" skill: {cfg.name}\n"
f" 请在 config.yaml mcp_skills[{cfg.name}].command 中设置"
)
import os as _os
env = {**_os.environ, **cfg.env}
cmd = [cfg.command] + cfg.args
logger.info(f"🔌 stdio 启动子进程: {' '.join(cmd)}")
self._proc = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
text=True,
encoding="utf-8",
)
self._lock = threading.Lock()
logger.info(f"✅ stdio 子进程已启动 PID={self._proc.pid}")
def send_request(self, method: str, params: dict | None = None) -> dict:
req = self._make_request(method, params)
line = json.dumps(req, ensure_ascii=False) + "\n"
with self._lock:
self._proc.stdin.write(line)
self._proc.stdin.flush()
resp_line = self._proc.stdout.readline()
if not resp_line:
raise RuntimeError(
f"stdio 子进程无响应 skill={self.cfg.name} method={method}"
)
data = json.loads(resp_line)
return self._check_response(data, method)
def close(self):
if self._proc and self._proc.poll() is None:
self._proc.terminate()
try:
self._proc.wait(timeout=5)
except subprocess.TimeoutExpired:
self._proc.kill()
logger.debug(f"🔌 stdio 子进程已关闭 skill={self.cfg.name}")
# ════════════════════════════════════════════════════════════════
# 传输层工厂
# ════════════════════════════════════════════════════════════════
def _create_transport(cfg: MCPSkillConfig) -> BaseTransport:
match cfg.transport.lower():
case "sse":
return SSETransport(cfg)
case "http":
return HTTPTransport(cfg)
case "stdio":
return StdioTransport(cfg)
case _:
raise ValueError(
f"不支持的传输协议: {cfg.transport}\n"
f" skill: {cfg.name}\n"
f" 可选值: sse | http | stdio"
)
# ════════════════════════════════════════════════════════════════
# MCP Skill 客户端
# ════════════════════════════════════════════════════════════════
class MCPSkillClient:
"""
单个在线 MCP Server 的客户端
职责:
1. 建立连接SSE / HTTP / stdio
2. 执行 MCP initialize 握手
3. 获取工具列表list_tools
4. 代理调用工具call_tool
5. 支持 include/exclude 过滤
6. 支持失败重试retry 次数来自 config.yaml
用法:
client = MCPSkillClient(skill_cfg)
client.connect()
tools = client.list_tools()
result = client.call_tool("tool_name", {"arg": "value"})
client.close()
"""
def __init__(self, cfg: MCPSkillConfig):
self.cfg: MCPSkillConfig = cfg
self._transport: BaseTransport | None = None
self._tools: list[RemoteTool] = []
self._initialized: bool = False
# ── 连接管理 ──────────────────────────────────────────────
def connect(self) -> None:
"""建立连接并完成 MCP 握手"""
logger.info(
f"🌐 连接在线 MCP Skill: [{self.cfg.name}]\n"
f" 传输协议: {self.cfg.transport}\n"
f" 地址 : {self.cfg.url or self.cfg.command}\n"
f" 超时 : {self.cfg.timeout}s\n"
f" 重试 : {self.cfg.retry}"
)
last_err: Exception | None = None
for attempt in range(self.cfg.retry + 1):
try:
self._transport = _create_transport(self.cfg)
self._handshake()
self._initialized = True
logger.info(f"✅ MCP Skill [{self.cfg.name}] 连接成功")
return
except Exception as e:
last_err = e
if attempt < self.cfg.retry:
wait = 2 ** attempt # 指数退避
logger.warning(
f"⚠️ 连接失败 (attempt {attempt + 1}/{self.cfg.retry + 1})"
f"{wait}s 后重试: {e}"
)
time.sleep(wait)
if self._transport:
try:
self._transport.close()
except Exception:
pass
self._transport = None
raise ConnectionError(
f"❌ MCP Skill [{self.cfg.name}] 连接失败(已重试 {self.cfg.retry} 次)\n"
f" 最后错误: {last_err}"
)
def _handshake(self) -> None:
"""执行 MCP initialize 握手"""
result = self._transport.send_request(
_METHOD_INITIALIZE,
{
"protocolVersion": _PROTOCOL_VERSION,
"capabilities": {"tools": {}},
"clientInfo": _CLIENT_INFO,
},
)
server_info = result.get("serverInfo", {})
server_version = result.get("protocolVersion", "unknown")
logger.info(
f"🤝 MCP 握手成功 [{self.cfg.name}]\n"
f" 服务端: {server_info.get('name', 'unknown')} "
f"v{server_info.get('version', '?')}\n"
f" 协议版本: {server_version}"
)
def close(self) -> None:
if self._transport:
self._transport.close()
self._transport = None
self._initialized = False
logger.debug(f"🔌 MCP Skill [{self.cfg.name}] 已断开")
def __enter__(self):
self.connect()
return self
def __exit__(self, *_):
self.close()
# ── 工具发现 ──────────────────────────────────────────────
def list_tools(self, force_refresh: bool = False) -> list[RemoteTool]:
"""
获取远端工具列表(带缓存)
Args:
force_refresh: 强制重新拉取,忽略缓存
Returns:
经过 include/exclude 过滤后的 RemoteTool 列表
"""
if self._tools and not force_refresh:
return self._tools
self._ensure_connected()
result = self._transport.send_request(_METHOD_LIST_TOOLS)
raw_tools = result.get("tools", [])
tools = []
for t in raw_tools:
name = t.get("name", "")
if not name:
continue
# include / exclude 过滤(来自 config.yaml
if not self.cfg.is_tool_allowed(name):
logger.debug(
f" ⏭ 跳过工具 [{name}](被 include/exclude 过滤)"
)
continue
tools.append(RemoteTool(
name=name,
description=t.get("description", ""),
parameters=t.get("inputSchema", {"type": "object", "properties": {}}),
skill_name=self.cfg.name,
))
self._tools = tools
logger.info(
f"📦 MCP Skill [{self.cfg.name}] 工具列表:\n"
+ "\n".join(f"{t.name}: {t.description[:60]}" for t in tools)
)
return tools
# ── 工具调用 ──────────────────────────────────────────────
def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
) -> ToolCallResult:
"""
调用远端工具
Args:
tool_name : 工具名称
arguments : 工具参数字典
Returns:
ToolCallResult 实例
"""
self._ensure_connected()
start = time.time()
logger.info(
f"🔧 调用远端工具: [{self.cfg.name}] / {tool_name}\n"
f" 参数: {json.dumps(arguments, ensure_ascii=False)[:200]}"
)
try:
result = self._transport.send_request(
_METHOD_CALL_TOOL,
{"name": tool_name, "arguments": arguments},
)
elapsed = time.time() - start
content = self._extract_content(result)
logger.info(
f"✅ 工具调用成功: {tool_name} 耗时={elapsed:.2f}s\n"
f" 结果: {content[:150]}"
)
return ToolCallResult(
tool_name=tool_name,
skill_name=self.cfg.name,
success=True,
content=content,
elapsed_sec=elapsed,
)
except Exception as e:
elapsed = time.time() - start
logger.error(f"❌ 工具调用失败: {tool_name} {e}")
return ToolCallResult(
tool_name=tool_name,
skill_name=self.cfg.name,
success=False,
error=str(e),
elapsed_sec=elapsed,
)
# ── 私有工具方法 ──────────────────────────────────────────
def _ensure_connected(self) -> None:
if not self._initialized or not self._transport:
raise RuntimeError(
f"MCP Skill [{self.cfg.name}] 未连接,请先调用 connect()"
)
@staticmethod
def _extract_content(result: dict) -> str:
"""
从 MCP tools/call 响应中提取文本内容
MCP 响应格式:
{"content": [{"type": "text", "text": "..."}]}
{"content": [{"type": "image", "data": "...", "mimeType": "..."}]}
"""
content_list = result.get("content", [])
if not content_list:
return json.dumps(result, ensure_ascii=False)
parts = []
for item in content_list:
match item.get("type"):
case "text":
parts.append(item.get("text", ""))
case "image":
parts.append(f"[图片: {item.get('mimeType', 'image')}]")
case "resource":
parts.append(f"[资源: {item.get('uri', '')}]")
case _:
parts.append(json.dumps(item, ensure_ascii=False))
return "\n".join(parts)