SIT/services/graph_service.py

156 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图算法服务
提供图的最短路径等算法服务
"""
from typing import List, Optional, Tuple
from models.mapping_graph import MappingGraph, GraphEdge
from utils.graph_algorithms import GraphAlgorithms
from utils.logger import get_logger
logger = get_logger(__name__)
class GraphService:
"""图算法服务类"""
def __init__(self):
self.algorithms = GraphAlgorithms()
def find_shortest_path(self, graph: MappingGraph,
source_id: int,
target_id: int,
use_weight: bool = True) -> Optional[Tuple[List[int], float]]:
"""
查找最短路径
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
use_weight: 是否使用权重True使用DijkstraFalse使用BFS
Returns:
(路径节点ID列表, 总权重) 或 None
"""
try:
if use_weight:
result = self.algorithms.dijkstra_shortest_path(graph, source_id, target_id)
if result:
path, weight = result
logger.info(f"Found shortest path (Dijkstra): {path}, weight: {weight}")
return path, weight
else:
path = self.algorithms.bfs_shortest_path(graph, source_id, target_id)
if path:
logger.info(f"Found shortest path (BFS): {path}")
return path, 0.0
return None
except Exception as e:
logger.error(f"Failed to find shortest path: {e}")
return None
def find_all_paths(self, graph: MappingGraph,
source_id: int,
target_id: int,
max_length: int = 10) -> List[List[int]]:
"""
查找所有路径
Args:
graph: 映射图
source_id: 源节点ID
target_id: 目标节点ID
max_length: 最大路径长度
Returns:
所有路径列表
"""
try:
paths = self.algorithms.find_all_paths(graph, source_id, target_id, max_length)
logger.info(f"Found {len(paths)} paths from {source_id} to {target_id}")
return paths
except Exception as e:
logger.error(f"Failed to find all paths: {e}")
return []
def get_path_mappings(self, graph: MappingGraph,
path: List[int]) -> List[GraphEdge]:
"""
获取路径上的所有映射关系
Args:
graph: 映射图
path: 路径节点ID列表
Returns:
边(映射关系)列表
"""
try:
edges = self.algorithms.get_path_edges(graph, path)
logger.info(f"Got {len(edges)} mappings for path")
return edges
except Exception as e:
logger.error(f"Failed to get path mappings: {e}")
return []
def detect_cycles(self, graph: MappingGraph) -> List[List[int]]:
"""
检测图中的环
Args:
graph: 映射图
Returns:
所有环的列表
"""
try:
cycles = self.algorithms.detect_cycles(graph)
if cycles:
logger.warning(f"Detected {len(cycles)} cycles in graph")
return cycles
except Exception as e:
logger.error(f"Failed to detect cycles: {e}")
return []
def validate_graph(self, graph: MappingGraph) -> Tuple[bool, List[str]]:
"""
验证图的有效性
Args:
graph: 映射图
Returns:
(是否有效, 错误信息列表)
"""
errors = []
# 检查孤立节点
isolated_nodes = []
for node_id in graph.nodes:
if not graph.get_edges_from(node_id) and not graph.get_edges_to(node_id):
isolated_nodes.append(node_id)
if isolated_nodes:
errors.append(f"发现 {len(isolated_nodes)} 个孤立节点")
# 检查环
cycles = self.detect_cycles(graph)
if cycles:
errors.append(f"发现 {len(cycles)} 个环路")
# 检查边的有效性
invalid_edges = []
for edge in graph.edges:
if edge.source_field_id not in graph.nodes:
invalid_edges.append(f"{edge.id} 的源节点不存在")
if edge.target_field_id not in graph.nodes:
invalid_edges.append(f"{edge.id} 的目标节点不存在")
if invalid_edges:
errors.extend(invalid_edges)
is_valid = len(errors) == 0
return is_valid, errors