240 lines
6.9 KiB
Python
240 lines
6.9 KiB
Python
|
|
"""
|
|||
|
|
图算法工具
|
|||
|
|
提供最短路径等图算法实现
|
|||
|
|
"""
|
|||
|
|
from typing import List, Optional, Dict, Set, Tuple
|
|||
|
|
import heapq
|
|||
|
|
from models.mapping_graph import MappingGraph, GraphEdge
|
|||
|
|
from utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GraphAlgorithms:
|
|||
|
|
"""图算法工具类"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def dijkstra_shortest_path(graph: MappingGraph,
|
|||
|
|
source_id: int,
|
|||
|
|
target_id: int) -> Optional[Tuple[List[int], float]]:
|
|||
|
|
"""
|
|||
|
|
使用Dijkstra算法查找最短路径
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
graph: 映射图
|
|||
|
|
source_id: 源节点ID
|
|||
|
|
target_id: 目标节点ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(路径节点ID列表, 总权重) 或 None
|
|||
|
|
"""
|
|||
|
|
if source_id not in graph.nodes or target_id not in graph.nodes:
|
|||
|
|
logger.warning(f"Source {source_id} or target {target_id} not in graph")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 初始化距离和前驱节点
|
|||
|
|
distances: Dict[int, float] = {node_id: float('inf') for node_id in graph.nodes}
|
|||
|
|
distances[source_id] = 0
|
|||
|
|
predecessors: Dict[int, Optional[int]] = {node_id: None for node_id in graph.nodes}
|
|||
|
|
|
|||
|
|
# 优先队列:(距离, 节点ID)
|
|||
|
|
pq = [(0, source_id)]
|
|||
|
|
visited: Set[int] = set()
|
|||
|
|
|
|||
|
|
while pq:
|
|||
|
|
current_dist, current_id = heapq.heappop(pq)
|
|||
|
|
|
|||
|
|
if current_id in visited:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
visited.add(current_id)
|
|||
|
|
|
|||
|
|
# 找到目标节点
|
|||
|
|
if current_id == target_id:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# 检查所有邻居
|
|||
|
|
for edge in graph.get_edges_from(current_id):
|
|||
|
|
neighbor_id = edge.target_field_id
|
|||
|
|
|
|||
|
|
if neighbor_id in visited:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 计算新距离
|
|||
|
|
new_dist = current_dist + edge.weight
|
|||
|
|
|
|||
|
|
if new_dist < distances[neighbor_id]:
|
|||
|
|
distances[neighbor_id] = new_dist
|
|||
|
|
predecessors[neighbor_id] = current_id
|
|||
|
|
heapq.heappush(pq, (new_dist, neighbor_id))
|
|||
|
|
|
|||
|
|
# 重建路径
|
|||
|
|
if distances[target_id] == float('inf'):
|
|||
|
|
logger.info(f"No path found from {source_id} to {target_id}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
path = []
|
|||
|
|
current = target_id
|
|||
|
|
while current is not None:
|
|||
|
|
path.append(current)
|
|||
|
|
current = predecessors[current]
|
|||
|
|
|
|||
|
|
path.reverse()
|
|||
|
|
|
|||
|
|
logger.info(f"Found shortest path: {path} with total weight {distances[target_id]}")
|
|||
|
|
return path, distances[target_id]
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def bfs_shortest_path(graph: MappingGraph,
|
|||
|
|
source_id: int,
|
|||
|
|
target_id: int) -> Optional[List[int]]:
|
|||
|
|
"""
|
|||
|
|
使用BFS算法查找最短路径(不考虑权重)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
graph: 映射图
|
|||
|
|
source_id: 源节点ID
|
|||
|
|
target_id: 目标节点ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
路径节点ID列表 或 None
|
|||
|
|
"""
|
|||
|
|
if source_id not in graph.nodes or target_id not in graph.nodes:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
visited: Set[int] = set()
|
|||
|
|
queue: List[Tuple[int, List[int]]] = [(source_id, [source_id])]
|
|||
|
|
|
|||
|
|
while queue:
|
|||
|
|
current_id, path = queue.pop(0)
|
|||
|
|
|
|||
|
|
if current_id == target_id:
|
|||
|
|
logger.info(f"Found BFS path: {path}")
|
|||
|
|
return path
|
|||
|
|
|
|||
|
|
if current_id in visited:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
visited.add(current_id)
|
|||
|
|
|
|||
|
|
# 访问所有邻居
|
|||
|
|
neighbors = graph.get_neighbors(current_id)
|
|||
|
|
for neighbor_id in neighbors:
|
|||
|
|
if neighbor_id not in visited:
|
|||
|
|
new_path = path + [neighbor_id]
|
|||
|
|
queue.append((neighbor_id, new_path))
|
|||
|
|
|
|||
|
|
logger.info(f"No BFS path found from {source_id} to {target_id}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def find_all_paths(graph: MappingGraph,
|
|||
|
|
source_id: int,
|
|||
|
|
target_id: int,
|
|||
|
|
max_length: int = 10) -> List[List[int]]:
|
|||
|
|
"""
|
|||
|
|
查找所有路径
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
graph: 映射图
|
|||
|
|
source_id: 源节点ID
|
|||
|
|
target_id: 目标节点ID
|
|||
|
|
max_length: 最大路径长度
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
所有路径列表
|
|||
|
|
"""
|
|||
|
|
if source_id not in graph.nodes or target_id not in graph.nodes:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
all_paths = []
|
|||
|
|
|
|||
|
|
def dfs(current: int, target: int, path: List[int], visited: Set[int]):
|
|||
|
|
if len(path) > max_length:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
if current == target:
|
|||
|
|
all_paths.append(path.copy())
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
visited.add(current)
|
|||
|
|
neighbors = graph.get_neighbors(current)
|
|||
|
|
|
|||
|
|
for neighbor in neighbors:
|
|||
|
|
if neighbor not in visited:
|
|||
|
|
path.append(neighbor)
|
|||
|
|
dfs(neighbor, target, path, visited)
|
|||
|
|
path.pop()
|
|||
|
|
|
|||
|
|
visited.remove(current)
|
|||
|
|
|
|||
|
|
dfs(source_id, target_id, [source_id], set())
|
|||
|
|
|
|||
|
|
logger.info(f"Found {len(all_paths)} paths from {source_id} to {target_id}")
|
|||
|
|
return all_paths
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_path_edges(graph: MappingGraph, path: List[int]) -> List[GraphEdge]:
|
|||
|
|
"""
|
|||
|
|
获取路径上的所有边
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
graph: 映射图
|
|||
|
|
path: 路径节点ID列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
边列表
|
|||
|
|
"""
|
|||
|
|
edges = []
|
|||
|
|
for i in range(len(path) - 1):
|
|||
|
|
source_id = path[i]
|
|||
|
|
target_id = path[i + 1]
|
|||
|
|
|
|||
|
|
# 查找连接这两个节点的边
|
|||
|
|
for edge in graph.get_edges_from(source_id):
|
|||
|
|
if edge.target_field_id == target_id:
|
|||
|
|
edges.append(edge)
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
return edges
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def detect_cycles(graph: MappingGraph) -> List[List[int]]:
|
|||
|
|
"""
|
|||
|
|
检测图中的环
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
graph: 映射图
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
所有环的列表
|
|||
|
|
"""
|
|||
|
|
cycles = []
|
|||
|
|
visited = set()
|
|||
|
|
rec_stack = set()
|
|||
|
|
|
|||
|
|
def dfs_cycle(node_id: int, path: List[int]):
|
|||
|
|
visited.add(node_id)
|
|||
|
|
rec_stack.add(node_id)
|
|||
|
|
path.append(node_id)
|
|||
|
|
|
|||
|
|
neighbors = graph.get_neighbors(node_id)
|
|||
|
|
for neighbor in neighbors:
|
|||
|
|
if neighbor not in visited:
|
|||
|
|
dfs_cycle(neighbor, path)
|
|||
|
|
elif neighbor in rec_stack:
|
|||
|
|
# 找到环
|
|||
|
|
cycle_start = path.index(neighbor)
|
|||
|
|
cycle = path[cycle_start:]
|
|||
|
|
cycles.append(cycle)
|
|||
|
|
|
|||
|
|
path.pop()
|
|||
|
|
rec_stack.remove(node_id)
|
|||
|
|
|
|||
|
|
for node_id in graph.nodes:
|
|||
|
|
if node_id not in visited:
|
|||
|
|
dfs_cycle(node_id, [])
|
|||
|
|
|
|||
|
|
logger.info(f"Detected {len(cycles)} cycles in graph")
|
|||
|
|
return cycles
|