Coverage for openhcs/processing/backends/pos_gen/mist/gpu_kernels.py: 15.5%

80 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2GPU JIT Kernels for Borůvka's MST Algorithm 

3 

4All CUDA kernels for parallel MST construction using CuPy JIT. 

5""" 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING 

9 

10from openhcs.core.utils import optional_import 

11 

12jit = optional_import("cupyx.jit") 

13# For type checking only 

14if TYPE_CHECKING: 14 ↛ 15line 14 didn't jump to line 15 because the condition on line 14 was never true

15 import cupy as cp 

16 

17# Import CuPy as an optional dependency 

18cp = optional_import("cupy") 

19 

20 

21@jit.rawkernel() if jit else lambda f: f 

22def _reset_and_flatten_kernel( 

23 parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes 

24): 

25 """ 

26 Kernel 1: Reset cheapest edge arrays and flatten union-find trees. 

27 """ 

28 tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x 

29 

30 if tid < num_nodes: 

31 # Reset cheapest edge tracking for this node 

32 cheapest_edge_idx[tid] = -1 

33 cheapest_edge_weight_int[tid] = 2147483647 # Max int32 value 

34 

35 # Flatten union-find tree: make this node point directly to root 

36 # Correct iterative path compression 

37 current = tid 

38 while parent[current] != current: 

39 parent[current] = parent[parent[current]] # Compress one level at a time 

40 current = parent[current] 

41 

42 

43@jit.rawkernel() if jit else lambda f: f 

44def _find_minimum_edges_kernel( 

45 edges_from, edges_to, edges_quality, parent, 

46 cheapest_edge_idx, cheapest_edge_weight_int, num_edges 

47): 

48 """ 

49 Kernel 2: Find minimum weight edge for each component (using int32 for atomics). 

50 """ 

51 tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x 

52 

53 if tid < num_edges: 

54 # Get edge endpoints and their components 

55 from_node = edges_from[tid] 

56 to_node = edges_to[tid] 

57 from_comp = parent[from_node] 

58 to_comp = parent[to_node] 

59 

60 # Only process if NOT a self-edge (no return statement) 

61 if from_comp != to_comp: 

62 # Get edge quality (higher is better, so negate for min comparison) 

63 edge_quality = edges_quality[tid] 

64 # Convert to integer by scaling and negating (higher quality = lower int value) 

65 # Scale by 1000000 to preserve precision, then negate 

66 edge_weight_int = int(-edge_quality * 1000000) 

67 

68 # Atomic update cheapest edge for 'from' component 

69 jit.atomic_min(cheapest_edge_weight_int, from_comp, edge_weight_int) 

70 if cheapest_edge_weight_int[from_comp] == edge_weight_int: 

71 cheapest_edge_idx[from_comp] = tid 

72 

73 # Atomic update cheapest edge for 'to' component 

74 jit.atomic_min(cheapest_edge_weight_int, to_comp, edge_weight_int) 

75 if cheapest_edge_weight_int[to_comp] == edge_weight_int: 

76 cheapest_edge_idx[to_comp] = tid 

77 

78 

79@jit.rawkernel() if jit else lambda f: f 

80def _union_components_kernel( 

81 cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy, 

82 parent, rank, mst_from, mst_to, mst_dx, mst_dy, mst_count, num_nodes 

83): 

84 """ 

85 Kernel 3: Union components (no return statements). 

