Coverage for openhcs/core/memory/stack_utils.py: 67.1%

108 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Stack utilities module for OpenHCS. 

3 

4This module provides functions for stacking 2D slices into a 3D array 

5and unstacking a 3D array into 2D slices, with explicit memory type handling. 

6 

7This module enforces Clause 278 — Mandatory 3D Output Enforcement: 

8All functions must return a 3D array of shape [Z, Y, X], even when operating 

9on a single 2D slice. No logic may check, coerce, or infer rank at unstack time. 

10""" 

11 

12import logging 

13from typing import Any, List 

14 

15import numpy as np 

16 

17from openhcs.constants.constants import GPU_MEMORY_TYPES, MemoryType 

18from openhcs.core.memory.converters import detect_memory_type 

19from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG 

20from openhcs.core.utils import optional_import 

21 

22logger = logging.getLogger(__name__) 

23 

24# 🔍 MEMORY CONVERSION LOGGING: Test log to verify logger is working 

25logger.debug("🔄 STACK_UTILS: Module loaded - memory conversion logging enabled") 

26 

27 

28def _is_2d(data: Any) -> bool: 

29 """ 

30 Check if data is a 2D array. 

31 

32 Args: 

33 data: Data to check 

34 

35 Returns: 

36 True if data is 2D, False otherwise 

37 """ 

38 # Check if data has a shape attribute 

39 if not hasattr(data, 'shape'): 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true

40 return False 

41 

42 # Check if shape has length 2 

43 return len(data.shape) == 2 

44 

45 

46def _is_3d(data: Any) -> bool: 

47 """ 

48 Check if data is a 3D array. 

49 

50 Args: 

51 data: Data to check 

52 

53 Returns: 

54 True if data is 3D, False otherwise 

55 """ 

56 # Check if data has a shape attribute 

57 if not hasattr(data, 'shape'): 57 ↛ 58line 57 didn't jump to line 58 because the condition on line 57 was never true

58 return False 

59 

60 # Check if shape has length 3 

61 return len(data.shape) == 3 

62 

63 

64def _enforce_gpu_device_requirements(memory_type: str, gpu_id: int) -> None: 

65 """ 

66 Enforce GPU device requirements. 

67 

68 Args: 

69 memory_type: The memory type 

70 gpu_id: The GPU device ID 

71 

72 Raises: 

73 ValueError: If gpu_id is negative 

74 """ 

75 # For GPU memory types, validate gpu_id 

76 if memory_type in {mem_type.value for mem_type in GPU_MEMORY_TYPES}: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 if gpu_id < 0: 

78 raise ValueError(f"Invalid GPU device ID: {gpu_id}. Must be a non-negative integer.") 

79 

80 

81# NOTE: Allocation operations now defined in framework_config.py 

82# This eliminates the scattered _ALLOCATION_OPS dict 

83 

84 

85def _allocate_stack_array(memory_type: str, stack_shape: tuple, first_slice: Any, gpu_id: int) -> Any: 

86 """ 

87 Allocate a 3D array for stacking slices using framework config. 

88 

89 Args: 

90 memory_type: The target memory type 

91 stack_shape: The shape of the stack (Z, Y, X) 

92 first_slice: The first slice (used for dtype inference) 

93 gpu_id: The GPU device ID 

94 

95 Returns: 

96 Pre-allocated array or None for pyclesperanto 

