SIT/models/mapping_graph.py

196 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
映射图模型
定义映射图的数据结构和业务逻辑
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from models.field import Field
from models.mapping import Mapping, OperatorMapping, CodeMapping
@dataclass
class GraphNode:
"""
图节点
Attributes:
field: 关联的字段
x: X坐标
y: Y坐标
"""
field: Field
x: float = 0.0
y: float = 0.0
def __hash__(self):
"""使节点可哈希"""
return hash(self.field.id)
def __eq__(self, other):
"""节点相等性比较"""
if isinstance(other, GraphNode):
return self.field.id == other.field.id
return False
@dataclass
class GraphEdge:
"""
图边
Attributes:
id: 边ID
source_field_id: 源字段ID
target_field_id: 目标字段ID
mapping: 映射关系
weight: 权重(用于最短路径计算)
"""
source_field_id: int
target_field_id: int
mapping: Mapping
id: Optional[int] = None
weight: float = 1.0
def to_dict(self) -> dict:
"""转换为字典"""
return {
'id': self.id,
'source_field_id': self.source_field_id,
'target_field_id': self.target_field_id,
'mapping': self.mapping.to_dict() if self.mapping else None,
'weight': self.weight,
}
@dataclass
class MappingGraph:
"""
映射图
Attributes:
nodes: 节点字典字段ID -> 节点)
edges: 边列表
"""
nodes: Dict[int, GraphNode] = field(default_factory=dict)
edges: List[GraphEdge] = field(default_factory=list)
def add_node(self, node: GraphNode):
"""添加节点"""
if node.field.id:
self.nodes[node.field.id] = node
def remove_node(self, field_id: int):
"""移除节点及相关的边"""
if field_id in self.nodes:
del self.nodes[field_id]
# 移除相关的边
self.edges = [e for e in self.edges
if e.source_field_id != field_id and e.target_field_id != field_id]
def get_node(self, field_id: int) -> Optional[GraphNode]:
"""获取节点"""
return self.nodes.get(field_id)
def add_edge(self, edge: GraphEdge):
"""添加边"""
# 确保源节点和目标节点存在
if edge.source_field_id in self.nodes and edge.target_field_id in self.nodes:
self.edges.append(edge)
def remove_edge(self, edge_id: int):
"""移除边"""
self.edges = [e for e in self.edges if e.id != edge_id]
def get_edges_from(self, field_id: int) -> List[GraphEdge]:
"""获取从指定节点出发的所有边"""
return [e for e in self.edges if e.source_field_id == field_id]
def get_edges_to(self, field_id: int) -> List[GraphEdge]:
"""获取指向指定节点的所有边"""
return [e for e in self.edges if e.target_field_id == field_id]
def get_neighbors(self, field_id: int) -> List[int]:
"""获取节点的所有邻居节点ID"""
return [e.target_field_id for e in self.edges if e.source_field_id == field_id]
def has_path(self, source_id: int, target_id: int) -> bool:
"""检查两个节点之间是否存在路径"""
if source_id not in self.nodes or target_id not in self.nodes:
return False
visited = set()
queue = [source_id]
while queue:
current = queue.pop(0)
if current == target_id:
return True
if current in visited:
continue
visited.add(current)
neighbors = self.get_neighbors(current)
queue.extend([n for n in neighbors if n not in visited])
return False
def get_all_paths(self, source_id: int, target_id: int,
max_depth: int = 10) -> List[List[int]]:
"""
获取两个节点之间的所有路径
Args:
source_id: 源节点ID
target_id: 目标节点ID
max_depth: 最大搜索深度
Returns:
路径列表每个路径是节点ID列表
"""
if source_id not in self.nodes or target_id not in self.nodes:
return []
all_paths = []
def dfs(current: int, target: int, path: List[int], visited: set, depth: int):
if depth > max_depth:
return
if current == target:
all_paths.append(path.copy())
return
visited.add(current)
neighbors = self.get_neighbors(current)
for neighbor in neighbors:
if neighbor not in visited:
path.append(neighbor)
dfs(neighbor, target, path, visited, depth + 1)
path.pop()
visited.remove(current)
dfs(source_id, target_id, [source_id], set(), 0)
return all_paths
def to_dict(self) -> dict:
"""转换为字典"""
return {
'nodes': {fid: {'field': node.field.to_dict(), 'x': node.x, 'y': node.y}
for fid, node in self.nodes.items()},
'edges': [e.to_dict() for e in self.edges],
}
def get_statistics(self) -> dict:
"""获取图的统计信息"""
return {
'node_count': len(self.nodes),
'edge_count': len(self.edges),
'avg_degree': len(self.edges) / len(self.nodes) if self.nodes else 0,
}