86 """ 

87 tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x 

88 

89 if tid < num_nodes: 

90 edge_idx = cheapest_edge_idx[tid] 

91 

92 # Only process if valid edge exists (no return statement) 

93 if edge_idx >= 0: 

94 # Get edge details 

95 from_node = edges_from[edge_idx] 

96 to_node = edges_to[edge_idx] 

97 

98 # Find current component roots 

99 from_root = parent[from_node] 

100 to_root = parent[to_node] 

101 

102 # Only process if NOT already in same component (no return statement) 

103 if from_root != to_root: 

104 # Atomic union operation with rank-based optimization 

105 rank1 = rank[from_root] 

106 rank2 = rank[to_root] 

107 

108 union_success = False 

109 if rank1 < rank2: 

110 # Make to_root the parent of from_root 

111 old_parent = jit.atomic_cas(parent, from_root, from_root, to_root) 

112 union_success = (old_parent == from_root) 

113 elif rank1 > rank2: 

114 # Make from_root the parent of to_root 

115 old_parent = jit.atomic_cas(parent, to_root, to_root, from_root) 

116 union_success = (old_parent == to_root) 

117 else: 

118 # Equal ranks: make from_root parent and increment its rank 

119 old_parent = jit.atomic_cas(parent, to_root, to_root, from_root) 

120 if old_parent == to_root: 

121 jit.atomic_add(rank, from_root, 1) 

122 union_success = True 

123 

124 # If union was successful, atomically add edge to MST 

125 if union_success: 

126 mst_slot = jit.atomic_add(mst_count, 0, 1) 

127 if mst_slot < num_nodes - 1: 

128 mst_from[mst_slot] = from_node 

129 mst_to[mst_slot] = to_node 

130 mst_dx[mst_slot] = edges_dx[edge_idx] 

131 mst_dy[mst_slot] = edges_dy[edge_idx] 

132 

133 

134def launch_reset_flatten_kernel( 

135 parent: "cp.ndarray", # type: ignore 

136 rank: "cp.ndarray", # type: ignore 

137 cheapest_edge_idx: "cp.ndarray", # type: ignore 

138 cheapest_edge_weight_int: "cp.ndarray", # type: ignore 

139 num_nodes: int 

140) -> None: 

141 """ 

142 Launch the reset and flatten kernel with appropriate grid/block dimensions. 

143 """ 

144 threads_per_block = 256 

145 blocks_per_grid = (num_nodes + threads_per_block - 1) // threads_per_block 

146 

147 _reset_and_flatten_kernel( 

148 (blocks_per_grid,), (threads_per_block,), 

149 (parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes) 

150 ) 

151 

152 

153def launch_find_minimum_edges_kernel( 

154 edges_from: "cp.ndarray", # type: ignore 

155 edges_to: "cp.ndarray", # type: ignore 

156 edges_quality: "cp.ndarray", # type: ignore 

157 parent: "cp.ndarray", # type: ignore 

158 cheapest_edge_idx: "cp.ndarray", # type: ignore 

159 cheapest_edge_weight_int: "cp.ndarray", # type: ignore 

160 num_edges: int 

161) -> None: 

162 """ 

163 Launch the minimum edge finding kernel with appropriate dimensions. 

164 """ 

165 threads_per_block = 256 

166 blocks_per_grid = (num_edges + threads_per_block - 1) // threads_per_block 

167 

168 _find_minimum_edges_kernel( 

169 (blocks_per_grid,), (threads_per_block,), 

170 (edges_from, edges_to, edges_quality, parent, 

171 cheapest_edge_idx, cheapest_edge_weight_int, num_edges) 

172 ) 

173 

174 

175def launch_union_components_kernel( 

176 cheapest_edge_idx: "cp.ndarray", # type: ignore 

177 edges_from: "cp.ndarray", # type: ignore 

178 edges_to: "cp.ndarray", # type: ignore 

179 edges_dx: "cp.ndarray", # type: ignore 

180 edges_dy: "cp.ndarray", # type: ignore 

181 parent: "cp.ndarray", # type: ignore 

182 rank: "cp.ndarray", # type: ignore 

183 mst_from: "cp.ndarray", # type: ignore 

184 mst_to: "cp.ndarray", # type: ignore 

185 mst_dx: "cp.ndarray", # type: ignore 

186 mst_dy: "cp.ndarray", # type: ignore 

187 mst_count: "cp.ndarray", # type: ignore 

188 num_nodes: int 

189) -> None: 

190 """ 

191 Launch the union components kernel - pure GPU, no CPU sync. 

192 """ 

193 # Launch kernel without CPU synchronization 

194 threads_per_block = 256 

195 blocks_per_grid = (num_nodes + threads_per_block - 1) // threads_per_block 

196 

197 _union_components_kernel( 

198 (blocks_per_grid,), (threads_per_block,), 

199 (cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy, 

200 parent, rank, mst_from, mst_to, mst_dx, mst_dy, mst_count, num_nodes) 

201 ) 

202 

203 

204def gpu_component_count(parent: "cp.ndarray") -> "cp.ndarray": # type: ignore 

205 """ 

206 Count number of distinct components in flattened union-find - pure GPU. 

207 Returns GPU array, no CPU sync. 

208 """ 

209 # After flattening, roots are nodes where parent[i] == i 

210 roots = (parent == cp.arange(len(parent))) 

211 return cp.sum(roots) # Return GPU array, not CPU int