Coverage for openhcs/core/memory/stack_utils.py: 34.4%

160 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Stack utilities module for OpenHCS. 

3 

4This module provides functions for stacking 2D slices into a 3D array 

5and unstacking a 3D array into 2D slices, with explicit memory type handling. 

6 

7This module enforces Clause 278 — Mandatory 3D Output Enforcement: 

8All functions must return a 3D array of shape [Z, Y, X], even when operating 

9on a single 2D slice. No logic may check, coerce, or infer rank at unstack time. 

10""" 

11 

12import logging 

13from typing import Any, List 

14 

15import numpy as np 

16 

17from openhcs.constants.constants import (GPU_MEMORY_TYPES, MEMORY_TYPE_CUPY, 

18 MEMORY_TYPE_JAX, MEMORY_TYPE_NUMPY, 

19 MEMORY_TYPE_PYCLESPERANTO, MEMORY_TYPE_TENSORFLOW, 

20 MEMORY_TYPE_TORCH, MemoryType) 

21from openhcs.core.memory import MemoryWrapper 

22from openhcs.core.utils import optional_import 

23 

24logger = logging.getLogger(__name__) 

25 

26# 🔍 MEMORY CONVERSION LOGGING: Test log to verify logger is working 

27logger.debug("🔄 STACK_UTILS: Module loaded - memory conversion logging enabled") 

28 

29 

30def _is_2d(data: Any) -> bool: 

31 """ 

32 Check if data is a 2D array. 

33 

34 Args: 

35 data: Data to check 

36 

37 Returns: 

38 True if data is 2D, False otherwise 

39 """ 

40 # Check if data has a shape attribute 

41 if not hasattr(data, 'shape'): 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true

42 return False 

43 

44 # Check if shape has length 2 

45 return len(data.shape) == 2 

46 

47 

48def _is_3d(data: Any) -> bool: 

49 """ 

50 Check if data is a 3D array. 

51 

52 Args: 

53 data: Data to check 

54 

55 Returns: 

56 True if data is 3D, False otherwise 

57 """ 

58 # Check if data has a shape attribute 

59 if not hasattr(data, 'shape'): 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true

60 return False 

61 

62 # Check if shape has length 3 

63 return len(data.shape) == 3 

64 

65 

66def _detect_memory_type(data: Any) -> str: 

67 """ 

68 Detect the memory type of the data. 

69 

70 STRICT VALIDATION: Fails loudly if the memory type cannot be detected. 

71 No automatic fallback to a default memory type. 

72 

73 Args: 

74 data: The data to detect the memory type of 

75 

76 Returns: 

77 The detected memory type 

78 

79 Raises: 

80 ValueError: If the memory type cannot be detected 

81 """ 

82 # Check if it's a MemoryWrapper 

83 if isinstance(data, MemoryWrapper): 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true

84 return data.memory_type 

85 

86 # Check if it's a numpy array 

87 if isinstance(data, np.ndarray): 87 ↛ 91line 87 didn't jump to line 91 because the condition on line 87 was always true

88 return MemoryType.NUMPY.value 

89 

90 # Check if it's a cupy array 

91 cp = optional_import("cupy") 

92 if cp is not None and isinstance(data, cp.ndarray): 

93 return MemoryType.CUPY.value 

94 

95 # Check if it's a torch tensor 

96 torch = optional_import("torch") 

97 if torch is not None and isinstance(data, torch.Tensor): 

98 return MemoryType.TORCH.value 

99 

100 # Check if it's a tensorflow tensor 

101 tf = optional_import("tensorflow") 

102 if tf is not None and isinstance(data, tf.Tensor): 

103 return MemoryType.TENSORFLOW.value 

104 

105 # Check if it's a JAX array 

106 jax = optional_import("jax") 

107 jnp = optional_import("jax.numpy") if jax is not None else None 

108 if jnp is not None and isinstance(data, jnp.ndarray): 

109 return MemoryType.JAX.value 

110 

111 # Check if it's a pyclesperanto array 

112 cle = optional_import("pyclesperanto") 

113 if cle is not None and hasattr(cle, 'Array') and isinstance(data, cle.Array): 

114 return MemoryType.PYCLESPERANTO.value 

115 

116 # Fail loudly if we can't detect the type 

117 raise ValueError(f"Could not detect memory type of {type(data)}") 

118 

119 

120def _enforce_gpu_device_requirements(memory_type: str, gpu_id: int) -> None: 

121 """ 

