129 lines
5.1 KiB
Python
129 lines
5.1 KiB
Python
"""
|
||
mcp/mcp_server.py
|
||
MCP Server:从配置读取 server_name、transport、enabled_tools
|
||
支持按配置动态过滤注册工具
|
||
"""
|
||
|
||
from typing import Type
|
||
|
||
from config.settings import MCPConfig, settings
|
||
from mcp.mcp_protocol import MCPMethod, MCPRequest, MCPResponse, ToolSchema
|
||
from tools.base_tool import BaseTool, ToolResult
|
||
from utils.logger import get_logger
|
||
|
||
|
||
class MCPServer:
|
||
"""
|
||
MCP 服务器核心类(配置驱动)
|
||
|
||
配置项:
|
||
- server_name: 服务器名称
|
||
- transport: 通信方式 (stdio / http / websocket)
|
||
- enabled_tools: 白名单,仅注册列表中的工具
|
||
|
||
使用示例:
|
||
server = MCPServer() # 从 settings 读取配置
|
||
server = MCPServer(cfg=custom_cfg) # 使用自定义配置
|
||
server.register_tool(CalculatorTool)
|
||
response = server.handle_request(request)
|
||
"""
|
||
|
||
def __init__(self, cfg: MCPConfig | None = None):
|
||
"""
|
||
Args:
|
||
cfg: MCPConfig 实例,None 时从全局 settings 读取
|
||
"""
|
||
self.cfg = cfg or settings.mcp
|
||
self.logger = get_logger("MCP")
|
||
self._registry: dict[str, BaseTool] = {}
|
||
|
||
self.logger.info(f"🚀 MCP Server [{self.cfg.server_name}] 启动")
|
||
self.logger.info(f" transport = {self.cfg.transport}")
|
||
self.logger.info(f" enabled_tools = {self.cfg.enabled_tools}")
|
||
|
||
# ── 工具注册 ────────────────────────────────────────────────
|
||
|
||
def register_tool(self, tool_class: Type[BaseTool]) -> None:
|
||
"""
|
||
注册工具(受 enabled_tools 白名单过滤)
|
||
|
||
Args:
|
||
tool_class: 继承自 BaseTool 的工具类
|
||
"""
|
||
instance = tool_class()
|
||
if not instance.name:
|
||
raise ValueError(f"工具类 {tool_class.__name__} 未设置 name 属性")
|
||
|
||
# 白名单过滤
|
||
if instance.name not in self.cfg.enabled_tools:
|
||
self.logger.warning(
|
||
f"⏭ 工具 [{instance.name}] 不在 enabled_tools 白名单中,跳过注册"
|
||
)
|
||
return
|
||
|
||
self._registry[instance.name] = instance
|
||
self.logger.info(f"📌 注册工具: [{instance.name}] — {instance.description}")
|
||
|
||
def register_tools(self, *tool_classes: Type[BaseTool]) -> None:
|
||
"""批量注册多个工具类"""
|
||
for cls in tool_classes:
|
||
self.register_tool(cls)
|
||
|
||
# ── 请求处理 ────────────────────────────────────────────────
|
||
|
||
def handle_request(self, request: MCPRequest) -> MCPResponse:
|
||
"""处理 MCP 请求的统一入口"""
|
||
self.logger.info(
|
||
f"📨 收到请求 id={request.id} method={request.method} "
|
||
f"transport={self.cfg.transport}"
|
||
)
|
||
handlers = {
|
||
MCPMethod.TOOLS_LIST: self._handle_tools_list,
|
||
MCPMethod.TOOLS_CALL: self._handle_tools_call,
|
||
}
|
||
handler = handlers.get(request.method)
|
||
if handler is None:
|
||
return self._error_response(request.id, -32601, f"未知方法: {request.method}")
|
||
return handler(request)
|
||
|
||
def _handle_tools_list(self, request: MCPRequest) -> MCPResponse:
|
||
schemas = [tool.get_schema().to_dict() for tool in self._registry.values()]
|
||
self.logger.info(f"📋 返回工具列表,共 {len(schemas)} 个")
|
||
return MCPResponse(id=request.id, result={"tools": schemas})
|
||
|
||
def _handle_tools_call(self, request: MCPRequest) -> MCPResponse:
|
||
tool_name = request.params.get("name")
|
||
arguments = request.params.get("arguments", {})
|
||
tool = self._registry.get(tool_name)
|
||
if tool is None:
|
||
return self._error_response(
|
||
request.id, -32602,
|
||
f"工具 [{tool_name}] 不存在,可用: {list(self._registry.keys())}"
|
||
)
|
||
result: ToolResult = tool.safe_execute(**arguments)
|
||
if result.success:
|
||
return MCPResponse(
|
||
id=request.id,
|
||
result={"content": [{"type": "text", "text": result.output}],
|
||
"metadata": result.metadata},
|
||
)
|
||
return self._error_response(request.id, -32000, result.output)
|
||
|
||
# ── 工具方法 ────────────────────────────────────────────────
|
||
|
||
def get_tool_schemas(self) -> list[ToolSchema]:
|
||
return [tool.get_schema() for tool in self._registry.values()]
|
||
|
||
def list_tools(self) -> list[str]:
|
||
return list(self._registry.keys())
|
||
|
||
@staticmethod
|
||
def _error_response(req_id: str, code: int, message: str) -> MCPResponse:
|
||
return MCPResponse(id=req_id, error={"code": code, "message": message})
|
||
|
||
def __repr__(self) -> str:
|
||
return (
|
||
f"MCPServer(name={self.cfg.server_name!r}, "
|
||
f"transport={self.cfg.transport!r}, "
|
||
f"tools={self.list_tools()})"
|
||
) |