196 lines
5.5 KiB
Python
196 lines
5.5 KiB
Python
|
|
"""
|
|||
|
|
映射图模型
|
|||
|
|
定义映射图的数据结构和业务逻辑
|
|||
|
|
"""
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import Dict, List, Optional, Tuple
|
|||
|
|
from models.field import Field
|
|||
|
|
from models.mapping import Mapping, OperatorMapping, CodeMapping
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class GraphNode:
|
|||
|
|
"""
|
|||
|
|
图节点
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
field: 关联的字段
|
|||
|
|
x: X坐标
|
|||
|
|
y: Y坐标
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
field: Field
|
|||
|
|
x: float = 0.0
|
|||
|
|
y: float = 0.0
|
|||
|
|
|
|||
|
|
def __hash__(self):
|
|||
|
|
"""使节点可哈希"""
|
|||
|
|
return hash(self.field.id)
|
|||
|
|
|
|||
|
|
def __eq__(self, other):
|
|||
|
|
"""节点相等性比较"""
|
|||
|
|
if isinstance(other, GraphNode):
|
|||
|
|
return self.field.id == other.field.id
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class GraphEdge:
|
|||
|
|
"""
|
|||
|
|
图边
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
id: 边ID
|
|||
|
|
source_field_id: 源字段ID
|
|||
|
|
target_field_id: 目标字段ID
|
|||
|
|
mapping: 映射关系
|
|||
|
|
weight: 权重(用于最短路径计算)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
source_field_id: int
|
|||
|
|
target_field_id: int
|
|||
|
|
mapping: Mapping
|
|||
|
|
id: Optional[int] = None
|
|||
|
|
weight: float = 1.0
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict:
|
|||
|
|
"""转换为字典"""
|
|||
|
|
return {
|
|||
|
|
'id': self.id,
|
|||
|
|
'source_field_id': self.source_field_id,
|
|||
|
|
'target_field_id': self.target_field_id,
|
|||
|
|
'mapping': self.mapping.to_dict() if self.mapping else None,
|
|||
|
|
'weight': self.weight,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MappingGraph:
|
|||
|
|
"""
|
|||
|
|
映射图
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
nodes: 节点字典(字段ID -> 节点)
|
|||
|
|
edges: 边列表
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
nodes: Dict[int, GraphNode] = field(default_factory=dict)
|
|||
|
|
edges: List[GraphEdge] = field(default_factory=list)
|
|||
|
|
|
|||
|
|
def add_node(self, node: GraphNode):
|
|||
|
|
"""添加节点"""
|
|||
|
|
if node.field.id:
|
|||
|
|
self.nodes[node.field.id] = node
|
|||
|
|
|
|||
|
|
def remove_node(self, field_id: int):
|
|||
|
|
"""移除节点及相关的边"""
|
|||
|
|
if field_id in self.nodes:
|
|||
|
|
del self.nodes[field_id]
|
|||
|
|
# 移除相关的边
|
|||
|
|
self.edges = [e for e in self.edges
|
|||
|
|
if e.source_field_id != field_id and e.target_field_id != field_id]
|
|||
|
|
|
|||
|
|
def get_node(self, field_id: int) -> Optional[GraphNode]:
|
|||
|
|
"""获取节点"""
|
|||
|
|
return self.nodes.get(field_id)
|
|||
|
|
|
|||
|
|
def add_edge(self, edge: GraphEdge):
|
|||
|
|
"""添加边"""
|
|||
|
|
# 确保源节点和目标节点存在
|
|||
|
|
if edge.source_field_id in self.nodes and edge.target_field_id in self.nodes:
|
|||
|
|
self.edges.append(edge)
|
|||
|
|
|
|||
|
|
def remove_edge(self, edge_id: int):
|
|||
|
|
"""移除边"""
|
|||
|
|
self.edges = [e for e in self.edges if e.id != edge_id]
|
|||
|
|
|
|||
|
|
def get_edges_from(self, field_id: int) -> List[GraphEdge]:
|
|||
|
|
"""获取从指定节点出发的所有边"""
|
|||
|
|
return [e for e in self.edges if e.source_field_id == field_id]
|
|||
|
|
|
|||
|
|
def get_edges_to(self, field_id: int) -> List[GraphEdge]:
|
|||
|
|
"""获取指向指定节点的所有边"""
|
|||
|
|
return [e for e in self.edges if e.target_field_id == field_id]
|
|||
|
|
|
|||
|
|
def get_neighbors(self, field_id: int) -> List[int]:
|
|||
|
|
"""获取节点的所有邻居节点ID"""
|
|||
|
|
return [e.target_field_id for e in self.edges if e.source_field_id == field_id]
|
|||
|
|
|
|||
|
|
def has_path(self, source_id: int, target_id: int) -> bool:
|
|||
|
|
"""检查两个节点之间是否存在路径"""
|
|||
|
|
if source_id not in self.nodes or target_id not in self.nodes:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
visited = set()
|
|||
|
|
queue = [source_id]
|
|||
|
|
|
|||
|
|
while queue:
|
|||
|
|
current = queue.pop(0)
|
|||
|
|
if current == target_id:
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
if current in visited:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
visited.add(current)
|
|||
|
|
neighbors = self.get_neighbors(current)
|
|||
|
|
queue.extend([n for n in neighbors if n not in visited])
|
|||
|
|
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def get_all_paths(self, source_id: int, target_id: int,
|
|||
|
|
max_depth: int = 10) -> List[List[int]]:
|
|||
|
|
"""
|
|||
|
|
获取两个节点之间的所有路径
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
source_id: 源节点ID
|
|||
|
|
target_id: 目标节点ID
|
|||
|
|
max_depth: 最大搜索深度
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
路径列表,每个路径是节点ID列表
|
|||
|
|
"""
|
|||
|
|
if source_id not in self.nodes or target_id not in self.nodes:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
all_paths = []
|
|||
|
|
|
|||
|
|
def dfs(current: int, target: int, path: List[int], visited: set, depth: int):
|
|||
|
|
if depth > max_depth:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
if current == target:
|
|||
|
|
all_paths.append(path.copy())
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
visited.add(current)
|
|||
|
|
neighbors = self.get_neighbors(current)
|
|||
|
|
|
|||
|
|
for neighbor in neighbors:
|
|||
|
|
if neighbor not in visited:
|
|||
|
|
path.append(neighbor)
|
|||
|
|
dfs(neighbor, target, path, visited, depth + 1)
|
|||
|
|
path.pop()
|
|||
|
|
|
|||
|
|
visited.remove(current)
|
|||
|
|
|
|||
|
|
dfs(source_id, target_id, [source_id], set(), 0)
|
|||
|
|
return all_paths
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict:
|
|||
|
|
"""转换为字典"""
|
|||
|
|
return {
|
|||
|
|
'nodes': {fid: {'field': node.field.to_dict(), 'x': node.x, 'y': node.y}
|
|||
|
|
for fid, node in self.nodes.items()},
|
|||
|
|
'edges': [e.to_dict() for e in self.edges],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def get_statistics(self) -> dict:
|
|||
|
|
"""获取图的统计信息"""
|
|||
|
|
return {
|
|||
|
|
'node_count': len(self.nodes),
|
|||
|
|
'edge_count': len(self.edges),
|
|||
|
|
'avg_degree': len(self.edges) / len(self.nodes) if self.nodes else 0,
|
|||
|
|
}
|