122 Enforce GPU device requirements. 

123 

124 Args: 

125 memory_type: The memory type 

126 gpu_id: The GPU device ID 

127 

128 Raises: 

129 ValueError: If gpu_id is negative 

130 """ 

131 # For GPU memory types, validate gpu_id 

132 if memory_type in {mem_type.value for mem_type in GPU_MEMORY_TYPES}: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 if gpu_id < 0: 

134 raise ValueError(f"Invalid GPU device ID: {gpu_id}. Must be a non-negative integer.") 

135 

136 

137def stack_slices(slices: List[Any], memory_type: str, gpu_id: int) -> Any: 

138 """ 

139 Stack 2D slices into a 3D array with the specified memory type. 

140 

141 STRICT VALIDATION: Assumes all slices are 2D arrays. 

142 No automatic handling of improper inputs. 

143 

144 Args: 

145 slices: List of 2D slices (numpy arrays, cupy arrays, torch tensors, etc.) 

146 memory_type: The memory type to use for the stacked array (REQUIRED) 

147 gpu_id: The target GPU device ID (REQUIRED) 

148 

149 Returns: 

150 A 3D array with the specified memory type of shape [Z, Y, X] 

151 

152 Raises: 

153 ValueError: If memory_type is not supported or slices is empty 

154 ValueError: If gpu_id is negative for GPU memory types 

155 ValueError: If slices are not 2D arrays 

156 MemoryConversionError: If conversion fails 

