SIT/models/mapping_graph.py

196 lines
5.5 KiB
Python
Raw Permalink Normal View History

2026-01-29 09:08:31 +00:00
"""
映射图模型
定义映射图的数据结构和业务逻辑
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from models.field import Field
from models.mapping import Mapping, OperatorMapping, CodeMapping
@dataclass
class GraphNode:
"""
图节点
Attributes:
field: 关联的字段
x: X坐标
y: Y坐标
"""
field: Field
x: float = 0.0
y: float = 0.0
def __hash__(self):
"""使节点可哈希"""
return hash(self.field.id)
def __eq__(self, other):
"""节点相等性比较"""
if isinstance(other, GraphNode):
return self.field.id == other.field.id
return False
@dataclass
class GraphEdge:
"""
图边
Attributes:
id: 边ID
source_field_id: 源字段ID
target_field_id: 目标字段ID
mapping: 映射关系
weight: 权重用于最短路径计算
"""
source_field_id: int
target_field_id: int
mapping: Mapping
id: Optional[int] = None
weight: float = 1.0
def to_dict(self) -> dict:
"""转换为字典"""
return {
'id': self.id,
'source_field_id': self.source_field_id,
'target_field_id': self.target_field_id,
'mapping': self.mapping.to_dict() if self.mapping else None,
'weight': self.weight,
}
@dataclass
class MappingGraph:
"""
映射图
Attributes:
nodes: 节点字典字段ID -> 节点
edges: 边列表
"""
nodes: Dict[int, GraphNode] = field(default_factory=dict)
edges: List[GraphEdge] = field(default_factory=list)
def add_node(self, node: GraphNode):
"""添加节点"""
if node.field.id:
self.nodes[node.field.id] = node
def remove_node(self, field_id: int):
"""移除节点及相关的边"""
if field_id in self.nodes:
del self.nodes[field_id]
# 移除相关的边
self.edges = [e for e in self.edges
if e.source_field_id != field_id and e.target_field_id != field_id]
def get_node(self, field_id: int) -> Optional[GraphNode]:
"""获取节点"""
return self.nodes.get(field_id)
def add_edge(self, edge: GraphEdge):
"""添加边"""
# 确保源节点和目标节点存在
if edge.source_field_id in self.nodes and edge.target_field_id in self.nodes:
self.edges.append(edge)
def remove_edge(self, edge_id: int):
"""移除边"""
self.edges = [e for e in self.edges if e.id != edge_id]
def get_edges_from(self, field_id: int) -> List[GraphEdge]:
"""获取从指定节点出发的所有边"""
return [e for e in self.edges if e.source_field_id == field_id]
def get_edges_to(self, field_id: int) -> List[GraphEdge]:
"""获取指向指定节点的所有边"""
return [e for e in self.edges if e.target_field_id == field_id]
def get_neighbors(self, field_id: int) -> List[int]:
"""获取节点的所有邻居节点ID"""
return [e.target_field_id for e in self.edges if e.source_field_id == field_id]
def has_path(self, source_id: int, target_id: int) -> bool:
"""检查两个节点之间是否存在路径"""
if source_id not in self.nodes or target_id not in self.nodes:
return False
visited = set()
queue = [source_id]
while queue:
current = queue.pop(0)
if current == target_id:
return True
if current in visited:
continue
visited.add(current)
neighbors = self.get_neighbors(current)
queue.extend([n for n in neighbors if n not in visited])
return False
def get_all_paths(self, source_id: int, target_id: int,
max_depth: int = 10) -> List[List[int]]:
"""
获取两个节点之间的所有路径
Args:
source_id: 源节点ID
target_id: 目标节点ID
max_depth: 最大搜索深度
Returns:
路径列表每个路径是节点ID列表
"""
if source_id not in self.nodes or target_id not in self.nodes:
return []
all_paths = []
def dfs(current: int, target: int, path: List[int], visited: set, depth: int):
if depth > max_depth:
return
if current == target:
all_paths.append(path.copy())
return
visited.add(current)
neighbors = self.get_neighbors(current)
for neighbor in neighbors:
if neighbor not in visited:
path.append(neighbor)
dfs(neighbor, target, path, visited, depth + 1)
path.pop()
visited.remove(current)
dfs(source_id, target_id, [source_id], set(), 0)
return all_paths
def to_dict(self) -> dict:
"""转换为字典"""
return {
'nodes': {fid: {'field': node.field.to_dict(), 'x': node.x, 'y': node.y}
for fid, node in self.nodes.items()},
'edges': [e.to_dict() for e in self.edges],
}
def get_statistics(self) -> dict:
"""获取图的统计信息"""
return {
'node_count': len(self.nodes),
'edge_count': len(self.edges),
'avg_degree': len(self.edges) / len(self.nodes) if self.nodes else 0,
}