97 """ 

98 # Convert string to enum 

99 mem_type = MemoryType(memory_type) 

100 config = _FRAMEWORK_CONFIG[mem_type] 

101 allocate_expr = config['allocate_stack'] 

102 

103 # Check if allocation is None (pyclesperanto uses custom stacking) 

104 if allocate_expr is None: 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true

105 return None 

106 

107 # Import the module 

108 mod = optional_import(mem_type.value) 

109 if mod is None: 109 ↛ 110line 109 didn't jump to line 110 because the condition on line 109 was never true

110 raise ValueError(f"{mem_type.value} is required for memory type {memory_type}") 

111 

112 # Handle dtype conversion if needed 

113 needs_conversion = config['needs_dtype_conversion'] 

114 if callable(needs_conversion): 114 ↛ 118line 114 didn't jump to line 118 because the condition on line 114 was always true

115 # It's a callable that determines if conversion is needed 

116 needs_conversion = needs_conversion(first_slice, detect_memory_type) 

117 

118 if needs_conversion: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 from openhcs.core.memory.converters import convert_memory 

120 first_slice_source_type = detect_memory_type(first_slice) 

121 sample_converted = convert_memory( # noqa: F841 (used in eval) 

122 data=first_slice, 

123 source_type=first_slice_source_type, 

124 target_type=memory_type, 

125 gpu_id=gpu_id 

126 ) 

127 dtype = sample_converted.dtype # noqa: F841 (used in eval) 

128 else: 

129 dtype = first_slice.dtype if hasattr(first_slice, 'dtype') else None # noqa: F841 (used in eval) 

130 

131 # Set up local variables for eval 

132 np = optional_import("numpy") # noqa: F841 (used in eval) 

133 cupy = mod if mem_type == MemoryType.CUPY else None # noqa: F841 (used in eval) 

134 torch = mod if mem_type == MemoryType.TORCH else None # noqa: F841 (used in eval) 

135 tf = mod if mem_type == MemoryType.TENSORFLOW else None # noqa: F841 (used in eval) 

136 jnp = optional_import("jax.numpy") if mem_type == MemoryType.JAX else None # noqa: F841 (used in eval) 

137 

138 # Execute allocation with context if needed 

139 allocate_context = config.get('allocate_context') 

140 if allocate_context: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true

141 context = eval(allocate_context) 

142 with context: 

143 return eval(allocate_expr) 

144 else: 

145 return eval(allocate_expr) 

146 

147 

148def stack_slices(slices: List[Any], memory_type: str, gpu_id: int) -> Any: 

149 """ 

150 Stack 2D slices into a 3D array with the specified memory type. 

151 

152 STRICT VALIDATION: Assumes all slices are 2D arrays. 

153 No automatic handling of improper inputs. 

154 

155 Args: 

156 slices: List of 2D slices (numpy arrays, cupy arrays, torch tensors, etc.) 

157 memory_type: The memory type to use for the stacked array (REQUIRED) 

158 gpu_id: The target GPU device ID (REQUIRED) 

159 

160 Returns: 

161 A 3D array with the specified memory type of shape [Z, Y, X] 

162 

163 Raises: 

164 ValueError: If memory_type is not supported or slices is empty 

165 ValueError: If gpu_id is negative for GPU memory types 

166 ValueError: If slices are not 2D arrays 

167 MemoryConversionError: If conversion fails 

168 """ 

169 if not slices: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 raise ValueError("Cannot stack empty list of slices") 

171 

172 # Verify all slices are 2D 

173 for i, slice_data in enumerate(slices): 

174 if not _is_2d(slice_data): 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true

175 raise ValueError(f"Slice at index {i} is not a 2D array. All slices must be 2D.") 

176 

177 # Analyze input types for conversion planning (minimal logging) 

178 input_types = [detect_memory_type(slice_data) for slice_data in slices] 

179 unique_input_types = set(input_types) 

180 needs_conversion = memory_type not in unique_input_types or len(unique_input_types) > 1 

181 

182 # Check GPU requirements 

183 _enforce_gpu_device_requirements(memory_type, gpu_id) 

184 

185 # Pre-allocate the final 3D array to avoid intermediate list and final stack operation 

186 first_slice = slices[0] 

187 stack_shape = (len(slices), first_slice.shape[0], first_slice.shape[1]) 

188 

189 # Create pre-allocated result array in target memory type using enum dispatch 

190 result = _allocate_stack_array(memory_type, stack_shape, first_slice, gpu_id) 

191 

192 # Convert each slice and assign to result array 

193 conversion_count = 0 

194 

195 # Check for custom stack handler (pyclesperanto) 

196 mem_type = MemoryType(memory_type) 

197 config = _FRAMEWORK_CONFIG[mem_type] 

198 stack_handler = config.get('stack_handler') 

199 

200 if stack_handler: 200 ↛ 202line 200 didn't jump to line 202 because the condition on line 200 was never true

201 # Use custom stack handler 

202 mod = optional_import(mem_type.value) 

203 result = stack_handler(slices, memory_type, gpu_id, mod) 

204 else: 

205 # Standard stacking logic 

206 for i, slice_data in enumerate(slices): 

207 source_type = detect_memory_type(slice_data) 

208 

209 # Track conversions for batch logging 

210 if source_type != memory_type: 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true

211 conversion_count += 1 

212 

213 # Direct conversion 

214 if source_type == memory_type: 214 ↛ 217line 214 didn't jump to line 217 because the condition on line 214 was always true

215 converted_data = slice_data 

216 else: 

217 from openhcs.core.memory.converters import convert_memory 

218 converted_data = convert_memory( 

219 data=slice_data, 

220 source_type=source_type, 

221 target_type=memory_type, 

222 gpu_id=gpu_id 

223 ) 

224 

225 # Assign converted slice using framework-specific handler if available 

226 assign_handler = config.get('assign_slice') 

227 if assign_handler: 227 ↛ 229line 227 didn't jump to line 229 because the condition on line 227 was never true

228 # Custom assignment (JAX immutability) 

229 result = assign_handler(result, i, converted_data) 

230 else: 

231 # Standard assignment 

232 result[i] = converted_data 

233 

234 # 🔍 MEMORY CONVERSION LOGGING: Only log when conversions happen or issues occur 

235 if conversion_count > 0: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 logger.debug(f"🔄 STACK_SLICES: Converted {conversion_count}/{len(slices)} slices to {memory_type}") 

237 # Silent success for no-conversion cases to reduce log pollution 

238 

239 return result 

240 

241 

242def unstack_slices(array: Any, memory_type: str, gpu_id: int, validate_slices: bool = True) -> List[Any]: 

243 """ 

