Coverage for openhcs/processing/backends/pos_gen/mist/boruvka_mst.py: 11.2%

67 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Borůvka's Minimum Spanning Tree Algorithm for MIST 

3 

4GPU-accelerated MST construction using parallel Borůvka's algorithm. 

5""" 

6 

7from __future__ import annotations 

8from typing import TYPE_CHECKING 

9 

10from openhcs.core.utils import optional_import 

11from .gpu_kernels import ( 

12 launch_reset_flatten_kernel, 

13 launch_find_minimum_edges_kernel, 

14 launch_union_components_kernel, 

15 gpu_component_count 

16) 

17 

18# For type checking only 

19if TYPE_CHECKING: 19 ↛ 20line 19 didn't jump to line 20 because the condition on line 19 was never true

20 import cupy as cp 

21 

22# Import CuPy as an optional dependency 

23cp = optional_import("cupy") 

24 

25 

26def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

27 """Validate that the input is a CuPy array.""" 

28 if not isinstance(array, cp.ndarray): 

29 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

30 

31 

32def _validate_mst_inputs( 

33 connection_from: "cp.ndarray", # type: ignore 

34 connection_to: "cp.ndarray", # type: ignore 

35 connection_dx: "cp.ndarray", # type: ignore 

36 connection_dy: "cp.ndarray", # type: ignore 

37 connection_quality: "cp.ndarray", # type: ignore 

38 num_nodes: int 

39) -> None: 

40 """Validate MST input arrays.""" 

41 arrays = [connection_from, connection_to, connection_dx, connection_dy, connection_quality] 

42 names = ["connection_from", "connection_to", "connection_dx", "connection_dy", "connection_quality"] 

43 

44 for array, name in zip(arrays, names): 

45 _validate_cupy_array(array, name) 

46 

47 # Check all arrays have same length 

48 lengths = [len(arr) for arr in arrays] 

49 if not all(length == lengths[0] for length in lengths): 

50 raise ValueError(f"All connection arrays must have same length, got {lengths}") 

51 

52 # Check node indices are valid 

53 if len(connection_from) > 0: 

54 max_from = int(cp.max(connection_from)) 

55 max_to = int(cp.max(connection_to)) 

56 max_node = max(max_from, max_to) 

57 if max_node >= num_nodes: 

58 raise ValueError(f"Node index {max_node} exceeds num_nodes {num_nodes}") 

59 

60 

61def build_mst_gpu_boruvka( 

62 connection_from: "cp.ndarray", # type: ignore 

63 connection_to: "cp.ndarray", # type: ignore 

64 connection_dx: "cp.ndarray", # type: ignore 

65 connection_dy: "cp.ndarray", # type: ignore 

66 connection_quality: "cp.ndarray", # type: ignore 

67 num_nodes: int 

68) -> dict: 

69 """ 

70 Full GPU Borůvka's algorithm for minimum spanning tree. 

71 

72 Uses JIT kernels with atomic operations for true parallel execution. 

73 All operations remain on GPU with no CPU-GPU synchronization in inner loops. 

74 """ 

75 

76 

77 # Validate inputs 

78 _validate_mst_inputs(connection_from, connection_to, connection_dx, 

79 connection_dy, connection_quality, num_nodes) 

80 

81 if len(connection_from) == 0: 

82 return {'edges': []} 

83 

84 # Initialize GPU data structures 

85 num_edges = len(connection_from) 

86 

87 # Union-find structure (flattened for O(1) lookups) 

88 parent = cp.arange(num_nodes, dtype=cp.int32) 

89 rank = cp.zeros(num_nodes, dtype=cp.int32) 

90 

91 # Component minimum edge tracking (use int32 for atomic operations) 

92 cheapest_edge_idx = cp.full(num_nodes, -1, dtype=cp.int32) 

93 cheapest_edge_weight_int = cp.full(num_nodes, 2147483647, dtype=cp.int32) # Max int32 value 

94 

95 # MST result storage 

96 mst_edges_from = cp.zeros(num_nodes - 1, dtype=cp.int32) 

97 mst_edges_to = cp.zeros(num_nodes - 1, dtype=cp.int32) 

98 mst_edges_dx = cp.zeros(num_nodes - 1, dtype=cp.float32) 

99 mst_edges_dy = cp.zeros(num_nodes - 1, dtype=cp.float32) 

100 mst_count = cp.array([0], dtype=cp.int32) 

101 

102 # Sort edges by source vertex for cache locality 

103 sort_indices = cp.argsort(connection_from) 

104 edges_from = connection_from[sort_indices] 

105 edges_to = connection_to[sort_indices] 

106 edges_dx = connection_dx[sort_indices] 

107 edges_dy = connection_dy[sort_indices] 

108 edges_quality = connection_quality[sort_indices] 

109 

110 # Main Borůvka's loop - O(log V) iterations 

111 max_iterations = int(cp.ceil(cp.log2(num_nodes))) + 1 

112 

113 

114 for iteration in range(max_iterations): 

115 print(f"🔥 Iteration {iteration}: Starting...") 

116 

117 # Kernel 1: Reset and flatten union-find trees 

118 print(f"🔥 Iteration {iteration}: Launching reset/flatten kernel...") 

119 launch_reset_flatten_kernel( 

120 parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes 

121 ) 

122 

123 # Kernel 2: Find minimum edge per component (parallel) 

124 print(f"🔥 Iteration {iteration}: Launching find minimum edges kernel...") 

125 launch_find_minimum_edges_kernel( 

126 edges_from, edges_to, edges_quality, parent, 

127 cheapest_edge_idx, cheapest_edge_weight_int, num_edges 

128 ) 

129 

130 # Kernel 3: Union components and update MST 

131 print(f"🔥 Iteration {iteration}: Launching union components kernel...") 

132 launch_union_components_kernel( 

133 cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy, 

134 parent, rank, mst_edges_from, mst_edges_to, mst_edges_dx, mst_edges_dy, 

135 mst_count, num_nodes 

136 ) 

137 

138 print(f"🔥 Iteration {iteration}: Kernel launched (pure GPU)") 

139 

140 # Pure GPU termination check - no CPU sync 

141 # Use fixed iteration count instead of dynamic checking 

142 # This eliminates CPU-GPU synchronization bottleneck 

143 

144 # Convert result to expected format 

145 final_mst_count = int(mst_count[0]) 

146 selected_edges = [] 

147 

148 # Debug: Print MST construction results 

149 print(f"Borůvka MST: {final_mst_count} edges constructed (expected: {num_nodes-1})") 

150 

151 # Bounds check to prevent crash 

152 max_edges = num_nodes - 1 

153 if final_mst_count > max_edges: 

154 print(f"🔥 WARNING: MST count {final_mst_count} exceeds maximum {max_edges}, clamping") 

155 final_mst_count = max_edges 

156 

157 for i in range(final_mst_count): 

158 edge = { 

159 'from': int(mst_edges_from[i]), 

160 'to': int(mst_edges_to[i]), 

161 'dx': float(mst_edges_dx[i]), 

162 'dy': float(mst_edges_dy[i]), 

163 'quality': 0.0 # Could be stored if needed 

164 } 

165 selected_edges.append(edge) 

166 

167 # Debug: Print first few edges 

168 if i < 3: 

169 print(f" Edge {i}: {edge['from']} -> {edge['to']}, dx={edge['dx']:.3f}, dy={edge['dy']:.3f}") 

170 

171 return {'edges': selected_edges}