Coverage for openhcs/processing/backends/assemblers/assemble_stack_cpu.py: 61.9%

183 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2CPU implementation of image assembly functions with fixed blending. 

3""" 

4from __future__ import annotations 

5 

6import logging 

7from typing import TYPE_CHECKING, List, Tuple, Union 

8 

9from openhcs.core.memory.decorators import numpy as numpy_func 

10from openhcs.core.pipeline.function_contracts import special_inputs 

11 

12# For type checking only 

13if TYPE_CHECKING: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true

14 import numpy as np 

15 from scipy.ndimage import shift as subpixel_shift 

16 

17# Import NumPy 

18import numpy as np # type: ignore 

19from scipy.ndimage import shift as subpixel_shift # type: ignore 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24def _get_all_overlapping_pairs(positions: "np.ndarray", tile_shape: tuple) -> list: 

25 """ 

26 Find ALL overlapping tile pairs with edge directions. 

27 [Keep this exactly as it was - it works fine] 

28 """ 

29 height, width = tile_shape 

30 N = positions.shape[0] 

31 

32 if N <= 1: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true

33 return [] 

34 

35 # Vectorized computation of ALL pairwise overlaps 

36 pos_i = positions[:, np.newaxis, :] 

37 pos_j = positions[np.newaxis, :, :] 

38 

39 xi, yi = pos_i[:, :, 0], pos_i[:, :, 1] 

40 xj, yj = pos_j[:, :, 0], pos_j[:, :, 1] 

41 

42 left_i, right_i = xi, xi + width 

43 top_i, bottom_i = yi, yi + height 

44 left_j, right_j = xj, xj + width 

45 top_j, bottom_j = yj, yj + height 

46 

47 x_overlap = np.maximum(0, np.minimum(right_i, right_j) - np.maximum(left_i, left_j)) 

48 y_overlap = np.maximum(0, np.minimum(bottom_i, bottom_j) - np.maximum(top_i, top_j)) 

49 

50 valid_overlap = (x_overlap > 0) & (y_overlap > 0) & (np.arange(N)[:, None] != np.arange(N)[None, :]) 

51 

52 edge_pairs = [] 

53 overlapping_pairs = np.where(valid_overlap) 

54 

55 for idx in range(len(overlapping_pairs[0])): 

56 i, j = overlapping_pairs[0][idx], overlapping_pairs[1][idx] 

57 

58 x_overlap_val = float(x_overlap[i, j]) 

59 y_overlap_val = float(y_overlap[i, j]) 

60 

61 xi_val, yi_val = positions[i, 0], positions[i, 1] 

62 xj_val, yj_val = positions[j, 0], positions[j, 1] 

63 

64 if x_overlap_val > 0: 64 ↛ 70line 64 didn't jump to line 70 because the condition on line 64 was always true

65 if xj_val < xi_val: 

66 edge_pairs.append((i, j, 'left', x_overlap_val)) 

67 elif xj_val > xi_val: 

68 edge_pairs.append((i, j, 'right', x_overlap_val)) 

69 

70 if y_overlap_val > 0: 70 ↛ 55line 70 didn't jump to line 55 because the condition on line 70 was always true

71 if yj_val < yi_val: 

72 edge_pairs.append((i, j, 'top', y_overlap_val)) 

73 elif yj_val > yi_val: 

74 edge_pairs.append((i, j, 'bottom', y_overlap_val)) 

75 

76 return edge_pairs 

77 

78 

79def _create_fixed_blend_mask( 

80 tile_shape: tuple, 

81 edge_overlaps: dict, 

82 margin_ratio: float = 0.1 

83) -> "np.ndarray": 

84 """ 

85 Create blend mask with FIXED margin ratio using WORKING logic from old version. 

86 CRITICAL: Uses endpoint=False like the old working version. 

