240 lines
6.9 KiB
Python
240 lines
6.9 KiB
Python
"""
|
||
图算法工具
|
||
提供最短路径等图算法实现
|
||
"""
|
||
from typing import List, Optional, Dict, Set, Tuple
|
||
import heapq
|
||
from models.mapping_graph import MappingGraph, GraphEdge
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class GraphAlgorithms:
|
||
"""图算法工具类"""
|
||
|
||
@staticmethod
|
||
def dijkstra_shortest_path(graph: MappingGraph,
|
||
source_id: int,
|
||
target_id: int) -> Optional[Tuple[List[int], float]]:
|
||
"""
|
||
使用Dijkstra算法查找最短路径
|
||
|
||
Args:
|
||
graph: 映射图
|
||
source_id: 源节点ID
|
||
target_id: 目标节点ID
|
||
|
||
Returns:
|
||
(路径节点ID列表, 总权重) 或 None
|
||
"""
|
||
if source_id not in graph.nodes or target_id not in graph.nodes:
|
||
logger.warning(f"Source {source_id} or target {target_id} not in graph")
|
||
return None
|
||
|
||
# 初始化距离和前驱节点
|
||
distances: Dict[int, float] = {node_id: float('inf') for node_id in graph.nodes}
|
||
distances[source_id] = 0
|
||
predecessors: Dict[int, Optional[int]] = {node_id: None for node_id in graph.nodes}
|
||
|
||
# 优先队列:(距离, 节点ID)
|
||
pq = [(0, source_id)]
|
||
visited: Set[int] = set()
|
||
|
||
while pq:
|
||
current_dist, current_id = heapq.heappop(pq)
|
||
|
||
if current_id in visited:
|
||
continue
|
||
|
||
visited.add(current_id)
|
||
|
||
# 找到目标节点
|
||
if current_id == target_id:
|
||
break
|
||
|
||
# 检查所有邻居
|
||
for edge in graph.get_edges_from(current_id):
|
||
neighbor_id = edge.target_field_id
|
||
|
||
if neighbor_id in visited:
|
||
continue
|
||
|
||
# 计算新距离
|
||
new_dist = current_dist + edge.weight
|
||
|
||
if new_dist < distances[neighbor_id]:
|
||
distances[neighbor_id] = new_dist
|
||
predecessors[neighbor_id] = current_id
|
||
heapq.heappush(pq, (new_dist, neighbor_id))
|
||
|
||
# 重建路径
|
||
if distances[target_id] == float('inf'):
|
||
logger.info(f"No path found from {source_id} to {target_id}")
|
||
return None
|
||
|
||
path = []
|
||
current = target_id
|
||
while current is not None:
|
||
path.append(current)
|
||
current = predecessors[current]
|
||
|
||
path.reverse()
|
||
|
||
logger.info(f"Found shortest path: {path} with total weight {distances[target_id]}")
|
||
return path, distances[target_id]
|
||
|
||
@staticmethod
|
||
def bfs_shortest_path(graph: MappingGraph,
|
||
source_id: int,
|
||
target_id: int) -> Optional[List[int]]:
|
||
"""
|
||
使用BFS算法查找最短路径(不考虑权重)
|
||
|
||
Args:
|
||
graph: 映射图
|
||
source_id: 源节点ID
|
||
target_id: 目标节点ID
|
||
|
||
Returns:
|
||
路径节点ID列表 或 None
|
||
"""
|
||
if source_id not in graph.nodes or target_id not in graph.nodes:
|
||
return None
|
||
|
||
visited: Set[int] = set()
|
||
queue: List[Tuple[int, List[int]]] = [(source_id, [source_id])]
|
||
|
||
while queue:
|
||
current_id, path = queue.pop(0)
|
||
|
||
if current_id == target_id:
|
||
logger.info(f"Found BFS path: {path}")
|
||
return path
|
||
|
||
if current_id in visited:
|
||
continue
|
||
|
||
visited.add(current_id)
|
||
|
||
# 访问所有邻居
|
||
neighbors = graph.get_neighbors(current_id)
|
||
for neighbor_id in neighbors:
|
||
if neighbor_id not in visited:
|
||
new_path = path + [neighbor_id]
|
||
queue.append((neighbor_id, new_path))
|
||
|
||
logger.info(f"No BFS path found from {source_id} to {target_id}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def find_all_paths(graph: MappingGraph,
|
||
source_id: int,
|
||
target_id: int,
|
||
max_length: int = 10) -> List[List[int]]:
|
||
"""
|
||
查找所有路径
|
||
|
||
Args:
|
||
graph: 映射图
|
||
source_id: 源节点ID
|
||
target_id: 目标节点ID
|
||
max_length: 最大路径长度
|
||
|
||
Returns:
|
||
所有路径列表
|
||
"""
|
||
if source_id not in graph.nodes or target_id not in graph.nodes:
|
||
return []
|
||
|
||
all_paths = []
|
||
|
||
def dfs(current: int, target: int, path: List[int], visited: Set[int]):
|
||
if len(path) > max_length:
|
||
return
|
||
|
||
if current == target:
|
||
all_paths.append(path.copy())
|
||
return
|
||
|
||
visited.add(current)
|
||
neighbors = graph.get_neighbors(current)
|
||
|
||
for neighbor in neighbors:
|
||
if neighbor not in visited:
|
||
path.append(neighbor)
|
||
dfs(neighbor, target, path, visited)
|
||
path.pop()
|
||
|
||
visited.remove(current)
|
||
|
||
dfs(source_id, target_id, [source_id], set())
|
||
|
||
logger.info(f"Found {len(all_paths)} paths from {source_id} to {target_id}")
|
||
return all_paths
|
||
|
||
@staticmethod
|
||
def get_path_edges(graph: MappingGraph, path: List[int]) -> List[GraphEdge]:
|
||
"""
|
||
获取路径上的所有边
|
||
|
||
Args:
|
||
graph: 映射图
|
||
path: 路径节点ID列表
|
||
|
||
Returns:
|
||
边列表
|
||
"""
|
||
edges = []
|
||
for i in range(len(path) - 1):
|
||
source_id = path[i]
|
||
target_id = path[i + 1]
|
||
|
||
# 查找连接这两个节点的边
|
||
for edge in graph.get_edges_from(source_id):
|
||
if edge.target_field_id == target_id:
|
||
edges.append(edge)
|
||
break
|
||
|
||
return edges
|
||
|
||
@staticmethod
|
||
def detect_cycles(graph: MappingGraph) -> List[List[int]]:
|
||
"""
|
||
检测图中的环
|
||
|
||
Args:
|
||
graph: 映射图
|
||
|
||
Returns:
|
||
所有环的列表
|
||
"""
|
||
cycles = []
|
||
visited = set()
|
||
rec_stack = set()
|
||
|
||
def dfs_cycle(node_id: int, path: List[int]):
|
||
visited.add(node_id)
|
||
rec_stack.add(node_id)
|
||
path.append(node_id)
|
||
|
||
neighbors = graph.get_neighbors(node_id)
|
||
for neighbor in neighbors:
|
||
if neighbor not in visited:
|
||
dfs_cycle(neighbor, path)
|
||
elif neighbor in rec_stack:
|
||
# 找到环
|
||
cycle_start = path.index(neighbor)
|
||
cycle = path[cycle_start:]
|
||
cycles.append(cycle)
|
||
|
||
path.pop()
|
||
rec_stack.remove(node_id)
|
||
|
||
for node_id in graph.nodes:
|
||
if node_id not in visited:
|
||
dfs_cycle(node_id, [])
|
||
|
||
logger.info(f"Detected {len(cycles)} cycles in graph")
|
||
return cycles
|