157 """ 

158 if not slices: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true

159 raise ValueError("Cannot stack empty list of slices") 

160 

161 # Verify all slices are 2D 

162 for i, slice_data in enumerate(slices): 

163 if not _is_2d(slice_data): 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true

164 raise ValueError(f"Slice at index {i} is not a 2D array. All slices must be 2D.") 

165 

166 # Analyze input types for conversion planning (minimal logging) 

167 input_types = [_detect_memory_type(slice_data) for slice_data in slices] 

168 unique_input_types = set(input_types) 

169 needs_conversion = memory_type not in unique_input_types or len(unique_input_types) > 1 

170 

171 # Check GPU requirements 

172 _enforce_gpu_device_requirements(memory_type, gpu_id) 

173 

174 # Pre-allocate the final 3D array to avoid intermediate list and final stack operation 

175 first_slice = slices[0] 

176 stack_shape = (len(slices), first_slice.shape[0], first_slice.shape[1]) 

177 

178 # Create pre-allocated result array in target memory type 

179 if memory_type == MEMORY_TYPE_NUMPY: 179 ↛ 198line 179 didn't jump to line 198 because the condition on line 179 was always true

180 import numpy as np 

181 

182 # Handle torch dtypes by converting a sample slice first 

183 first_slice_source_type = _detect_memory_type(first_slice) 

184 if first_slice_source_type == MEMORY_TYPE_TORCH: 184 ↛ 186line 184 didn't jump to line 186 because the condition on line 184 was never true

185 # Convert torch tensor to numpy to get compatible dtype 

186 from openhcs.core.memory.converters import convert_memory 

187 sample_converted = convert_memory( 

188 data=first_slice, 

189 source_type=first_slice_source_type, 

190 target_type=memory_type, 

191 gpu_id=gpu_id, 

192 allow_cpu_roundtrip=True # Allow CPU roundtrip for numpy conversion 

193 ) 

194 result = np.empty(stack_shape, dtype=sample_converted.dtype) 

195 else: 

196 # Use dtype directly for non-torch types 

197 result = np.empty(stack_shape, dtype=first_slice.dtype) 

198 elif memory_type == MEMORY_TYPE_CUPY: 

199 cupy = optional_import("cupy") 

200 if cupy is None: 

201 raise ValueError(f"CuPy is required for memory type {memory_type}") 

202 with cupy.cuda.Device(gpu_id): 

203 result = cupy.empty(stack_shape, dtype=first_slice.dtype) 

204 elif memory_type == MEMORY_TYPE_TORCH: 

205 torch = optional_import("torch") 

206 if torch is None: 

207 raise ValueError(f"PyTorch is required for memory type {memory_type}") 

208 

209 # Convert first slice to get the correct torch dtype 

210 from openhcs.core.memory.converters import convert_memory 

211 first_slice_source_type = _detect_memory_type(first_slice) 

212 sample_converted = convert_memory( 

213 data=first_slice, 

214 source_type=first_slice_source_type, 

215 target_type=memory_type, 

216 gpu_id=gpu_id, 

217 allow_cpu_roundtrip=False 

218 ) 

219 

220 result = torch.empty(stack_shape, dtype=sample_converted.dtype, device=sample_converted.device) 

221 elif memory_type == MEMORY_TYPE_TENSORFLOW: 

222 tf = optional_import("tensorflow") 

223 if tf is None: 

224 raise ValueError(f"TensorFlow is required for memory type {memory_type}") 

225 with tf.device(f"/device:GPU:{gpu_id}"): 

226 result = tf.zeros(stack_shape, dtype=first_slice.dtype) # TF doesn't have empty() 

227 elif memory_type == MEMORY_TYPE_JAX: 

228 jax = optional_import("jax") 

229 if jax is None: 

230 raise ValueError(f"JAX is required for memory type {memory_type}") 

231 jnp = optional_import("jax.numpy") 

232 if jnp is None: 

233 raise ValueError(f"JAX is required for memory type {memory_type}") 

234 result = jnp.empty(stack_shape, dtype=first_slice.dtype) 

235 elif memory_type == MEMORY_TYPE_PYCLESPERANTO: 

236 cle = optional_import("pyclesperanto") 

237 if cle is None: 

238 raise ValueError(f"pyclesperanto is required for memory type {memory_type}") 

239 # For pyclesperanto, we'll build the result using concatenate_along_z 

240 # Don't pre-allocate here, we'll handle it in the loop below 

241 result = None 

242 else: 

243 raise ValueError(f"Unsupported memory type: {memory_type}") 

244 

245 # Convert each slice and assign to result array 

246 conversion_count = 0 

247 

248 # Special handling for pyclesperanto - build using concatenate_along_z 

249 if memory_type == MEMORY_TYPE_PYCLESPERANTO: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true

250 cle = optional_import("pyclesperanto") 

251 converted_slices = [] 

252 

253 for i, slice_data in enumerate(slices): 

254 source_type = _detect_memory_type(slice_data) 

255 

256 # Track conversions for batch logging 

257 if source_type != memory_type: 

258 conversion_count += 1 

259 

260 # Convert slice to pyclesperanto 

261 if source_type == memory_type: 

262 converted_data = slice_data 

263 else: 

264 from openhcs.core.memory.converters import convert_memory 

265 converted_data = convert_memory( 

266 data=slice_data, 

267 source_type=source_type, 

268 target_type=memory_type, 

269 gpu_id=gpu_id, 

270 allow_cpu_roundtrip=False 

271 ) 

272 

273 # Ensure slice is 2D, expand to 3D single slice if needed 

274 if converted_data.ndim == 2: 

275 # Convert 2D slice to 3D single slice using expand_dims equivalent 

276 converted_data = cle.push(cle.pull(converted_data)[None, ...]) 

277 

278 converted_slices.append(converted_data) 

279 

280 # Build 3D result using efficient batch concatenation 

281 if len(converted_slices) == 1: 

282 result = converted_slices[0] 

283 else: 

284 # Use divide-and-conquer approach for better performance 

285 # This reduces O(N²) copying to O(N log N) 

286 slices_to_concat = converted_slices[:] 

287 while len(slices_to_concat) > 1: 

288 new_slices = [] 

289 for i in range(0, len(slices_to_concat), 2): 

290 if i + 1 < len(slices_to_concat): 

291 # Concatenate pair 

292 combined = cle.concatenate_along_z(slices_to_concat[i], slices_to_concat[i + 1]) 

293 new_slices.append(combined) 

294 else: 

295 # Odd one out 

296 new_slices.append(slices_to_concat[i]) 

297 slices_to_concat = new_slices 

298 result = slices_to_concat[0] 

299 

300 else: 

301 # Standard handling for other memory types 

302 for i, slice_data in enumerate(slices): 

303 source_type = _detect_memory_type(slice_data) 

304 

305 # Track conversions for batch logging 

306 if source_type != memory_type: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 conversion_count += 1 

308 

309 # Direct conversion without MemoryWrapper overhead 

310 if source_type == memory_type: 310 ↛ 313line 310 didn't jump to line 313 because the condition on line 310 was always true

311 converted_data = slice_data 

312 else: 

313 from openhcs.core.memory.converters import convert_memory 

314 converted_data = convert_memory( 

315 data=slice_data, 

316 source_type=source_type, 

317 target_type=memory_type, 

318 gpu_id=gpu_id, 

319 allow_cpu_roundtrip=False 

320 ) 

321 

322 # Assign converted slice directly to pre-allocated result array 

323 # Handle JAX immutability 

324 if memory_type == MEMORY_TYPE_JAX: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true

325 result = result.at[i].set(converted_data) 

326 else: 

327 result[i] = converted_data 

328 

329 # 🔍 MEMORY CONVERSION LOGGING: Only log when conversions happen or issues occur 

330 if conversion_count > 0: 330 ↛ 331line 330 didn't jump to line 331 because the condition on line 330 was never true

331 logger.debug(f"🔄 STACK_SLICES: Converted {conversion_count}/{len(slices)} slices to {memory_type}") 

332 # Silent success for no-conversion cases to reduce log pollution 

333 

334 return result 

335 

336 

337def unstack_slices(array: Any, memory_type: str, gpu_id: int, validate_slices: bool = True) -> List[Any]: 

338 """ 

