Coverage for openhcs/core/memory/stack_utils.py: 34.4%
160 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Stack utilities module for OpenHCS.
4This module provides functions for stacking 2D slices into a 3D array
5and unstacking a 3D array into 2D slices, with explicit memory type handling.
7This module enforces Clause 278 — Mandatory 3D Output Enforcement:
8All functions must return a 3D array of shape [Z, Y, X], even when operating
9on a single 2D slice. No logic may check, coerce, or infer rank at unstack time.
10"""
12import logging
13from typing import Any, List
15import numpy as np
17from openhcs.constants.constants import (GPU_MEMORY_TYPES, MEMORY_TYPE_CUPY,
18 MEMORY_TYPE_JAX, MEMORY_TYPE_NUMPY,
19 MEMORY_TYPE_PYCLESPERANTO, MEMORY_TYPE_TENSORFLOW,
20 MEMORY_TYPE_TORCH, MemoryType)
21from openhcs.core.memory import MemoryWrapper
22from openhcs.core.utils import optional_import
24logger = logging.getLogger(__name__)
26# 🔍 MEMORY CONVERSION LOGGING: Test log to verify logger is working
27logger.debug("🔄 STACK_UTILS: Module loaded - memory conversion logging enabled")
30def _is_2d(data: Any) -> bool:
31 """
32 Check if data is a 2D array.
34 Args:
35 data: Data to check
37 Returns:
38 True if data is 2D, False otherwise
39 """
40 # Check if data has a shape attribute
41 if not hasattr(data, 'shape'): 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true
42 return False
44 # Check if shape has length 2
45 return len(data.shape) == 2
48def _is_3d(data: Any) -> bool:
49 """
50 Check if data is a 3D array.
52 Args:
53 data: Data to check
55 Returns:
56 True if data is 3D, False otherwise
57 """
58 # Check if data has a shape attribute
59 if not hasattr(data, 'shape'): 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true
60 return False
62 # Check if shape has length 3
63 return len(data.shape) == 3
66def _detect_memory_type(data: Any) -> str:
67 """
68 Detect the memory type of the data.
70 STRICT VALIDATION: Fails loudly if the memory type cannot be detected.
71 No automatic fallback to a default memory type.
73 Args:
74 data: The data to detect the memory type of
76 Returns:
77 The detected memory type
79 Raises:
80 ValueError: If the memory type cannot be detected
81 """
82 # Check if it's a MemoryWrapper
83 if isinstance(data, MemoryWrapper): 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 return data.memory_type
86 # Check if it's a numpy array
87 if isinstance(data, np.ndarray): 87 ↛ 91line 87 didn't jump to line 91 because the condition on line 87 was always true
88 return MemoryType.NUMPY.value
90 # Check if it's a cupy array
91 cp = optional_import("cupy")
92 if cp is not None and isinstance(data, cp.ndarray):
93 return MemoryType.CUPY.value
95 # Check if it's a torch tensor
96 torch = optional_import("torch")
97 if torch is not None and isinstance(data, torch.Tensor):
98 return MemoryType.TORCH.value
100 # Check if it's a tensorflow tensor
101 tf = optional_import("tensorflow")
102 if tf is not None and isinstance(data, tf.Tensor):
103 return MemoryType.TENSORFLOW.value
105 # Check if it's a JAX array
106 jax = optional_import("jax")
107 jnp = optional_import("jax.numpy") if jax is not None else None
108 if jnp is not None and isinstance(data, jnp.ndarray):
109 return MemoryType.JAX.value
111 # Check if it's a pyclesperanto array
112 cle = optional_import("pyclesperanto")
113 if cle is not None and hasattr(cle, 'Array') and isinstance(data, cle.Array):
114 return MemoryType.PYCLESPERANTO.value
116 # Fail loudly if we can't detect the type
117 raise ValueError(f"Could not detect memory type of {type(data)}")
120def _enforce_gpu_device_requirements(memory_type: str, gpu_id: int) -> None:
121 """
122 Enforce GPU device requirements.
124 Args:
125 memory_type: The memory type
126 gpu_id: The GPU device ID
128 Raises:
129 ValueError: If gpu_id is negative
130 """
131 # For GPU memory types, validate gpu_id
132 if memory_type in {mem_type.value for mem_type in GPU_MEMORY_TYPES}: 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true
133 if gpu_id < 0:
134 raise ValueError(f"Invalid GPU device ID: {gpu_id}. Must be a non-negative integer.")
137def stack_slices(slices: List[Any], memory_type: str, gpu_id: int) -> Any:
138 """
139 Stack 2D slices into a 3D array with the specified memory type.
141 STRICT VALIDATION: Assumes all slices are 2D arrays.
142 No automatic handling of improper inputs.
144 Args:
145 slices: List of 2D slices (numpy arrays, cupy arrays, torch tensors, etc.)
146 memory_type: The memory type to use for the stacked array (REQUIRED)
147 gpu_id: The target GPU device ID (REQUIRED)
149 Returns:
150 A 3D array with the specified memory type of shape [Z, Y, X]
152 Raises:
153 ValueError: If memory_type is not supported or slices is empty
154 ValueError: If gpu_id is negative for GPU memory types
155 ValueError: If slices are not 2D arrays
156 MemoryConversionError: If conversion fails
157 """
158 if not slices: 158 ↛ 159line 158 didn't jump to line 159 because the condition on line 158 was never true
159 raise ValueError("Cannot stack empty list of slices")
161 # Verify all slices are 2D
162 for i, slice_data in enumerate(slices):
163 if not _is_2d(slice_data): 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true
164 raise ValueError(f"Slice at index {i} is not a 2D array. All slices must be 2D.")
166 # Analyze input types for conversion planning (minimal logging)
167 input_types = [_detect_memory_type(slice_data) for slice_data in slices]
168 unique_input_types = set(input_types)
169 needs_conversion = memory_type not in unique_input_types or len(unique_input_types) > 1
171 # Check GPU requirements
172 _enforce_gpu_device_requirements(memory_type, gpu_id)
174 # Pre-allocate the final 3D array to avoid intermediate list and final stack operation
175 first_slice = slices[0]
176 stack_shape = (len(slices), first_slice.shape[0], first_slice.shape[1])
178 # Create pre-allocated result array in target memory type
179 if memory_type == MEMORY_TYPE_NUMPY: 179 ↛ 198line 179 didn't jump to line 198 because the condition on line 179 was always true
180 import numpy as np
182 # Handle torch dtypes by converting a sample slice first
183 first_slice_source_type = _detect_memory_type(first_slice)
184 if first_slice_source_type == MEMORY_TYPE_TORCH: 184 ↛ 186line 184 didn't jump to line 186 because the condition on line 184 was never true
185 # Convert torch tensor to numpy to get compatible dtype
186 from openhcs.core.memory.converters import convert_memory
187 sample_converted = convert_memory(
188 data=first_slice,
189 source_type=first_slice_source_type,
190 target_type=memory_type,
191 gpu_id=gpu_id,
192 allow_cpu_roundtrip=True # Allow CPU roundtrip for numpy conversion
193 )
194 result = np.empty(stack_shape, dtype=sample_converted.dtype)
195 else:
196 # Use dtype directly for non-torch types
197 result = np.empty(stack_shape, dtype=first_slice.dtype)
198 elif memory_type == MEMORY_TYPE_CUPY:
199 cupy = optional_import("cupy")
200 if cupy is None:
201 raise ValueError(f"CuPy is required for memory type {memory_type}")
202 with cupy.cuda.Device(gpu_id):
203 result = cupy.empty(stack_shape, dtype=first_slice.dtype)
204 elif memory_type == MEMORY_TYPE_TORCH:
205 torch = optional_import("torch")
206 if torch is None:
207 raise ValueError(f"PyTorch is required for memory type {memory_type}")
209 # Convert first slice to get the correct torch dtype
210 from openhcs.core.memory.converters import convert_memory
211 first_slice_source_type = _detect_memory_type(first_slice)
212 sample_converted = convert_memory(
213 data=first_slice,
214 source_type=first_slice_source_type,
215 target_type=memory_type,
216 gpu_id=gpu_id,
217 allow_cpu_roundtrip=False
218 )
220 result = torch.empty(stack_shape, dtype=sample_converted.dtype, device=sample_converted.device)
221 elif memory_type == MEMORY_TYPE_TENSORFLOW:
222 tf = optional_import("tensorflow")
223 if tf is None:
224 raise ValueError(f"TensorFlow is required for memory type {memory_type}")
225 with tf.device(f"/device:GPU:{gpu_id}"):
226 result = tf.zeros(stack_shape, dtype=first_slice.dtype) # TF doesn't have empty()
227 elif memory_type == MEMORY_TYPE_JAX:
228 jax = optional_import("jax")
229 if jax is None:
230 raise ValueError(f"JAX is required for memory type {memory_type}")
231 jnp = optional_import("jax.numpy")
232 if jnp is None:
233 raise ValueError(f"JAX is required for memory type {memory_type}")
234 result = jnp.empty(stack_shape, dtype=first_slice.dtype)
235 elif memory_type == MEMORY_TYPE_PYCLESPERANTO:
236 cle = optional_import("pyclesperanto")
237 if cle is None:
238 raise ValueError(f"pyclesperanto is required for memory type {memory_type}")
239 # For pyclesperanto, we'll build the result using concatenate_along_z
240 # Don't pre-allocate here, we'll handle it in the loop below
241 result = None
242 else:
243 raise ValueError(f"Unsupported memory type: {memory_type}")
245 # Convert each slice and assign to result array
246 conversion_count = 0
248 # Special handling for pyclesperanto - build using concatenate_along_z
249 if memory_type == MEMORY_TYPE_PYCLESPERANTO: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true
250 cle = optional_import("pyclesperanto")
251 converted_slices = []
253 for i, slice_data in enumerate(slices):
254 source_type = _detect_memory_type(slice_data)
256 # Track conversions for batch logging
257 if source_type != memory_type:
258 conversion_count += 1
260 # Convert slice to pyclesperanto
261 if source_type == memory_type:
262 converted_data = slice_data
263 else:
264 from openhcs.core.memory.converters import convert_memory
265 converted_data = convert_memory(
266 data=slice_data,
267 source_type=source_type,
268 target_type=memory_type,
269 gpu_id=gpu_id,
270 allow_cpu_roundtrip=False
271 )
273 # Ensure slice is 2D, expand to 3D single slice if needed
274 if converted_data.ndim == 2:
275 # Convert 2D slice to 3D single slice using expand_dims equivalent
276 converted_data = cle.push(cle.pull(converted_data)[None, ...])
278 converted_slices.append(converted_data)
280 # Build 3D result using efficient batch concatenation
281 if len(converted_slices) == 1:
282 result = converted_slices[0]
283 else:
284 # Use divide-and-conquer approach for better performance
285 # This reduces O(N²) copying to O(N log N)
286 slices_to_concat = converted_slices[:]
287 while len(slices_to_concat) > 1:
288 new_slices = []
289 for i in range(0, len(slices_to_concat), 2):
290 if i + 1 < len(slices_to_concat):
291 # Concatenate pair
292 combined = cle.concatenate_along_z(slices_to_concat[i], slices_to_concat[i + 1])
293 new_slices.append(combined)
294 else:
295 # Odd one out
296 new_slices.append(slices_to_concat[i])
297 slices_to_concat = new_slices
298 result = slices_to_concat[0]
300 else:
301 # Standard handling for other memory types
302 for i, slice_data in enumerate(slices):
303 source_type = _detect_memory_type(slice_data)
305 # Track conversions for batch logging
306 if source_type != memory_type: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true
307 conversion_count += 1
309 # Direct conversion without MemoryWrapper overhead
310 if source_type == memory_type: 310 ↛ 313line 310 didn't jump to line 313 because the condition on line 310 was always true
311 converted_data = slice_data
312 else:
313 from openhcs.core.memory.converters import convert_memory
314 converted_data = convert_memory(
315 data=slice_data,
316 source_type=source_type,
317 target_type=memory_type,
318 gpu_id=gpu_id,
319 allow_cpu_roundtrip=False
320 )
322 # Assign converted slice directly to pre-allocated result array
323 # Handle JAX immutability
324 if memory_type == MEMORY_TYPE_JAX: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 result = result.at[i].set(converted_data)
326 else:
327 result[i] = converted_data
329 # 🔍 MEMORY CONVERSION LOGGING: Only log when conversions happen or issues occur
330 if conversion_count > 0: 330 ↛ 331line 330 didn't jump to line 331 because the condition on line 330 was never true
331 logger.debug(f"🔄 STACK_SLICES: Converted {conversion_count}/{len(slices)} slices to {memory_type}")
332 # Silent success for no-conversion cases to reduce log pollution
334 return result
337def unstack_slices(array: Any, memory_type: str, gpu_id: int, validate_slices: bool = True) -> List[Any]:
338 """
339 Split a 3D array into 2D slices along axis 0 and convert to the specified memory type.
341 STRICT VALIDATION: Input must be a 3D array. No automatic handling of improper inputs.
343 Args:
344 array: 3D array to split - MUST BE 3D
345 memory_type: The memory type to use for the output slices (REQUIRED)
346 gpu_id: The target GPU device ID (REQUIRED)
347 validate_slices: If True, validates that each extracted slice is 2D
349 Returns:
350 List of 2D slices in the specified memory type
352 Raises:
353 ValueError: If array is not 3D
354 ValueError: If validate_slices is True and any extracted slice is not 2D
355 ValueError: If gpu_id is negative for GPU memory types
356 ValueError: If memory_type is not supported
357 MemoryConversionError: If conversion fails
358 """
359 # Detect input type and check if conversion is needed
360 input_type = _detect_memory_type(array)
361 input_shape = getattr(array, 'shape', 'unknown')
362 needs_conversion = input_type != memory_type
364 # Verify the array is 3D - fail loudly if not
365 if not _is_3d(array): 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true
366 raise ValueError(f"Array must be 3D, got shape {getattr(array, 'shape', 'unknown')}")
368 # Check GPU requirements
369 _enforce_gpu_device_requirements(memory_type, gpu_id)
371 # Convert to target memory type using direct convert_memory call
372 # Bypass MemoryWrapper to eliminate object creation overhead
373 source_type = input_type # Reuse already detected type from line 286
375 # Direct conversion without MemoryWrapper overhead
376 if source_type == memory_type: 376 ↛ 381line 376 didn't jump to line 381 because the condition on line 376 was always true
377 # No conversion needed - silent success to reduce log pollution
378 pass
379 else:
380 # Use direct convert_memory call and log the conversion
381 from openhcs.core.memory.converters import convert_memory
382 logger.debug(f"🔄 UNSTACK_SLICES: Converting array - {source_type} → {memory_type}")
383 array = convert_memory(
384 data=array,
385 source_type=source_type,
386 target_type=memory_type,
387 gpu_id=gpu_id,
388 allow_cpu_roundtrip=False
389 )
391 # Extract slices along axis 0 (already in the target memory type)
392 slices = [array[i] for i in range(array.shape[0])]
394 # Validate that all extracted slices are 2D if requested
395 if validate_slices: 395 ↛ 401line 395 didn't jump to line 401 because the condition on line 395 was always true
396 for i, slice_data in enumerate(slices):
397 if not _is_2d(slice_data): 397 ↛ 398line 397 didn't jump to line 398 because the condition on line 397 was never true
398 raise ValueError(f"Extracted slice at index {i} is not 2D. This indicates a malformed 3D array.")
400 # 🔍 MEMORY CONVERSION LOGGING: Only log conversions or issues
401 if source_type != memory_type: 401 ↛ 402line 401 didn't jump to line 402 because the condition on line 401 was never true
402 logger.debug(f"🔄 UNSTACK_SLICES: Converted and extracted {len(slices)} slices")
403 elif len(slices) == 0: 403 ↛ 404line 403 didn't jump to line 404 because the condition on line 403 was never true
404 logger.warning(f"🔄 UNSTACK_SLICES: No slices extracted (empty array)")
405 # Silent success for no-conversion cases to reduce log pollution
407 return slices