196 lines
5.5 KiB
Python
196 lines
5.5 KiB
Python
"""
|
||
映射图模型
|
||
定义映射图的数据结构和业务逻辑
|
||
"""
|
||
from dataclasses import dataclass, field
|
||
from typing import Dict, List, Optional, Tuple
|
||
from models.field import Field
|
||
from models.mapping import Mapping, OperatorMapping, CodeMapping
|
||
|
||
|
||
@dataclass
|
||
class GraphNode:
|
||
"""
|
||
图节点
|
||
|
||
Attributes:
|
||
field: 关联的字段
|
||
x: X坐标
|
||
y: Y坐标
|
||
"""
|
||
|
||
field: Field
|
||
x: float = 0.0
|
||
y: float = 0.0
|
||
|
||
def __hash__(self):
|
||
"""使节点可哈希"""
|
||
return hash(self.field.id)
|
||
|
||
def __eq__(self, other):
|
||
"""节点相等性比较"""
|
||
if isinstance(other, GraphNode):
|
||
return self.field.id == other.field.id
|
||
return False
|
||
|
||
|
||
@dataclass
|
||
class GraphEdge:
|
||
"""
|
||
图边
|
||
|
||
Attributes:
|
||
id: 边ID
|
||
source_field_id: 源字段ID
|
||
target_field_id: 目标字段ID
|
||
mapping: 映射关系
|
||
weight: 权重(用于最短路径计算)
|
||
"""
|
||
|
||
source_field_id: int
|
||
target_field_id: int
|
||
mapping: Mapping
|
||
id: Optional[int] = None
|
||
weight: float = 1.0
|
||
|
||
def to_dict(self) -> dict:
|
||
"""转换为字典"""
|
||
return {
|
||
'id': self.id,
|
||
'source_field_id': self.source_field_id,
|
||
'target_field_id': self.target_field_id,
|
||
'mapping': self.mapping.to_dict() if self.mapping else None,
|
||
'weight': self.weight,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class MappingGraph:
|
||
"""
|
||
映射图
|
||
|
||
Attributes:
|
||
nodes: 节点字典(字段ID -> 节点)
|
||
edges: 边列表
|
||
"""
|
||
|
||
nodes: Dict[int, GraphNode] = field(default_factory=dict)
|
||
edges: List[GraphEdge] = field(default_factory=list)
|
||
|
||
def add_node(self, node: GraphNode):
|
||
"""添加节点"""
|
||
if node.field.id:
|
||
self.nodes[node.field.id] = node
|
||
|
||
def remove_node(self, field_id: int):
|
||
"""移除节点及相关的边"""
|
||
if field_id in self.nodes:
|
||
del self.nodes[field_id]
|
||
# 移除相关的边
|
||
self.edges = [e for e in self.edges
|
||
if e.source_field_id != field_id and e.target_field_id != field_id]
|
||
|
||
def get_node(self, field_id: int) -> Optional[GraphNode]:
|
||
"""获取节点"""
|
||
return self.nodes.get(field_id)
|
||
|
||
def add_edge(self, edge: GraphEdge):
|
||
"""添加边"""
|
||
# 确保源节点和目标节点存在
|
||
if edge.source_field_id in self.nodes and edge.target_field_id in self.nodes:
|
||
self.edges.append(edge)
|
||
|
||
def remove_edge(self, edge_id: int):
|
||
"""移除边"""
|
||
self.edges = [e for e in self.edges if e.id != edge_id]
|
||
|
||
def get_edges_from(self, field_id: int) -> List[GraphEdge]:
|
||
"""获取从指定节点出发的所有边"""
|
||
return [e for e in self.edges if e.source_field_id == field_id]
|
||
|
||
def get_edges_to(self, field_id: int) -> List[GraphEdge]:
|
||
"""获取指向指定节点的所有边"""
|
||
return [e for e in self.edges if e.target_field_id == field_id]
|
||
|
||
def get_neighbors(self, field_id: int) -> List[int]:
|
||
"""获取节点的所有邻居节点ID"""
|
||
return [e.target_field_id for e in self.edges if e.source_field_id == field_id]
|
||
|
||
def has_path(self, source_id: int, target_id: int) -> bool:
|
||
"""检查两个节点之间是否存在路径"""
|
||
if source_id not in self.nodes or target_id not in self.nodes:
|
||
return False
|
||
|
||
visited = set()
|
||
queue = [source_id]
|
||
|
||
while queue:
|
||
current = queue.pop(0)
|
||
if current == target_id:
|
||
return True
|
||
|
||
if current in visited:
|
||
continue
|
||
|
||
visited.add(current)
|
||
neighbors = self.get_neighbors(current)
|
||
queue.extend([n for n in neighbors if n not in visited])
|
||
|
||
return False
|
||
|
||
def get_all_paths(self, source_id: int, target_id: int,
|
||
max_depth: int = 10) -> List[List[int]]:
|
||
"""
|
||
获取两个节点之间的所有路径
|
||
|
||
Args:
|
||
source_id: 源节点ID
|
||
target_id: 目标节点ID
|
||
max_depth: 最大搜索深度
|
||
|
||
Returns:
|
||
路径列表,每个路径是节点ID列表
|
||
"""
|
||
if source_id not in self.nodes or target_id not in self.nodes:
|
||
return []
|
||
|
||
all_paths = []
|
||
|
||
def dfs(current: int, target: int, path: List[int], visited: set, depth: int):
|
||
if depth > max_depth:
|
||
return
|
||
|
||
if current == target:
|
||
all_paths.append(path.copy())
|
||
return
|
||
|
||
visited.add(current)
|
||
neighbors = self.get_neighbors(current)
|
||
|
||
for neighbor in neighbors:
|
||
if neighbor not in visited:
|
||
path.append(neighbor)
|
||
dfs(neighbor, target, path, visited, depth + 1)
|
||
path.pop()
|
||
|
||
visited.remove(current)
|
||
|
||
dfs(source_id, target_id, [source_id], set(), 0)
|
||
return all_paths
|
||
|
||
def to_dict(self) -> dict:
|
||
"""转换为字典"""
|
||
return {
|
||
'nodes': {fid: {'field': node.field.to_dict(), 'x': node.x, 'y': node.y}
|
||
for fid, node in self.nodes.items()},
|
||
'edges': [e.to_dict() for e in self.edges],
|
||
}
|
||
|
||
def get_statistics(self) -> dict:
|
||
"""获取图的统计信息"""
|
||
return {
|
||
'node_count': len(self.nodes),
|
||
'edge_count': len(self.edges),
|
||
'avg_degree': len(self.edges) / len(self.nodes) if self.nodes else 0,
|
||
}
|