339 Split a 3D array into 2D slices along axis 0 and convert to the specified memory type. 

340 

341 STRICT VALIDATION: Input must be a 3D array. No automatic handling of improper inputs. 

342 

343 Args: 

344 array: 3D array to split - MUST BE 3D 

345 memory_type: The memory type to use for the output slices (REQUIRED) 

346 gpu_id: The target GPU device ID (REQUIRED) 

347 validate_slices: If True, validates that each extracted slice is 2D 

348 

349 Returns: 

350 List of 2D slices in the specified memory type 

351 

352 Raises: 

353 ValueError: If array is not 3D 

354 ValueError: If validate_slices is True and any extracted slice is not 2D 

355 ValueError: If gpu_id is negative for GPU memory types 

356 ValueError: If memory_type is not supported 

357 MemoryConversionError: If conversion fails 

358 """ 

359 # Detect input type and check if conversion is needed 

360 input_type = _detect_memory_type(array) 

361 input_shape = getattr(array, 'shape', 'unknown') 

362 needs_conversion = input_type != memory_type 

363 

364 # Verify the array is 3D - fail loudly if not 

365 if not _is_3d(array): 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 raise ValueError(f"Array must be 3D, got shape {getattr(array, 'shape', 'unknown')}") 

367 

368 # Check GPU requirements 

369 _enforce_gpu_device_requirements(memory_type, gpu_id) 

370 

371 # Convert to target memory type using direct convert_memory call 

372 # Bypass MemoryWrapper to eliminate object creation overhead 

373 source_type = input_type # Reuse already detected type from line 286 

374 

375 # Direct conversion without MemoryWrapper overhead 

376 if source_type == memory_type: 376 ↛ 381line 376 didn't jump to line 381 because the condition on line 376 was always true

377 # No conversion needed - silent success to reduce log pollution 

378 pass 

379 else: 

380 # Use direct convert_memory call and log the conversion 

381 from openhcs.core.memory.converters import convert_memory 

382 logger.debug(f"🔄 UNSTACK_SLICES: Converting array - {source_type}{memory_type}") 

383 array = convert_memory( 

384 data=array, 

385 source_type=source_type, 

386 target_type=memory_type, 

387 gpu_id=gpu_id, 

388 allow_cpu_roundtrip=False 

389 ) 

390 

391 # Extract slices along axis 0 (already in the target memory type) 

392 slices = [array[i] for i in range(array.shape[0])] 

393 

394 # Validate that all extracted slices are 2D if requested 

395 if validate_slices: 395 ↛ 401line 395 didn't jump to line 401 because the condition on line 395 was always true

396 for i, slice_data in enumerate(slices): 

397 if not _is_2d(slice_data): 397 ↛ 398line 397 didn't jump to line 398 because the condition on line 397 was never true

398 raise ValueError(f"Extracted slice at index {i} is not 2D. This indicates a malformed 3D array.") 

399 

400 # 🔍 MEMORY CONVERSION LOGGING: Only log conversions or issues 

401 if source_type != memory_type: 401 ↛ 402line 401 didn't jump to line 402 because the condition on line 401 was never true

402 logger.debug(f"🔄 UNSTACK_SLICES: Converted and extracted {len(slices)} slices") 

403 elif len(slices) == 0: 403 ↛ 404line 403 didn't jump to line 404 because the condition on line 403 was never true

404 logger.warning(f"🔄 UNSTACK_SLICES: No slices extracted (empty array)") 

405 # Silent success for no-conversion cases to reduce log pollution 

406 

407 return slices