""" 图算法工具 提供最短路径等图算法实现 """ from typing import List, Optional, Dict, Set, Tuple import heapq from models.mapping_graph import MappingGraph, GraphEdge from utils.logger import get_logger logger = get_logger(__name__) class GraphAlgorithms: """图算法工具类""" @staticmethod def dijkstra_shortest_path(graph: MappingGraph, source_id: int, target_id: int) -> Optional[Tuple[List[int], float]]: """ 使用Dijkstra算法查找最短路径 Args: graph: 映射图 source_id: 源节点ID target_id: 目标节点ID Returns: (路径节点ID列表, 总权重) 或 None """ if source_id not in graph.nodes or target_id not in graph.nodes: logger.warning(f"Source {source_id} or target {target_id} not in graph") return None # 初始化距离和前驱节点 distances: Dict[int, float] = {node_id: float('inf') for node_id in graph.nodes} distances[source_id] = 0 predecessors: Dict[int, Optional[int]] = {node_id: None for node_id in graph.nodes} # 优先队列:(距离, 节点ID) pq = [(0, source_id)] visited: Set[int] = set() while pq: current_dist, current_id = heapq.heappop(pq) if current_id in visited: continue visited.add(current_id) # 找到目标节点 if current_id == target_id: break # 检查所有邻居 for edge in graph.get_edges_from(current_id): neighbor_id = edge.target_field_id if neighbor_id in visited: continue # 计算新距离 new_dist = current_dist + edge.weight if new_dist < distances[neighbor_id]: distances[neighbor_id] = new_dist predecessors[neighbor_id] = current_id heapq.heappush(pq, (new_dist, neighbor_id)) # 重建路径 if distances[target_id] == float('inf'): logger.info(f"No path found from {source_id} to {target_id}") return None path = [] current = target_id while current is not None: path.append(current) current = predecessors[current] path.reverse() logger.info(f"Found shortest path: {path} with total weight {distances[target_id]}") return path, distances[target_id] @staticmethod def bfs_shortest_path(graph: MappingGraph, source_id: int, target_id: int) -> Optional[List[int]]: """ 使用BFS算法查找最短路径(不考虑权重) Args: graph: 映射图 source_id: 源节点ID target_id: 目标节点ID Returns: 路径节点ID列表 或 None """ if source_id not in graph.nodes or target_id not in graph.nodes: return None visited: Set[int] = set() queue: List[Tuple[int, List[int]]] = [(source_id, [source_id])] while queue: current_id, path = queue.pop(0) if current_id == target_id: logger.info(f"Found BFS path: {path}") return path if current_id in visited: continue visited.add(current_id) # 访问所有邻居 neighbors = graph.get_neighbors(current_id) for neighbor_id in neighbors: if neighbor_id not in visited: new_path = path + [neighbor_id] queue.append((neighbor_id, new_path)) logger.info(f"No BFS path found from {source_id} to {target_id}") return None @staticmethod def find_all_paths(graph: MappingGraph, source_id: int, target_id: int, max_length: int = 10) -> List[List[int]]: """ 查找所有路径 Args: graph: 映射图 source_id: 源节点ID target_id: 目标节点ID max_length: 最大路径长度 Returns: 所有路径列表 """ if source_id not in graph.nodes or target_id not in graph.nodes: return [] all_paths = [] def dfs(current: int, target: int, path: List[int], visited: Set[int]): if len(path) > max_length: return if current == target: all_paths.append(path.copy()) return visited.add(current) neighbors = graph.get_neighbors(current) for neighbor in neighbors: if neighbor not in visited: path.append(neighbor) dfs(neighbor, target, path, visited) path.pop() visited.remove(current) dfs(source_id, target_id, [source_id], set()) logger.info(f"Found {len(all_paths)} paths from {source_id} to {target_id}") return all_paths @staticmethod def get_path_edges(graph: MappingGraph, path: List[int]) -> List[GraphEdge]: """ 获取路径上的所有边 Args: graph: 映射图 path: 路径节点ID列表 Returns: 边列表 """ edges = [] for i in range(len(path) - 1): source_id = path[i] target_id = path[i + 1] # 查找连接这两个节点的边 for edge in graph.get_edges_from(source_id): if edge.target_field_id == target_id: edges.append(edge) break return edges @staticmethod def detect_cycles(graph: MappingGraph) -> List[List[int]]: """ 检测图中的环 Args: graph: 映射图 Returns: 所有环的列表 """ cycles = [] visited = set() rec_stack = set() def dfs_cycle(node_id: int, path: List[int]): visited.add(node_id) rec_stack.add(node_id) path.append(node_id) neighbors = graph.get_neighbors(node_id) for neighbor in neighbors: if neighbor not in visited: dfs_cycle(neighbor, path) elif neighbor in rec_stack: # 找到环 cycle_start = path.index(neighbor) cycle = path[cycle_start:] cycles.append(cycle) path.pop() rec_stack.remove(node_id) for node_id in graph.nodes: if node_id not in visited: dfs_cycle(node_id, []) logger.info(f"Detected {len(cycles)} cycles in graph") return cycles