Coverage for src/arraybridge/stack_utils.py: 78%
103 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2Stack utilities module for OpenHCS.
4This module provides functions for stacking 2D slices into a 3D array
5and unstacking a 3D array into 2D slices, with explicit memory type handling.
7This module enforces Clause 278 — Mandatory 3D Output Enforcement:
8All functions must return a 3D array of shape [Z, Y, X], even when operating
9on a single 2D slice. No logic may check, coerce, or infer rank at unstack time.
10"""
12import logging
13from typing import Any
15from arraybridge.converters import detect_memory_type
16from arraybridge.framework_config import _FRAMEWORK_CONFIG
17from arraybridge.types import GPU_MEMORY_TYPES, MemoryType
18from arraybridge.utils import optional_import
20logger = logging.getLogger(__name__)
22# 🔍 MEMORY CONVERSION LOGGING: Test log to verify logger is working
23logger.debug("🔄 STACK_UTILS: Module loaded - memory conversion logging enabled")
26def _is_2d(data: Any) -> bool:
27 """
28 Check if data is a 2D array.
30 Args:
31 data: Data to check
33 Returns:
34 True if data is 2D, False otherwise
35 """
36 # Check if data has a shape attribute
37 if not hasattr(data, "shape"):
38 return False
40 # Check if shape has length 2
41 return len(data.shape) == 2
44def _is_3d(data: Any) -> bool:
45 """
46 Check if data is a 3D array.
48 Args:
49 data: Data to check
51 Returns:
52 True if data is 3D, False otherwise
53 """
54 # Check if data has a shape attribute
55 if not hasattr(data, "shape"):
56 return False
58 # Check if shape has length 3
59 return len(data.shape) == 3
62def _enforce_gpu_device_requirements(memory_type: str, gpu_id: int) -> None:
63 """
64 Enforce GPU device requirements.
66 Args:
67 memory_type: The memory type
68 gpu_id: The GPU device ID
70 Raises:
71 ValueError: If gpu_id is negative
72 """
73 # For GPU memory types, validate gpu_id
74 if memory_type in {mem_type.value for mem_type in GPU_MEMORY_TYPES}:
75 if gpu_id < 0:
76 raise ValueError(f"Invalid GPU device ID: {gpu_id}. Must be a non-negative integer.")
79# NOTE: Allocation operations now defined in framework_config.py
80# This eliminates the scattered _ALLOCATION_OPS dict
83def _allocate_stack_array(
84 memory_type: str, stack_shape: tuple, first_slice: Any, gpu_id: int
85) -> Any:
86 """
87 Allocate a 3D array for stacking slices using framework config.
89 Args:
90 memory_type: The target memory type
91 stack_shape: The shape of the stack (Z, Y, X)
92 first_slice: The first slice (used for dtype inference)
93 gpu_id: The GPU device ID
95 Returns:
96 Pre-allocated array or None for pyclesperanto
97 """
98 # Convert string to enum
99 mem_type = MemoryType(memory_type)
100 config = _FRAMEWORK_CONFIG[mem_type]
101 allocate_expr = config["allocate_stack"]
103 # Check if allocation is None (pyclesperanto uses custom stacking)
104 if allocate_expr is None:
105 return None
107 # Import the module
108 mod = optional_import(mem_type.value)
109 if mod is None:
110 raise ValueError(f"{mem_type.value} is required for memory type {memory_type}")
112 # Handle dtype conversion if needed
113 needs_conversion = config["needs_dtype_conversion"]
114 if callable(needs_conversion):
115 # It's a callable that determines if conversion is needed
116 needs_conversion = needs_conversion(first_slice, detect_memory_type)
118 # Initialize variables for eval expressions
119 sample_converted = None
120 if needs_conversion:
121 from arraybridge.converters import convert_memory
123 first_slice_source_type = detect_memory_type(first_slice)
124 sample_converted = convert_memory(
125 data=first_slice,
126 source_type=first_slice_source_type,
127 target_type=memory_type,
128 gpu_id=gpu_id,
129 )
131 # Set up local variables for eval
132 np = optional_import("numpy") # noqa: F841
133 cupy = mod if mem_type == MemoryType.CUPY else None # noqa: F841
134 torch = mod if mem_type == MemoryType.TORCH else None # noqa: F841
135 tf = mod if mem_type == MemoryType.TENSORFLOW else None # noqa: F841
136 jnp = optional_import("jax.numpy") if mem_type == MemoryType.JAX else None # noqa: F841
137 # dtype is used in allocate_expr eval below (for numpy framework)
138 dtype = ( # noqa: F841
139 sample_converted.dtype
140 if sample_converted is not None
141 else (first_slice.dtype if hasattr(first_slice, "dtype") else None)
142 )
144 # Execute allocation with context if needed
145 allocate_context = config.get("allocate_context")
146 if allocate_context:
147 context = eval(allocate_context)
148 with context:
149 return eval(allocate_expr)
150 else:
151 return eval(allocate_expr)
154def stack_slices(slices: list[Any], memory_type: str, gpu_id: int) -> Any:
155 """
156 Stack 2D slices into a 3D array with the specified memory type.
158 STRICT VALIDATION: Assumes all slices are 2D arrays.
159 No automatic handling of improper inputs.
161 Args:
162 slices: List of 2D slices (numpy arrays, cupy arrays, torch tensors, etc.)
163 memory_type: The memory type to use for the stacked array (REQUIRED)
164 gpu_id: The target GPU device ID (REQUIRED)
166 Returns:
167 A 3D array with the specified memory type of shape [Z, Y, X]
169 Raises:
170 ValueError: If memory_type is not supported or slices is empty
171 ValueError: If gpu_id is negative for GPU memory types
172 ValueError: If slices are not 2D arrays
173 MemoryConversionError: If conversion fails
174 """
175 if not slices:
176 raise ValueError("Cannot stack empty list of slices")
178 # Verify all slices are 2D
179 for i, slice_data in enumerate(slices):
180 if not _is_2d(slice_data):
181 raise ValueError(f"Slice at index {i} is not a 2D array. All slices must be 2D.")
183 # Check GPU requirements
184 _enforce_gpu_device_requirements(memory_type, gpu_id)
186 # Pre-allocate the final 3D array to avoid intermediate list and final stack operation
187 first_slice = slices[0]
188 stack_shape = (len(slices), first_slice.shape[0], first_slice.shape[1])
190 # Create pre-allocated result array in target memory type using enum dispatch
191 result = _allocate_stack_array(memory_type, stack_shape, first_slice, gpu_id)
193 # Convert each slice and assign to result array
194 conversion_count = 0
196 # Check for custom stack handler (pyclesperanto)
197 mem_type = MemoryType(memory_type)
198 config = _FRAMEWORK_CONFIG[mem_type]
199 stack_handler = config.get("stack_handler")
201 if stack_handler:
202 # Use custom stack handler
203 mod = optional_import(mem_type.value)
204 result = stack_handler(slices, memory_type, gpu_id, mod)
205 else:
206 # Standard stacking logic
207 for i, slice_data in enumerate(slices):
208 source_type = detect_memory_type(slice_data)
210 # Track conversions for batch logging
211 if source_type != memory_type:
212 conversion_count += 1
214 # Direct conversion
215 if source_type == memory_type:
216 converted_data = slice_data
217 else:
218 from arraybridge.converters import convert_memory
220 converted_data = convert_memory(
221 data=slice_data, source_type=source_type, target_type=memory_type, gpu_id=gpu_id
222 )
224 # Assign converted slice using framework-specific handler if available
225 assign_handler = config.get("assign_slice")
226 if assign_handler:
227 # Custom assignment (JAX immutability)
228 result = assign_handler(result, i, converted_data)
229 else:
230 # Standard assignment
231 result[i] = converted_data
233 # 🔍 MEMORY CONVERSION LOGGING: Only log when conversions happen or issues occur
234 if conversion_count > 0:
235 logger.debug(
236 f"🔄 STACK_SLICES: Converted {conversion_count}/{len(slices)} "
237 f"slices to {memory_type}"
238 )
239 # Silent success for no-conversion cases to reduce log pollution
241 return result
244def unstack_slices(
245 array: Any, memory_type: str, gpu_id: int, validate_slices: bool = True
246) -> list[Any]:
247 """
248 Split a 3D array into 2D slices along axis 0 and convert to the specified memory type.
250 STRICT VALIDATION: Input must be a 3D array. No automatic handling of improper inputs.
252 Args:
253 array: 3D array to split - MUST BE 3D
254 memory_type: The memory type to use for the output slices (REQUIRED)
255 gpu_id: The target GPU device ID (REQUIRED)
256 validate_slices: If True, validates that each extracted slice is 2D
258 Returns:
259 List of 2D slices in the specified memory type
261 Raises:
262 ValueError: If array is not 3D
263 ValueError: If validate_slices is True and any extracted slice is not 2D
264 ValueError: If gpu_id is negative for GPU memory types
265 ValueError: If memory_type is not supported
266 MemoryConversionError: If conversion fails
267 """
268 # Detect input type and check if conversion is needed
269 input_type = detect_memory_type(array)
270 getattr(array, "shape", "unknown")
272 # Verify the array is 3D - fail loudly if not
273 if not _is_3d(array):
274 raise ValueError(f"Array must be 3D, got shape {getattr(array, 'shape', 'unknown')}")
276 # Check GPU requirements
277 _enforce_gpu_device_requirements(memory_type, gpu_id)
279 # Convert to target memory type
280 source_type = input_type # Reuse already detected type
282 # Direct conversion
283 if source_type == memory_type:
284 # No conversion needed - silent success to reduce log pollution
285 pass
286 else:
287 # Convert and log the conversion
288 from arraybridge.converters import convert_memory
290 logger.debug(f"🔄 UNSTACK_SLICES: Converting array - {source_type} → {memory_type}")
291 array = convert_memory(
292 data=array, source_type=source_type, target_type=memory_type, gpu_id=gpu_id
293 )
295 # Extract slices along axis 0 (already in the target memory type)
296 slices = [array[i] for i in range(array.shape[0])]
298 # Validate that all extracted slices are 2D if requested
299 if validate_slices:
300 for i, slice_data in enumerate(slices):
301 if not _is_2d(slice_data):
302 raise ValueError(
303 f"Extracted slice at index {i} is not 2D. "
304 f"This indicates a malformed 3D array."
305 )
307 # 🔍 MEMORY CONVERSION LOGGING: Only log conversions or issues
308 if source_type != memory_type:
309 logger.debug(f"🔄 UNSTACK_SLICES: Converted and extracted {len(slices)} slices")
310 elif len(slices) == 0:
311 logger.warning("🔄 UNSTACK_SLICES: No slices extracted (empty array)")
312 # Silent success for no-conversion cases to reduce log pollution
314 return slices