Coverage for openhcs/processing/backends/pos_gen/mist/boruvka_mst.py: 11.2%
67 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Borůvka's Minimum Spanning Tree Algorithm for MIST
4GPU-accelerated MST construction using parallel Borůvka's algorithm.
5"""
7from __future__ import annotations
8from typing import TYPE_CHECKING
10from openhcs.core.utils import optional_import
11from .gpu_kernels import (
12 launch_reset_flatten_kernel,
13 launch_find_minimum_edges_kernel,
14 launch_union_components_kernel
15)
17# For type checking only
18if TYPE_CHECKING: 18 ↛ 19line 18 didn't jump to line 19 because the condition on line 18 was never true
19 import cupy as cp
21# Import CuPy as an optional dependency
22cp = optional_import("cupy")
25def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
26 """Validate that the input is a CuPy array."""
27 if not isinstance(array, cp.ndarray):
28 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
31def _validate_mst_inputs(
32 connection_from: "cp.ndarray", # type: ignore
33 connection_to: "cp.ndarray", # type: ignore
34 connection_dx: "cp.ndarray", # type: ignore
35 connection_dy: "cp.ndarray", # type: ignore
36 connection_quality: "cp.ndarray", # type: ignore
37 num_nodes: int
38) -> None:
39 """Validate MST input arrays."""
40 arrays = [connection_from, connection_to, connection_dx, connection_dy, connection_quality]
41 names = ["connection_from", "connection_to", "connection_dx", "connection_dy", "connection_quality"]
43 for array, name in zip(arrays, names):
44 _validate_cupy_array(array, name)
46 # Check all arrays have same length
47 lengths = [len(arr) for arr in arrays]
48 if not all(length == lengths[0] for length in lengths):
49 raise ValueError(f"All connection arrays must have same length, got {lengths}")
51 # Check node indices are valid
52 if len(connection_from) > 0:
53 max_from = int(cp.max(connection_from))
54 max_to = int(cp.max(connection_to))
55 max_node = max(max_from, max_to)
56 if max_node >= num_nodes:
57 raise ValueError(f"Node index {max_node} exceeds num_nodes {num_nodes}")
60def build_mst_gpu_boruvka(
61 connection_from: "cp.ndarray", # type: ignore
62 connection_to: "cp.ndarray", # type: ignore
63 connection_dx: "cp.ndarray", # type: ignore
64 connection_dy: "cp.ndarray", # type: ignore
65 connection_quality: "cp.ndarray", # type: ignore
66 num_nodes: int
67) -> dict:
68 """
69 Full GPU Borůvka's algorithm for minimum spanning tree.
71 Uses JIT kernels with atomic operations for true parallel execution.
72 All operations remain on GPU with no CPU-GPU synchronization in inner loops.
73 """
76 # Validate inputs
77 _validate_mst_inputs(connection_from, connection_to, connection_dx,
78 connection_dy, connection_quality, num_nodes)
80 if len(connection_from) == 0:
81 return {'edges': []}
83 # Initialize GPU data structures
84 num_edges = len(connection_from)
86 # Union-find structure (flattened for O(1) lookups)
87 parent = cp.arange(num_nodes, dtype=cp.int32)
88 rank = cp.zeros(num_nodes, dtype=cp.int32)
90 # Component minimum edge tracking (use int32 for atomic operations)
91 cheapest_edge_idx = cp.full(num_nodes, -1, dtype=cp.int32)
92 cheapest_edge_weight_int = cp.full(num_nodes, 2147483647, dtype=cp.int32) # Max int32 value
94 # MST result storage
95 mst_edges_from = cp.zeros(num_nodes - 1, dtype=cp.int32)
96 mst_edges_to = cp.zeros(num_nodes - 1, dtype=cp.int32)
97 mst_edges_dx = cp.zeros(num_nodes - 1, dtype=cp.float32)
98 mst_edges_dy = cp.zeros(num_nodes - 1, dtype=cp.float32)
99 mst_count = cp.array([0], dtype=cp.int32)
101 # Sort edges by source vertex for cache locality
102 sort_indices = cp.argsort(connection_from)
103 edges_from = connection_from[sort_indices]
104 edges_to = connection_to[sort_indices]
105 edges_dx = connection_dx[sort_indices]
106 edges_dy = connection_dy[sort_indices]
107 edges_quality = connection_quality[sort_indices]
109 # Main Borůvka's loop - O(log V) iterations
110 max_iterations = int(cp.ceil(cp.log2(num_nodes))) + 1
113 for iteration in range(max_iterations):
114 print(f"🔥 Iteration {iteration}: Starting...")
116 # Kernel 1: Reset and flatten union-find trees
117 print(f"🔥 Iteration {iteration}: Launching reset/flatten kernel...")
118 launch_reset_flatten_kernel(
119 parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes
120 )
122 # Kernel 2: Find minimum edge per component (parallel)
123 print(f"🔥 Iteration {iteration}: Launching find minimum edges kernel...")
124 launch_find_minimum_edges_kernel(
125 edges_from, edges_to, edges_quality, parent,
126 cheapest_edge_idx, cheapest_edge_weight_int, num_edges
127 )
129 # Kernel 3: Union components and update MST
130 print(f"🔥 Iteration {iteration}: Launching union components kernel...")
131 launch_union_components_kernel(
132 cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy,
133 parent, rank, mst_edges_from, mst_edges_to, mst_edges_dx, mst_edges_dy,
134 mst_count, num_nodes
135 )
137 print(f"🔥 Iteration {iteration}: Kernel launched (pure GPU)")
139 # Pure GPU termination check - no CPU sync
140 # Use fixed iteration count instead of dynamic checking
141 # This eliminates CPU-GPU synchronization bottleneck
143 # Convert result to expected format
144 final_mst_count = int(mst_count[0])
145 selected_edges = []
147 # Debug: Print MST construction results
148 print(f"Borůvka MST: {final_mst_count} edges constructed (expected: {num_nodes-1})")
150 # Bounds check to prevent crash
151 max_edges = num_nodes - 1
152 if final_mst_count > max_edges:
153 print(f"🔥 WARNING: MST count {final_mst_count} exceeds maximum {max_edges}, clamping")
154 final_mst_count = max_edges
156 for i in range(final_mst_count):
157 edge = {
158 'from': int(mst_edges_from[i]),
159 'to': int(mst_edges_to[i]),
160 'dx': float(mst_edges_dx[i]),
161 'dy': float(mst_edges_dy[i]),
162 'quality': 0.0 # Could be stored if needed
163 }
164 selected_edges.append(edge)
166 # Debug: Print first few edges
167 if i < 3:
168 print(f" Edge {i}: {edge['from']} -> {edge['to']}, dx={edge['dx']:.3f}, dy={edge['dy']:.3f}")
170 return {'edges': selected_edges}