244 Split a 3D array into 2D slices along axis 0 and convert to the specified memory type. 

245 

246 STRICT VALIDATION: Input must be a 3D array. No automatic handling of improper inputs. 

247 

248 Args: 

249 array: 3D array to split - MUST BE 3D 

250 memory_type: The memory type to use for the output slices (REQUIRED) 

251 gpu_id: The target GPU device ID (REQUIRED) 

252 validate_slices: If True, validates that each extracted slice is 2D 

253 

254 Returns: 

255 List of 2D slices in the specified memory type 

256 

257 Raises: 

258 ValueError: If array is not 3D 

259 ValueError: If validate_slices is True and any extracted slice is not 2D 

260 ValueError: If gpu_id is negative for GPU memory types 

261 ValueError: If memory_type is not supported 

262 MemoryConversionError: If conversion fails 

263 """ 

264 # Detect input type and check if conversion is needed 

265 input_type = detect_memory_type(array) 

266 input_shape = getattr(array, 'shape', 'unknown') 

267 needs_conversion = input_type != memory_type 

268 

269 # Verify the array is 3D - fail loudly if not 

270 if not _is_3d(array): 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true

271 raise ValueError(f"Array must be 3D, got shape {getattr(array, 'shape', 'unknown')}") 

272 

273 # Check GPU requirements 

274 _enforce_gpu_device_requirements(memory_type, gpu_id) 

275 

276 # Convert to target memory type 

277 source_type = input_type # Reuse already detected type 

278 

279 # Direct conversion 

280 if source_type == memory_type: 280 ↛ 285line 280 didn't jump to line 285 because the condition on line 280 was always true

281 # No conversion needed - silent success to reduce log pollution 

282 pass 

283 else: 

284 # Convert and log the conversion 

285 from openhcs.core.memory.converters import convert_memory 

286 logger.debug(f"🔄 UNSTACK_SLICES: Converting array - {source_type}{memory_type}") 

287 array = convert_memory( 

288 data=array, 

289 source_type=source_type, 

290 target_type=memory_type, 

291 gpu_id=gpu_id 

292 ) 

293 

294 # Extract slices along axis 0 (already in the target memory type) 

295 slices = [array[i] for i in range(array.shape[0])] 

296 

297 # Validate that all extracted slices are 2D if requested 

298 if validate_slices: 298 ↛ 304line 298 didn't jump to line 304 because the condition on line 298 was always true

299 for i, slice_data in enumerate(slices): 

300 if not _is_2d(slice_data): 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true

301 raise ValueError(f"Extracted slice at index {i} is not 2D. This indicates a malformed 3D array.") 

302 

303 # 🔍 MEMORY CONVERSION LOGGING: Only log conversions or issues 

304 if source_type != memory_type: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 logger.debug(f"🔄 UNSTACK_SLICES: Converted and extracted {len(slices)} slices") 

306 elif len(slices) == 0: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 logger.warning("🔄 UNSTACK_SLICES: No slices extracted (empty array)") 

308 # Silent success for no-conversion cases to reduce log pollution 

309 

310 return slices