Coverage for openhcs/processing/backends/pos_gen/mist/boruvka_mst.py: 11.2%
67 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Borůvka's Minimum Spanning Tree Algorithm for MIST
4GPU-accelerated MST construction using parallel Borůvka's algorithm.
5"""
7from __future__ import annotations
8from typing import TYPE_CHECKING
10from openhcs.core.utils import optional_import
11from .gpu_kernels import (
12 launch_reset_flatten_kernel,
13 launch_find_minimum_edges_kernel,
14 launch_union_components_kernel,
15 gpu_component_count
16)
18# For type checking only
19if TYPE_CHECKING: 19 ↛ 20line 19 didn't jump to line 20 because the condition on line 19 was never true
20 import cupy as cp
22# Import CuPy as an optional dependency
23cp = optional_import("cupy")
26def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
27 """Validate that the input is a CuPy array."""
28 if not isinstance(array, cp.ndarray):
29 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
32def _validate_mst_inputs(
33 connection_from: "cp.ndarray", # type: ignore
34 connection_to: "cp.ndarray", # type: ignore
35 connection_dx: "cp.ndarray", # type: ignore
36 connection_dy: "cp.ndarray", # type: ignore
37 connection_quality: "cp.ndarray", # type: ignore
38 num_nodes: int
39) -> None:
40 """Validate MST input arrays."""
41 arrays = [connection_from, connection_to, connection_dx, connection_dy, connection_quality]
42 names = ["connection_from", "connection_to", "connection_dx", "connection_dy", "connection_quality"]
44 for array, name in zip(arrays, names):
45 _validate_cupy_array(array, name)
47 # Check all arrays have same length
48 lengths = [len(arr) for arr in arrays]
49 if not all(length == lengths[0] for length in lengths):
50 raise ValueError(f"All connection arrays must have same length, got {lengths}")
52 # Check node indices are valid
53 if len(connection_from) > 0:
54 max_from = int(cp.max(connection_from))
55 max_to = int(cp.max(connection_to))
56 max_node = max(max_from, max_to)
57 if max_node >= num_nodes:
58 raise ValueError(f"Node index {max_node} exceeds num_nodes {num_nodes}")
61def build_mst_gpu_boruvka(
62 connection_from: "cp.ndarray", # type: ignore
63 connection_to: "cp.ndarray", # type: ignore
64 connection_dx: "cp.ndarray", # type: ignore
65 connection_dy: "cp.ndarray", # type: ignore
66 connection_quality: "cp.ndarray", # type: ignore
67 num_nodes: int
68) -> dict:
69 """
70 Full GPU Borůvka's algorithm for minimum spanning tree.
72 Uses JIT kernels with atomic operations for true parallel execution.
73 All operations remain on GPU with no CPU-GPU synchronization in inner loops.
74 """
77 # Validate inputs
78 _validate_mst_inputs(connection_from, connection_to, connection_dx,
79 connection_dy, connection_quality, num_nodes)
81 if len(connection_from) == 0:
82 return {'edges': []}
84 # Initialize GPU data structures
85 num_edges = len(connection_from)
87 # Union-find structure (flattened for O(1) lookups)
88 parent = cp.arange(num_nodes, dtype=cp.int32)
89 rank = cp.zeros(num_nodes, dtype=cp.int32)
91 # Component minimum edge tracking (use int32 for atomic operations)
92 cheapest_edge_idx = cp.full(num_nodes, -1, dtype=cp.int32)
93 cheapest_edge_weight_int = cp.full(num_nodes, 2147483647, dtype=cp.int32) # Max int32 value
95 # MST result storage
96 mst_edges_from = cp.zeros(num_nodes - 1, dtype=cp.int32)
97 mst_edges_to = cp.zeros(num_nodes - 1, dtype=cp.int32)
98 mst_edges_dx = cp.zeros(num_nodes - 1, dtype=cp.float32)
99 mst_edges_dy = cp.zeros(num_nodes - 1, dtype=cp.float32)
100 mst_count = cp.array([0], dtype=cp.int32)
102 # Sort edges by source vertex for cache locality
103 sort_indices = cp.argsort(connection_from)
104 edges_from = connection_from[sort_indices]
105 edges_to = connection_to[sort_indices]
106 edges_dx = connection_dx[sort_indices]
107 edges_dy = connection_dy[sort_indices]
108 edges_quality = connection_quality[sort_indices]
110 # Main Borůvka's loop - O(log V) iterations
111 max_iterations = int(cp.ceil(cp.log2(num_nodes))) + 1
114 for iteration in range(max_iterations):
115 print(f"🔥 Iteration {iteration}: Starting...")
117 # Kernel 1: Reset and flatten union-find trees
118 print(f"🔥 Iteration {iteration}: Launching reset/flatten kernel...")
119 launch_reset_flatten_kernel(
120 parent, rank, cheapest_edge_idx, cheapest_edge_weight_int, num_nodes
121 )
123 # Kernel 2: Find minimum edge per component (parallel)
124 print(f"🔥 Iteration {iteration}: Launching find minimum edges kernel...")
125 launch_find_minimum_edges_kernel(
126 edges_from, edges_to, edges_quality, parent,
127 cheapest_edge_idx, cheapest_edge_weight_int, num_edges
128 )
130 # Kernel 3: Union components and update MST
131 print(f"🔥 Iteration {iteration}: Launching union components kernel...")
132 launch_union_components_kernel(
133 cheapest_edge_idx, edges_from, edges_to, edges_dx, edges_dy,
134 parent, rank, mst_edges_from, mst_edges_to, mst_edges_dx, mst_edges_dy,
135 mst_count, num_nodes
136 )
138 print(f"🔥 Iteration {iteration}: Kernel launched (pure GPU)")
140 # Pure GPU termination check - no CPU sync
141 # Use fixed iteration count instead of dynamic checking
142 # This eliminates CPU-GPU synchronization bottleneck
144 # Convert result to expected format
145 final_mst_count = int(mst_count[0])
146 selected_edges = []
148 # Debug: Print MST construction results
149 print(f"Borůvka MST: {final_mst_count} edges constructed (expected: {num_nodes-1})")
151 # Bounds check to prevent crash
152 max_edges = num_nodes - 1
153 if final_mst_count > max_edges:
154 print(f"🔥 WARNING: MST count {final_mst_count} exceeds maximum {max_edges}, clamping")
155 final_mst_count = max_edges
157 for i in range(final_mst_count):
158 edge = {
159 'from': int(mst_edges_from[i]),
160 'to': int(mst_edges_to[i]),
161 'dx': float(mst_edges_dx[i]),
162 'dy': float(mst_edges_dy[i]),
163 'quality': 0.0 # Could be stored if needed
164 }
165 selected_edges.append(edge)
167 # Debug: Print first few edges
168 if i < 3:
169 print(f" Edge {i}: {edge['from']} -> {edge['to']}, dx={edge['dx']:.3f}, dy={edge['dy']:.3f}")
171 return {'edges': selected_edges}