SIT/utils/graph_algorithms.py

240 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图算法工具
提供最短路径等图算法实现
"""
from typing import List, Optional, Dict, Set, Tuple
import heapq
from models.mapping_graph import MappingGraph, GraphEdge
from utils.logger import get_logger
logger = get_logger(__name__)
class GraphAlgorithms:
"""图算法工具类"""
@staticmethod
def dijkstra_shortest_path(graph: MappingGraph,
source_id: int,
target_id: int) -> Optional[Tuple[List[int], float]]:
"""
使用Dijkstra算法查找最短路径
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
Returns:
(路径节点ID列表, 总权重) 或 None
"""
if source_id not in graph.nodes or target_id not in graph.nodes:
logger.warning(f"Source {source_id} or target {target_id} not in graph")
return None
# 初始化距离和前驱节点
distances: Dict[int, float] = {node_id: float('inf') for node_id in graph.nodes}
distances[source_id] = 0
predecessors: Dict[int, Optional[int]] = {node_id: None for node_id in graph.nodes}
# 优先队列:(距离, 节点ID)
pq = [(0, source_id)]
visited: Set[int] = set()
while pq:
current_dist, current_id = heapq.heappop(pq)
if current_id in visited:
continue
visited.add(current_id)
# 找到目标节点
if current_id == target_id:
break
# 检查所有邻居
for edge in graph.get_edges_from(current_id):
neighbor_id = edge.target_field_id
if neighbor_id in visited:
continue
# 计算新距离
new_dist = current_dist + edge.weight
if new_dist < distances[neighbor_id]:
distances[neighbor_id] = new_dist
predecessors[neighbor_id] = current_id
heapq.heappush(pq, (new_dist, neighbor_id))
# 重建路径
if distances[target_id] == float('inf'):
logger.info(f"No path found from {source_id} to {target_id}")
return None
path = []
current = target_id
while current is not None:
path.append(current)
current = predecessors[current]
path.reverse()
logger.info(f"Found shortest path: {path} with total weight {distances[target_id]}")
return path, distances[target_id]
@staticmethod
def bfs_shortest_path(graph: MappingGraph,
source_id: int,
target_id: int) -> Optional[List[int]]:
"""
使用BFS算法查找最短路径不考虑权重
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
Returns:
路径节点ID列表 或 None
"""
if source_id not in graph.nodes or target_id not in graph.nodes:
return None
visited: Set[int] = set()
queue: List[Tuple[int, List[int]]] = [(source_id, [source_id])]
while queue:
current_id, path = queue.pop(0)
if current_id == target_id:
logger.info(f"Found BFS path: {path}")
return path
if current_id in visited:
continue
visited.add(current_id)
# 访问所有邻居
neighbors = graph.get_neighbors(current_id)
for neighbor_id in neighbors:
if neighbor_id not in visited:
new_path = path + [neighbor_id]
queue.append((neighbor_id, new_path))
logger.info(f"No BFS path found from {source_id} to {target_id}")
return None
@staticmethod
def find_all_paths(graph: MappingGraph,
source_id: int,
target_id: int,
max_length: int = 10) -> List[List[int]]:
"""
查找所有路径
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
max_length: 最大路径长度
Returns:
所有路径列表
"""
if source_id not in graph.nodes or target_id not in graph.nodes:
return []
all_paths = []
def dfs(current: int, target: int, path: List[int], visited: Set[int]):
if len(path) > max_length:
return
if current == target:
all_paths.append(path.copy())
return
visited.add(current)
neighbors = graph.get_neighbors(current)
for neighbor in neighbors:
if neighbor not in visited:
path.append(neighbor)
dfs(neighbor, target, path, visited)
path.pop()
visited.remove(current)
dfs(source_id, target_id, [source_id], set())
logger.info(f"Found {len(all_paths)} paths from {source_id} to {target_id}")
return all_paths
@staticmethod
def get_path_edges(graph: MappingGraph, path: List[int]) -> List[GraphEdge]:
"""
获取路径上的所有边
Args:
graph: 映射图
path: 路径节点ID列表
Returns:
边列表
"""
edges = []
for i in range(len(path) - 1):
source_id = path[i]
target_id = path[i + 1]
# 查找连接这两个节点的边
for edge in graph.get_edges_from(source_id):
if edge.target_field_id == target_id:
edges.append(edge)
break
return edges
@staticmethod
def detect_cycles(graph: MappingGraph) -> List[List[int]]:
"""
检测图中的环
Args:
graph: 映射图
Returns:
所有环的列表
"""
cycles = []
visited = set()
rec_stack = set()
def dfs_cycle(node_id: int, path: List[int]):
visited.add(node_id)
rec_stack.add(node_id)
path.append(node_id)
neighbors = graph.get_neighbors(node_id)
for neighbor in neighbors:
if neighbor not in visited:
dfs_cycle(neighbor, path)
elif neighbor in rec_stack:
# 找到环
cycle_start = path.index(neighbor)
cycle = path[cycle_start:]
cycles.append(cycle)
path.pop()
rec_stack.remove(node_id)
for node_id in graph.nodes:
if node_id not in visited:
dfs_cycle(node_id, [])
logger.info(f"Detected {len(cycles)} cycles in graph")
return cycles