87 """ 

88 height, width = tile_shape 

89 

90 # Create 1D weights 

91 y_weight = np.ones(height, dtype=np.float32) 

92 x_weight = np.ones(width, dtype=np.float32) 

93 

94 # Fixed margins (same as old working version) 

95 margin_pixels_y = int(height * margin_ratio) 

96 margin_pixels_x = int(width * margin_ratio) 

97 

98 # Apply gradients ONLY where there are overlaps (same as old working version) 

99 # CRITICAL: endpoint=False (this is what made the old version work!) 

100 if 'top' in edge_overlaps and margin_pixels_y > 0: 

101 y_weight[:margin_pixels_y] = np.linspace(0, 1, margin_pixels_y, endpoint=False) 

102 

103 if 'bottom' in edge_overlaps and margin_pixels_y > 0: 

104 y_weight[-margin_pixels_y:] = np.linspace(1, 0, margin_pixels_y, endpoint=False) 

105 

106 if 'left' in edge_overlaps and margin_pixels_x > 0: 

107 x_weight[:margin_pixels_x] = np.linspace(0, 1, margin_pixels_x, endpoint=False) 

108 

109 if 'right' in edge_overlaps and margin_pixels_x > 0: 

110 x_weight[-margin_pixels_x:] = np.linspace(1, 0, margin_pixels_x, endpoint=False) 

111 

112 # Use outer product (same as old working version) 

113 mask = np.outer(y_weight, x_weight) 

114 return mask.astype(np.float32) 

115 

116 

117def _create_dynamic_blend_mask( 

118 tile_shape: tuple, 

119 edge_overlaps: dict, 

120 overlap_fraction: float = 1.0 

121) -> "np.ndarray": 

122 """ 

123 Create blend mask based on actual overlap amounts using WORKING logic from old version. 

124 CRITICAL: Uses endpoint=False and same logic as old working version. 

125 """ 

126 height, width = tile_shape 

127 

128 # Create 1D weights 

129 y_weight = np.ones(height, dtype=np.float32) 

130 x_weight = np.ones(width, dtype=np.float32) 

131 

132 # Process each edge based on actual overlap (same as old working version) 

133 # CRITICAL: endpoint=False (this is what made the old version work!) 

134 if 'top' in edge_overlaps: 

135 overlap_pixels = int(edge_overlaps['top'] * overlap_fraction) 

136 if overlap_pixels > 0: 

137 y_weight[:overlap_pixels] = np.linspace(0, 1, overlap_pixels, endpoint=False) 

138 

139 if 'bottom' in edge_overlaps: 

140 overlap_pixels = int(edge_overlaps['bottom'] * overlap_fraction) 

141 if overlap_pixels > 0: 

142 y_weight[-overlap_pixels:] = np.linspace(1, 0, overlap_pixels, endpoint=False) 

143 

144 if 'left' in edge_overlaps: 

145 overlap_pixels = int(edge_overlaps['left'] * overlap_fraction) 

146 if overlap_pixels > 0: 

147 x_weight[:overlap_pixels] = np.linspace(0, 1, overlap_pixels, endpoint=False) 

148 

149 if 'right' in edge_overlaps: 

150 overlap_pixels = int(edge_overlaps['right'] * overlap_fraction) 

151 if overlap_pixels > 0: 

152 x_weight[-overlap_pixels:] = np.linspace(1, 0, overlap_pixels, endpoint=False) 

153 

154 # Use outer product (same as old working version) 

155 mask = np.outer(y_weight, x_weight) 

156 return mask.astype(np.float32) 

157 

158 

159@special_inputs("positions") 

160@numpy_func 

161def assemble_stack_cpu( 

162 image_tiles: "np.ndarray", 

163 positions: Union[List[Tuple[float, float]], "np.ndarray"], 

164 blend_method: str = "fixed", 

165 fixed_margin_ratio: float = 0.1, 

166 overlap_blend_fraction: float = 1.0 

167) -> "np.ndarray": 

168 """ 

169 Assembles tiles with simple, working blending approach. 

170  

171 Args: 

172 image_tiles: 3D array of tiles (N, H, W) 

173 positions: List of (x, y) tuples or 2D array [N, 2] 

