Coverage for openhcs/processing/backends/pos_gen/mist/boruvka_mst.py: 11.2%

67 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Borůvka's Minimum Spanning Tree Algorithm for MIST 

3 

4GPU-accelerated MST construction using parallel Borůvka's algorithm. 

5""" 

6 

7from __future__ import annotations 

8from typing import TYPE_CHECKING 

9 

10from openhcs.core.utils import optional_import 

11from .gpu_kernels import ( 

12 launch_reset_flatten_kernel, 

13 launch_find_minimum_edges_kernel, 

14 launch_union_components_kernel 

15) 

16 

17# For type checking only 

18if TYPE_CHECKING: 18 ↛ 19line 18 didn't jump to line 19 because the condition on line 18 was never true

19 import cupy as cp 

20 

21# Import CuPy as an optional dependency 

22cp = optional_import("cupy") 

23 

24 

25def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

26 """Validate that the input is a CuPy array.""" 

27 if not isinstance(array, cp.ndarray): 

28 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

29 

30 

31def _validate_mst_inputs( 

32 connection_from: "cp.ndarray", # type: ignore 

33 connection_to: "cp.ndarray", # type: ignore 

34 connection_dx: "cp.ndarray", # type: ignore 

35 connection_dy: "cp.ndarray", # type: ignore 

36 connection_quality: "cp.ndarray", # type: ignore 

37 num_nodes: int 

38) -> None: 

39 """Validate MST input arrays.""" 

40 arrays = [connection_from, connection_to, connection_dx, connection_dy, connection_quality] 

41 names = ["connection_from", "connection_to", "connection_dx", "connection_dy", "connection_quality"] 

42 

43 for array, name in zip(arrays, names): 

44 _validate_cupy_array(array, name) 

45 

46 # Check all arrays have same length 

47 lengths = [len(arr) for arr in arrays] 

48 if not all(length == lengths[0] for length in lengths): 

49 raise ValueError(f"All connection arrays must have same length, got {lengths}") 

50 

51 # Check node indices are valid 

52 if len(connection_from) > 0: 

53 max_from = int(cp.max(connection_from)) 

54 max_to = int(cp.max(connection_to)) 

55 max_node = max(max_from, max_to) 

56 if max_node >= num_nodes: 

57 raise ValueError(f"Node index {max_node} exceeds num_nodes {num_nodes}") 

58 

59 

60def build_mst_gpu_boruvka( 

61 connection_from: "cp.ndarray", # type: ignore 

62 connection_to: "cp.ndarray", # type: ignore 

63 connection_dx: "cp.ndarray", # type: ignore 

64 connection_dy: "cp.ndarray", # type: ignore 

65 connection_quality: "cp.ndarray", # type: ignore 

66 num_nodes: int 

67) -> dict: 

68 """ 

69 Full GPU Borůvka's algorithm for minimum spanning tree. 

70 

71 Uses JIT kernels with atomic operations for true parallel execution. 

72 All operations remain on GPU with no CPU-GPU synchronization in inner loops. 

73 """ 

74 

75 

76 # Validate inputs 

77 _validate_mst_inputs(connection_from, connection_to, connection_dx, 

78 connection_dy, connection_quality, num_nodes) 

79 

80 if len(connection_from) == 0: 

81 return {'edges': []} 

82 

83 # Initialize GPU data structures 

84 num_edges = len(connection_from) 

85 

86 # Union-find structure (flattened for O(1) lookups) 

87 parent = cp.arange(num_nodes, dtype=cp.int32) 

88 rank = cp.zeros(num_nodes, dtype=cp.int32) 

89 

90 # Component minimum edge tracking (use int32 for atomic operations) 

91 cheapest_edge_idx = cp.full(num_nodes, -1, dtype=cp.int32) 

92 cheapest_edge_weight_int = cp.full(num_nodes, 2147483647, dtype=cp.int32) # Max int32 value 

93 

94 # MST result storage 

95 mst_edges_from = cp.zeros(num_nodes - 1, dtype=cp.int32) 

96 mst_edges_to = cp.zeros(num_nodes - 1, dtype=cp.int32) 

97 mst_edges_dx = cp.zeros(num_nodes - 1, dtype=cp.float32) 

98 mst_edges_dy = cp.zeros(num_nodes - 1, dtype=cp.float32) 

99 mst_count = cp.array([0], dtype=cp.int32) 

100 

101 # Sort edges by source vertex for cache locality 

102 sort_indices = cp.argsort(connection_from) 

103 edges_from = connection_from[sort_indices] 

104 edges_to = connection_to[sort_indices] 

105 edges_dx = connection_dx[sort_indices] 

106 edges_dy = connection_dy[sort_indices] 

107 edges_quality = connection_quality[sort_indices] 

108 

109 # Main Borůvka's loop - O(log V) iterations 

110 max_iterations = int(cp.ceil(cp.log2(num_nodes))) + 1 

111 

112 

113 for iteration in range(max_iterations): 

114 print(f"🔥 Iteration {iteration}: Starting...") 

115 

116 # Kernel 1: Reset and flatten union-find trees 

117 print(f"🔥 Iteration {iteration}: Launching reset/flatten kernel...") 

118 launch_reset_flatten_kernel( 

119 parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes 

120 ) 

121 

122 # Kernel 2: Find minimum edge per component (parallel) 

123 print(f"🔥 Iteration {iteration}: Launching find minimum edges kernel...") 

124 launch_find_minimum_edges_kernel( 

125 edges_from, edges_to, edges_quality, parent, 

126 cheapest_edge_idx, cheapest_edge_weight_int, num_edges 

127 ) 

128 

129 # Kernel 3: Union components and update MST 

130 print(f"🔥 Iteration {iteration}: Launching union components kernel...") 

131 launch_union_components_kernel( 

132 cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy, 

133 parent, rank, mst_edges_from, mst_edges_to, mst_edges_dx, mst_edges_dy, 

134 mst_count, num_nodes 

135 ) 

136 

137 print(f"🔥 Iteration {iteration}: Kernel launched (pure GPU)") 

138 

139 # Pure GPU termination check - no CPU sync 

140 # Use fixed iteration count instead of dynamic checking 

141 # This eliminates CPU-GPU synchronization bottleneck 

142 

143 # Convert result to expected format 

144 final_mst_count = int(mst_count[0]) 

145 selected_edges = [] 

146 

147 # Debug: Print MST construction results 

148 print(f"Borůvka MST: {final_mst_count} edges constructed (expected: {num_nodes-1})") 

149 

150 # Bounds check to prevent crash 

151 max_edges = num_nodes - 1 

152 if final_mst_count > max_edges: 

153 print(f"🔥 WARNING: MST count {final_mst_count} exceeds maximum {max_edges}, clamping") 

154 final_mst_count = max_edges 

155 

156 for i in range(final_mst_count): 

157 edge = { 

158 'from': int(mst_edges_from[i]), 

159 'to': int(mst_edges_to[i]), 

160 'dx': float(mst_edges_dx[i]), 

161 'dy': float(mst_edges_dy[i]), 

162 'quality': 0.0 # Could be stored if needed 

163 } 

164 selected_edges.append(edge) 

165 

166 # Debug: Print first few edges 

167 if i < 3: 

168 print(f" Edge {i}: {edge['from']} -> {edge['to']}, dx={edge['dx']:.3f}, dy={edge['dy']:.3f}") 

169 

170 return {'edges': selected_edges}