SIT/utils/graph_algorithms.py

240 lines
6.9 KiB
Python
Raw Normal View History

2026-01-29 09:22:54 +00:00
"""
图算法工具
提供最短路径等图算法实现
"""
from typing import List, Optional, Dict, Set, Tuple
import heapq
from models.mapping_graph import MappingGraph, GraphEdge
from utils.logger import get_logger
logger = get_logger(__name__)
class GraphAlgorithms:
"""图算法工具类"""
@staticmethod
def dijkstra_shortest_path(graph: MappingGraph,
source_id: int,
target_id: int) -> Optional[Tuple[List[int], float]]:
"""
使用Dijkstra算法查找最短路径
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
Returns:
(路径节点ID列表, 总权重) None
"""
if source_id not in graph.nodes or target_id not in graph.nodes:
logger.warning(f"Source {source_id} or target {target_id} not in graph")
return None
# 初始化距离和前驱节点
distances: Dict[int, float] = {node_id: float('inf') for node_id in graph.nodes}
distances[source_id] = 0
predecessors: Dict[int, Optional[int]] = {node_id: None for node_id in graph.nodes}
# 优先队列:(距离, 节点ID)
pq = [(0, source_id)]
visited: Set[int] = set()
while pq:
current_dist, current_id = heapq.heappop(pq)
if current_id in visited:
continue
visited.add(current_id)
# 找到目标节点
if current_id == target_id:
break
# 检查所有邻居
for edge in graph.get_edges_from(current_id):
neighbor_id = edge.target_field_id
if neighbor_id in visited:
continue
# 计算新距离
new_dist = current_dist + edge.weight
if new_dist < distances[neighbor_id]:
distances[neighbor_id] = new_dist
predecessors[neighbor_id] = current_id
heapq.heappush(pq, (new_dist, neighbor_id))
# 重建路径
if distances[target_id] == float('inf'):
logger.info(f"No path found from {source_id} to {target_id}")
return None
path = []
current = target_id
while current is not None:
path.append(current)
current = predecessors[current]
path.reverse()
logger.info(f"Found shortest path: {path} with total weight {distances[target_id]}")
return path, distances[target_id]
@staticmethod
def bfs_shortest_path(graph: MappingGraph,
source_id: int,
target_id: int) -> Optional[List[int]]:
"""
使用BFS算法查找最短路径不考虑权重
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
Returns:
路径节点ID列表 None
"""
if source_id not in graph.nodes or target_id not in graph.nodes:
return None
visited: Set[int] = set()
queue: List[Tuple[int, List[int]]] = [(source_id, [source_id])]
while queue:
current_id, path = queue.pop(0)
if current_id == target_id:
logger.info(f"Found BFS path: {path}")
return path
if current_id in visited:
continue
visited.add(current_id)
# 访问所有邻居
neighbors = graph.get_neighbors(current_id)
for neighbor_id in neighbors:
if neighbor_id not in visited:
new_path = path + [neighbor_id]
queue.append((neighbor_id, new_path))
logger.info(f"No BFS path found from {source_id} to {target_id}")
return None
@staticmethod
def find_all_paths(graph: MappingGraph,
source_id: int,
target_id: int,
max_length: int = 10) -> List[List[int]]:
"""
查找所有路径
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
max_length: 最大路径长度
Returns:
所有路径列表
"""
if source_id not in graph.nodes or target_id not in graph.nodes:
return []
all_paths = []
def dfs(current: int, target: int, path: List[int], visited: Set[int]):
if len(path) > max_length:
return
if current == target:
all_paths.append(path.copy())
return
visited.add(current)
neighbors = graph.get_neighbors(current)
for neighbor in neighbors:
if neighbor not in visited:
path.append(neighbor)
dfs(neighbor, target, path, visited)
path.pop()
visited.remove(current)
dfs(source_id, target_id, [source_id], set())
logger.info(f"Found {len(all_paths)} paths from {source_id} to {target_id}")
return all_paths
@staticmethod
def get_path_edges(graph: MappingGraph, path: List[int]) -> List[GraphEdge]:
"""
获取路径上的所有边
Args:
graph: 映射图
path: 路径节点ID列表
Returns:
边列表
"""
edges = []
for i in range(len(path) - 1):
source_id = path[i]
target_id = path[i + 1]
# 查找连接这两个节点的边
for edge in graph.get_edges_from(source_id):
if edge.target_field_id == target_id:
edges.append(edge)
break
return edges
@staticmethod
def detect_cycles(graph: MappingGraph) -> List[List[int]]:
"""
检测图中的环
Args:
graph: 映射图
Returns:
所有环的列表
"""
cycles = []
visited = set()
rec_stack = set()
def dfs_cycle(node_id: int, path: List[int]):
visited.add(node_id)
rec_stack.add(node_id)
path.append(node_id)
neighbors = graph.get_neighbors(node_id)
for neighbor in neighbors:
if neighbor not in visited:
dfs_cycle(neighbor, path)
elif neighbor in rec_stack:
# 找到环
cycle_start = path.index(neighbor)
cycle = path[cycle_start:]
cycles.append(cycle)
path.pop()
rec_stack.remove(node_id)
for node_id in graph.nodes:
if node_id not in visited:
dfs_cycle(node_id, [])
logger.info(f"Detected {len(cycles)} cycles in graph")
return cycles