""" 映射图模型 定义映射图的数据结构和业务逻辑 """ from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple from models.field import Field from models.mapping import Mapping, OperatorMapping, CodeMapping @dataclass class GraphNode: """ 图节点 Attributes: field: 关联的字段 x: X坐标 y: Y坐标 """ field: Field x: float = 0.0 y: float = 0.0 def __hash__(self): """使节点可哈希""" return hash(self.field.id) def __eq__(self, other): """节点相等性比较""" if isinstance(other, GraphNode): return self.field.id == other.field.id return False @dataclass class GraphEdge: """ 图边 Attributes: id: 边ID source_field_id: 源字段ID target_field_id: 目标字段ID mapping: 映射关系 weight: 权重(用于最短路径计算) """ source_field_id: int target_field_id: int mapping: Mapping id: Optional[int] = None weight: float = 1.0 def to_dict(self) -> dict: """转换为字典""" return { 'id': self.id, 'source_field_id': self.source_field_id, 'target_field_id': self.target_field_id, 'mapping': self.mapping.to_dict() if self.mapping else None, 'weight': self.weight, } @dataclass class MappingGraph: """ 映射图 Attributes: nodes: 节点字典(字段ID -> 节点) edges: 边列表 """ nodes: Dict[int, GraphNode] = field(default_factory=dict) edges: List[GraphEdge] = field(default_factory=list) def add_node(self, node: GraphNode): """添加节点""" if node.field.id: self.nodes[node.field.id] = node def remove_node(self, field_id: int): """移除节点及相关的边""" if field_id in self.nodes: del self.nodes[field_id] # 移除相关的边 self.edges = [e for e in self.edges if e.source_field_id != field_id and e.target_field_id != field_id] def get_node(self, field_id: int) -> Optional[GraphNode]: """获取节点""" return self.nodes.get(field_id) def add_edge(self, edge: GraphEdge): """添加边""" # 确保源节点和目标节点存在 if edge.source_field_id in self.nodes and edge.target_field_id in self.nodes: self.edges.append(edge) def remove_edge(self, edge_id: int): """移除边""" self.edges = [e for e in self.edges if e.id != edge_id] def get_edges_from(self, field_id: int) -> List[GraphEdge]: """获取从指定节点出发的所有边""" return [e for e in self.edges if e.source_field_id == field_id] def get_edges_to(self, field_id: int) -> List[GraphEdge]: """获取指向指定节点的所有边""" return [e for e in self.edges if e.target_field_id == field_id] def get_neighbors(self, field_id: int) -> List[int]: """获取节点的所有邻居节点ID""" return [e.target_field_id for e in self.edges if e.source_field_id == field_id] def has_path(self, source_id: int, target_id: int) -> bool: """检查两个节点之间是否存在路径""" if source_id not in self.nodes or target_id not in self.nodes: return False visited = set() queue = [source_id] while queue: current = queue.pop(0) if current == target_id: return True if current in visited: continue visited.add(current) neighbors = self.get_neighbors(current) queue.extend([n for n in neighbors if n not in visited]) return False def get_all_paths(self, source_id: int, target_id: int, max_depth: int = 10) -> List[List[int]]: """ 获取两个节点之间的所有路径 Args: source_id: 源节点ID target_id: 目标节点ID max_depth: 最大搜索深度 Returns: 路径列表,每个路径是节点ID列表 """ if source_id not in self.nodes or target_id not in self.nodes: return [] all_paths = [] def dfs(current: int, target: int, path: List[int], visited: set, depth: int): if depth > max_depth: return if current == target: all_paths.append(path.copy()) return visited.add(current) neighbors = self.get_neighbors(current) for neighbor in neighbors: if neighbor not in visited: path.append(neighbor) dfs(neighbor, target, path, visited, depth + 1) path.pop() visited.remove(current) dfs(source_id, target_id, [source_id], set(), 0) return all_paths def to_dict(self) -> dict: """转换为字典""" return { 'nodes': {fid: {'field': node.field.to_dict(), 'x': node.x, 'y': node.y} for fid, node in self.nodes.items()}, 'edges': [e.to_dict() for e in self.edges], } def get_statistics(self) -> dict: """获取图的统计信息""" return { 'node_count': len(self.nodes), 'edge_count': len(self.edges), 'avg_degree': len(self.edges) / len(self.nodes) if self.nodes else 0, }