Coverage for openhcs/processing/backends/pos_gen/mist/position_reconstruction.py: 12.7%

53 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Position Reconstruction from MST for MIST Algorithm 

3 

4Functions for rebuilding tile positions from minimum spanning tree. 

5""" 

6from __future__ import annotations 

7 

8from typing import TYPE_CHECKING 

9 

10from openhcs.core.utils import optional_import 

11 

12# For type checking only 

13if TYPE_CHECKING: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true

14 import cupy as cp 

15 

16# Import CuPy as an optional dependency 

17cp = optional_import("cupy") 

18 

19 

20def _validate_cupy_array(array, name: str = "input") -> None: # type: ignore 

21 """Validate that the input is a CuPy array.""" 

22 if not isinstance(array, cp.ndarray): 

23 raise TypeError(f"{name} must be a CuPy array, got {type(array)}") 

24 

25 

26def rebuild_positions_from_mst_gpu( 

27 initial_positions: "cp.ndarray", # type: ignore 

28 mst_edges: dict, 

29 num_tiles: int, 

30 anchor_tile_index: int = 0 

31) -> "cp.ndarray": # type: ignore 

32 """ 

33 Rebuild tile positions from MST edges using GPU operations. 

34  

35 Args: 

36 initial_positions: Initial position estimates (Z, 2) array 

37 mst_edges: Dictionary with 'edges' list containing MST edges 

38 num_tiles: Number of tiles 

39 anchor_tile_index: Index of anchor tile (fixed at origin) 

40  

41 Returns: 

42 Reconstructed positions as (Z, 2) CuPy array 

43 """ 

44 _validate_cupy_array(initial_positions, "initial_positions") 

45 

46 if initial_positions.shape != (num_tiles, 2): 

47 raise ValueError(f"Initial positions must be ({num_tiles}, 2), got {initial_positions.shape}") 

48 

49 edges = mst_edges.get('edges', []) 

50 if not edges: 

51 print("🔥 WARNING: No MST edges provided, returning initial positions") 

52 return initial_positions.copy() 

53 

54 print(f"Position reconstruction: {len(edges)} MST edges, {num_tiles} tiles") 

55 

56 # Initialize new positions (GPU) 

57 new_positions = cp.zeros((num_tiles, 2), dtype=cp.float32) 

58 visited = cp.zeros(num_tiles, dtype=cp.bool_) 

59 

60 # Set anchor tile position 

61 new_positions[anchor_tile_index] = cp.array([0.0, 0.0]) 

62 visited[anchor_tile_index] = True 

63 

64 print(f"Anchor tile {anchor_tile_index}: (0.0, 0.0)") 

65 

66 # Build adjacency list for efficient traversal 

67 adjacency = [[] for _ in range(num_tiles)] 

68 for edge in edges: 

69 from_idx = edge['from'] 

70 to_idx = edge['to'] 

71 dx = edge['dx'] 

72 dy = edge['dy'] 

73 

74 # Add bidirectional edges 

75 adjacency[from_idx].append({'to': to_idx, 'dx': dx, 'dy': dy}) 

76 adjacency[to_idx].append({'to': from_idx, 'dx': -dx, 'dy': -dy}) 

77 

78 # Breadth-first traversal to set positions 

79 queue = [anchor_tile_index] 

80 

81 while queue: 

82 current_tile = queue.pop(0) 

83 current_pos = new_positions[current_tile] 

84 

85 # Process all neighbors 

86 for neighbor_info in adjacency[current_tile]: 

87 neighbor_tile = neighbor_info['to'] 

88 

89 if not visited[neighbor_tile]: 

90 # Calculate neighbor position 

91 dx = neighbor_info['dx'] 

92 dy = neighbor_info['dy'] 

93 neighbor_pos = current_pos + cp.array([dx, dy]) 

94 

95 new_positions[neighbor_tile] = neighbor_pos 

96 visited[neighbor_tile] = True 

97 queue.append(neighbor_tile) 

98 

99 # Check if all tiles were visited 

100 unvisited_count = int(cp.sum(~visited)) 

101 if unvisited_count > 0: 

102 print(f"🔥 WARNING: {unvisited_count} tiles not reachable from anchor tile") 

103 

104 # For unvisited tiles, use initial positions 

105 unvisited_mask = ~visited 

106 new_positions[unvisited_mask] = initial_positions[unvisited_mask] 

107 

108 return new_positions 

109 

110 

111def build_mst_gpu( 

112 connection_from: "cp.ndarray", # type: ignore 

113 connection_to: "cp.ndarray", # type: ignore 

114 connection_dx: "cp.ndarray", # type: ignore 

115 connection_dy: "cp.ndarray", # type: ignore 

116 connection_quality: "cp.ndarray", # type: ignore 

117 num_tiles: int 

118) -> dict: 

119 """ 

120 Build MST using GPU Borůvka's algorithm. 

121  

122 This is a wrapper that imports and calls the Borůvka implementation. 

123 """ 

124 from .boruvka_mst import build_mst_gpu_boruvka 

125 

126 return build_mst_gpu_boruvka( 

127 connection_from, connection_to, connection_dx, 

128 connection_dy, connection_quality, num_tiles 

129 )