Coverage for openhcs/processing/backends/pos_gen/mist/position_reconstruction.py: 12.7%
53 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Position Reconstruction from MST for MIST Algorithm
4Functions for rebuilding tile positions from minimum spanning tree.
5"""
6from __future__ import annotations
8from typing import TYPE_CHECKING
10from openhcs.core.utils import optional_import
12# For type checking only
13if TYPE_CHECKING: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true
14 import cupy as cp
16# Import CuPy as an optional dependency
17cp = optional_import("cupy")
20def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore
21 """Validate that the input is a CuPy array."""
22 if not isinstance(array, cp.ndarray):
23 raise TypeError(f"{name} must be a CuPy array, got {type(array)}")
26def rebuild_positions_from_mst_gpu(
27 initial_positions: "cp.ndarray", # type: ignore
28 mst_edges: dict,
29 num_tiles: int,
30 anchor_tile_index: int = 0
31) -> "cp.ndarray": # type: ignore
32 """
33 Rebuild tile positions from MST edges using GPU operations.
35 Args:
36 initial_positions: Initial position estimates (Z, 2) array
37 mst_edges: Dictionary with 'edges' list containing MST edges
38 num_tiles: Number of tiles
39 anchor_tile_index: Index of anchor tile (fixed at origin)
41 Returns:
42 Reconstructed positions as (Z, 2) CuPy array
43 """
44 _validate_cupy_array(initial_positions, "initial_positions")
46 if initial_positions.shape != (num_tiles, 2):
47 raise ValueError(f"Initial positions must be ({num_tiles}, 2), got {initial_positions.shape}")
49 edges = mst_edges.get('edges', [])
50 if not edges:
51 print("🔥 WARNING: No MST edges provided, returning initial positions")
52 return initial_positions.copy()
54 print(f"Position reconstruction: {len(edges)} MST edges, {num_tiles} tiles")
56 # Initialize new positions (GPU)
57 new_positions = cp.zeros((num_tiles, 2), dtype=cp.float32)
58 visited = cp.zeros(num_tiles, dtype=cp.bool_)
60 # Set anchor tile position
61 new_positions[anchor_tile_index] = cp.array([0.0, 0.0])
62 visited[anchor_tile_index] = True
64 print(f"Anchor tile {anchor_tile_index}: (0.0, 0.0)")
66 # Build adjacency list for efficient traversal
67 adjacency = [[] for _ in range(num_tiles)]
68 for edge in edges:
69 from_idx = edge['from']
70 to_idx = edge['to']
71 dx = edge['dx']
72 dy = edge['dy']
74 # Add bidirectional edges
75 adjacency[from_idx].append({'to': to_idx, 'dx': dx, 'dy': dy})
76 adjacency[to_idx].append({'to': from_idx, 'dx': -dx, 'dy': -dy})
78 # Breadth-first traversal to set positions
79 queue = [anchor_tile_index]
81 while queue:
82 current_tile = queue.pop(0)
83 current_pos = new_positions[current_tile]
85 # Process all neighbors
86 for neighbor_info in adjacency[current_tile]:
87 neighbor_tile = neighbor_info['to']
89 if not visited[neighbor_tile]:
90 # Calculate neighbor position
91 dx = neighbor_info['dx']
92 dy = neighbor_info['dy']
93 neighbor_pos = current_pos + cp.array([dx, dy])
95 new_positions[neighbor_tile] = neighbor_pos
96 visited[neighbor_tile] = True
97 queue.append(neighbor_tile)
99 # Check if all tiles were visited
100 unvisited_count = int(cp.sum(~visited))
101 if unvisited_count > 0:
102 print(f"🔥 WARNING: {unvisited_count} tiles not reachable from anchor tile")
104 # For unvisited tiles, use initial positions
105 unvisited_mask = ~visited
106 new_positions[unvisited_mask] = initial_positions[unvisited_mask]
108 return new_positions
111def build_mst_gpu(
112 connection_from: "cp.ndarray", # type: ignore
113 connection_to: "cp.ndarray", # type: ignore
114 connection_dx: "cp.ndarray", # type: ignore
115 connection_dy: "cp.ndarray", # type: ignore
116 connection_quality: "cp.ndarray", # type: ignore
117 num_tiles: int
118) -> dict:
119 """
120 Build MST using GPU Borůvka's algorithm.
122 This is a wrapper that imports and calls the Borůvka implementation.
123 """
124 from .boruvka_mst import build_mst_gpu_boruvka
126 return build_mst_gpu_boruvka(
127 connection_from, connection_to, connection_dx,
128 connection_dy, connection_quality, num_tiles
129 )