Coverage for openhcs/processing/backends/pos_gen/mist/gpu_kernels.py: 15.5%
80 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2GPU JIT Kernels for Borůvka's MST Algorithm
4All CUDA kernels for parallel MST construction using CuPy JIT.
5"""
6from __future__ import annotations
8from typing import TYPE_CHECKING
10from openhcs.core.utils import optional_import
12jit = optional_import("cupyx.jit")
13# For type checking only
14if TYPE_CHECKING: 14 ↛ 15line 14 didn't jump to line 15 because the condition on line 14 was never true
15 import cupy as cp
17# Import CuPy as an optional dependency
18cp = optional_import("cupy")
21@jit.rawkernel() if jit else lambda f: f
22def _reset_and_flatten_kernel(
23 parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes
24):
25 """
26 Kernel 1: Reset cheapest edge arrays and flatten union-find trees.
27 """
28 tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
30 if tid < num_nodes:
31 # Reset cheapest edge tracking for this node
32 cheapest_edge_idx[tid] = -1
33 cheapest_edge_weight_int[tid] = 2147483647 # Max int32 value
35 # Flatten union-find tree: make this node point directly to root
36 # Correct iterative path compression
37 current = tid
38 while parent[current] != current:
39 parent[current] = parent[parent[current]] # Compress one level at a time
40 current = parent[current]
43@jit.rawkernel() if jit else lambda f: f
44def _find_minimum_edges_kernel(
45 edges_from, edges_to, edges_quality, parent,
46 cheapest_edge_idx, cheapest_edge_weight_int, num_edges
47):
48 """
49 Kernel 2: Find minimum weight edge for each component (using int32 for atomics).
50 """
51 tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
53 if tid < num_edges:
54 # Get edge endpoints and their components
55 from_node = edges_from[tid]
56 to_node = edges_to[tid]
57 from_comp = parent[from_node]
58 to_comp = parent[to_node]
60 # Only process if NOT a self-edge (no return statement)
61 if from_comp != to_comp:
62 # Get edge quality (higher is better, so negate for min comparison)
63 edge_quality = edges_quality[tid]
64 # Convert to integer by scaling and negating (higher quality = lower int value)
65 # Scale by 1000000 to preserve precision, then negate
66 edge_weight_int = int(-edge_quality * 1000000)
68 # Atomic update cheapest edge for 'from' component
69 jit.atomic_min(cheapest_edge_weight_int, from_comp, edge_weight_int)
70 if cheapest_edge_weight_int[from_comp] == edge_weight_int:
71 cheapest_edge_idx[from_comp] = tid
73 # Atomic update cheapest edge for 'to' component
74 jit.atomic_min(cheapest_edge_weight_int, to_comp, edge_weight_int)
75 if cheapest_edge_weight_int[to_comp] == edge_weight_int:
76 cheapest_edge_idx[to_comp] = tid
79@jit.rawkernel() if jit else lambda f: f
80def _union_components_kernel(
81 cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy,
82 parent, rank, mst_from, mst_to, mst_dx, mst_dy, mst_count, num_nodes
83):
84 """
85 Kernel 3: Union components (no return statements).
86 """
87 tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
89 if tid < num_nodes:
90 edge_idx = cheapest_edge_idx[tid]
92 # Only process if valid edge exists (no return statement)
93 if edge_idx >= 0:
94 # Get edge details
95 from_node = edges_from[edge_idx]
96 to_node = edges_to[edge_idx]
98 # Find current component roots
99 from_root = parent[from_node]
100 to_root = parent[to_node]
102 # Only process if NOT already in same component (no return statement)
103 if from_root != to_root:
104 # Atomic union operation with rank-based optimization
105 rank1 = rank[from_root]
106 rank2 = rank[to_root]
108 union_success = False
109 if rank1 < rank2:
110 # Make to_root the parent of from_root
111 old_parent = jit.atomic_cas(parent, from_root, from_root, to_root)
112 union_success = (old_parent == from_root)
113 elif rank1 > rank2:
114 # Make from_root the parent of to_root
115 old_parent = jit.atomic_cas(parent, to_root, to_root, from_root)
116 union_success = (old_parent == to_root)
117 else:
118 # Equal ranks: make from_root parent and increment its rank
119 old_parent = jit.atomic_cas(parent, to_root, to_root, from_root)
120 if old_parent == to_root:
121 jit.atomic_add(rank, from_root, 1)
122 union_success = True
124 # If union was successful, atomically add edge to MST
125 if union_success:
126 mst_slot = jit.atomic_add(mst_count, 0, 1)
127 if mst_slot < num_nodes - 1:
128 mst_from[mst_slot] = from_node
129 mst_to[mst_slot] = to_node
130 mst_dx[mst_slot] = edges_dx[edge_idx]
131 mst_dy[mst_slot] = edges_dy[edge_idx]
134def launch_reset_flatten_kernel(
135 parent: "cp.ndarray", # type: ignore
136 rank: "cp.ndarray", # type: ignore
137 cheapest_edge_idx: "cp.ndarray", # type: ignore
138 cheapest_edge_weight_int: "cp.ndarray", # type: ignore
139 num_nodes: int
140) -> None:
141 """
142 Launch the reset and flatten kernel with appropriate grid/block dimensions.
143 """
144 threads_per_block = 256
145 blocks_per_grid = (num_nodes + threads_per_block - 1) // threads_per_block
147 _reset_and_flatten_kernel(
148 (blocks_per_grid,), (threads_per_block,),
149 (parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes)
150 )
153def launch_find_minimum_edges_kernel(
154 edges_from: "cp.ndarray", # type: ignore
155 edges_to: "cp.ndarray", # type: ignore
156 edges_quality: "cp.ndarray", # type: ignore
157 parent: "cp.ndarray", # type: ignore
158 cheapest_edge_idx: "cp.ndarray", # type: ignore
159 cheapest_edge_weight_int: "cp.ndarray", # type: ignore
160 num_edges: int
161) -> None:
162 """
163 Launch the minimum edge finding kernel with appropriate dimensions.
164 """
165 threads_per_block = 256
166 blocks_per_grid = (num_edges + threads_per_block - 1) // threads_per_block
168 _find_minimum_edges_kernel(
169 (blocks_per_grid,), (threads_per_block,),
170 (edges_from, edges_to, edges_quality, parent,
171 cheapest_edge_idx, cheapest_edge_weight_int, num_edges)
172 )
175def launch_union_components_kernel(
176 cheapest_edge_idx: "cp.ndarray", # type: ignore
177 edges_from: "cp.ndarray", # type: ignore
178 edges_to: "cp.ndarray", # type: ignore
179 edges_dx: "cp.ndarray", # type: ignore
180 edges_dy: "cp.ndarray", # type: ignore
181 parent: "cp.ndarray", # type: ignore
182 rank: "cp.ndarray", # type: ignore
183 mst_from: "cp.ndarray", # type: ignore
184 mst_to: "cp.ndarray", # type: ignore
185 mst_dx: "cp.ndarray", # type: ignore
186 mst_dy: "cp.ndarray", # type: ignore
187 mst_count: "cp.ndarray", # type: ignore
188 num_nodes: int
189) -> None:
190 """
191 Launch the union components kernel - pure GPU, no CPU sync.
192 """
193 # Launch kernel without CPU synchronization
194 threads_per_block = 256
195 blocks_per_grid = (num_nodes + threads_per_block - 1) // threads_per_block
197 _union_components_kernel(
198 (blocks_per_grid,), (threads_per_block,),
199 (cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy,
200 parent, rank, mst_from, mst_to, mst_dx, mst_dy, mst_count, num_nodes)
201 )
204def gpu_component_count(parent: "cp.ndarray") -> "cp.ndarray": # type: ignore
205 """
206 Count number of distinct components in flattened union-find - pure GPU.
207 Returns GPU array, no CPU sync.
208 """
209 # After flattening, roots are nodes where parent[i] == i
210 roots = (parent == cp.arange(len(parent)))
211 return cp.sum(roots) # Return GPU array, not CPU int