174 blend_method: "none", "fixed", or "dynamic" 

175 fixed_margin_ratio: Ratio for fixed blending (e.g., 0.1 = 10%) 

176 overlap_blend_fraction: For dynamic mode, fraction of overlap to blend 

177 use_endpoint: Whether to include endpoint in gradients 

178 """ 

179 # --- 1. Validate inputs --- 

180 if not isinstance(image_tiles, np.ndarray) or image_tiles.ndim != 3: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true

181 raise TypeError("image_tiles must be a 3D NumPy ndarray of shape (N, H, W).") 

182 

183 if image_tiles.shape[0] == 0: 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 logger.warning("image_tiles array is empty (0 tiles).") 

185 return np.array([[[]]], dtype=np.uint16) 

186 

187 # Convert positions to numpy 

188 if isinstance(positions, list): 188 ↛ 193line 188 didn't jump to line 193 because the condition on line 188 was always true

189 if not positions or not isinstance(positions[0], tuple) or len(positions[0]) != 2: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true

190 raise TypeError("positions must be a list of (x, y) tuples.") 

191 positions = np.array(positions, dtype=np.float32) 

192 else: 

193 if not isinstance(positions, np.ndarray): 

194 positions = to_numpy(positions) 

195 if positions.ndim != 2 or positions.shape[1] != 2: 

196 raise TypeError("positions must be an array of shape [N, 2].") 

197 

198 if image_tiles.shape[0] != positions.shape[0]: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true

199 raise ValueError(f"Mismatch: {image_tiles.shape[0]} tiles vs {positions.shape[0]} positions.") 

200 

201 num_tiles, tile_h, tile_w = image_tiles.shape 

202 tile_shape = (tile_h, tile_w) 

203 

204 # Convert to float32 

205 image_tiles_float = image_tiles.astype(np.float32) 

206 

207 # --- 2. Compute canvas bounds --- 

208 min_x = np.floor(np.min(positions[:, 0])).astype(int) 

209 min_y = np.floor(np.min(positions[:, 1])).astype(int) 

210 max_x = np.ceil(np.max(positions[:, 0]) + tile_w).astype(int) 

211 max_y = np.ceil(np.max(positions[:, 1]) + tile_h).astype(int) 

212 

213 canvas_width = max_x - min_x 

214 canvas_height = max_y - min_y 

215 

216 if canvas_width <= 0 or canvas_height <= 0: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 logger.warning(f"Invalid canvas size: {canvas_height}x{canvas_width}") 

218 return np.array([], dtype=np.uint16) 

219 

220 composite_accum = np.zeros((canvas_height, canvas_width), dtype=np.float32) 

221 weight_accum = np.zeros((canvas_height, canvas_width), dtype=np.float32) 

222 

223 # --- 3. Create blend masks --- 

224 if blend_method == "none": 224 ↛ 225line 224 didn't jump to line 225 because the condition on line 224 was never true

225 blend_masks = [np.ones(tile_shape, dtype=np.float32) for _ in range(num_tiles)] 

226 

227 else: 

228 # Find overlaps 

229 edge_pairs = _get_all_overlapping_pairs(positions, tile_shape) 

230 tile_overlaps = [{} for _ in range(num_tiles)] 

231 

232 # Build overlap info per tile 

233 for tile_i, tile_j, edge_direction, pixel_overlap in edge_pairs: 

234 if edge_direction not in tile_overlaps[tile_i]: 

235 tile_overlaps[tile_i][edge_direction] = pixel_overlap 

236 else: 

237 # Keep maximum overlap 

238 tile_overlaps[tile_i][edge_direction] = max( 

239 tile_overlaps[tile_i][edge_direction], pixel_overlap 

240 ) 

241 

242 # Create masks using WORKING logic from old version 

243 blend_masks = [] 

244 for i in range(num_tiles): 

245 if blend_method == "fixed": 245 ↛ 251line 245 didn't jump to line 251 because the condition on line 245 was always true

246 mask = _create_fixed_blend_mask( 

247 tile_shape, 

248 tile_overlaps[i], 

249 margin_ratio=fixed_margin_ratio 

250 ) 

251 elif blend_method == "dynamic": 

252 mask = _create_dynamic_blend_mask( 

253 tile_shape, 

254 tile_overlaps[i], 

255 overlap_fraction=overlap_blend_fraction 

256 ) 

257 else: 

258 raise ValueError(f"Unknown blend_method: {blend_method}") 

259 

260 blend_masks.append(mask) 

261 

262 # --- 4. Place tiles --- 

263 for i in range(num_tiles): 

264 tile = image_tiles_float[i] 

265 pos_x, pos_y = positions[i] 

266 

267 # Canvas position 

268 target_x = pos_x - min_x 

269 target_y = pos_y - min_y 

270 

271 # Integer and fractional parts 

272 x_int = int(np.floor(target_x)) 

273 y_int = int(np.floor(target_y)) 

274 x_frac = target_x - x_int 

275 y_frac = target_y - y_int 

276 

277 # Subpixel shift 

278 shift_x = -x_frac 

279 shift_y = -y_frac 

280 

281 shifted_tile = subpixel_shift( 

282 tile, 

283 shift=(shift_y, shift_x), 

284 order=1, 

285 mode='constant', 

286 cval=0.0 

287 ) 

288 

289 # Apply blend mask 

290 blended_tile = shifted_tile * blend_masks[i] 

291 

292 # Canvas bounds 

293 y_start = y_int 

294 y_end = y_start + tile_h 

295 x_start = x_int 

296 x_end = x_start + tile_w 

297 

298 # Tile bounds (for edge cases) 

299 tile_y_start = 0 

300 tile_y_end = tile_h 

301 tile_x_start = 0 

302 tile_x_end = tile_w 

303 

304 # Clip to canvas 

305 if y_start < 0: 305 ↛ 306line 305 didn't jump to line 306 because the condition on line 305 was never true

306 tile_y_start = -y_start 

307 y_start = 0 

308 if x_start < 0: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 tile_x_start = -x_start 

310 x_start = 0 

311 if y_end > canvas_height: 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true

312 tile_y_end -= (y_end - canvas_height) 

313 y_end = canvas_height 

314 if x_end > canvas_width: 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true

315 tile_x_end -= (x_end - canvas_width) 

316 x_end = canvas_width 

317 

318 # Skip if out of bounds 

319 if (tile_y_start >= tile_y_end or tile_x_start >= tile_x_end or 319 ↛ 321line 319 didn't jump to line 321 because the condition on line 319 was never true

320 y_start >= y_end or x_start >= x_end): 

321 continue 

322 

323 # Accumulate 

324 composite_accum[y_start:y_end, x_start:x_end] += \ 

325 blended_tile[tile_y_start:tile_y_end, tile_x_start:tile_x_end] 

326 

327 weight_accum[y_start:y_end, x_start:x_end] += \ 

328 blend_masks[i][tile_y_start:tile_y_end, tile_x_start:tile_x_end] 

329 

330 # --- 5. Normalize --- 

331 epsilon = 1e-7 

332 stitched = composite_accum / (weight_accum + epsilon) 

333 

334 # Convert to uint16 

335 stitched_uint16 = np.clip(stitched, 0, 65535).astype(np.uint16) 

336 

337 return stitched_uint16.reshape(1, canvas_height, canvas_width) 

338 

339 

340def to_numpy(tensor): 

341 """Convert various tensor types to numpy""" 

342 if hasattr(tensor, 'dtype') and tensor.__class__.__module__ == 'numpy': 

343 return tensor 

344 if hasattr(tensor, 'get'): # CuPy 

345 return tensor.get() 

346 if hasattr(tensor, 'detach'): # PyTorch 

347 return tensor.detach().cpu().numpy() 

348 if hasattr(tensor, 'numpy') and hasattr(tensor, 'device'): # TF 

349 return tensor.numpy() 

350 raise ValueError(f"Unsupported tensor type: {type(tensor)}")