Coverage for src/arraybridge/stack_utils.py: 78%

103 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2Stack utilities module for OpenHCS. 

3 

4This module provides functions for stacking 2D slices into a 3D array 

5and unstacking a 3D array into 2D slices, with explicit memory type handling. 

6 

7This module enforces Clause 278 — Mandatory 3D Output Enforcement: 

8All functions must return a 3D array of shape [Z, Y, X], even when operating 

9on a single 2D slice. No logic may check, coerce, or infer rank at unstack time. 

10""" 

11 

12import logging 

13from typing import Any 

14 

15from arraybridge.converters import detect_memory_type 

16from arraybridge.framework_config import _FRAMEWORK_CONFIG 

17from arraybridge.types import GPU_MEMORY_TYPES, MemoryType 

18from arraybridge.utils import optional_import 

19 

20logger = logging.getLogger(__name__) 

21 

22# 🔍 MEMORY CONVERSION LOGGING: Test log to verify logger is working 

23logger.debug("🔄 STACK_UTILS: Module loaded - memory conversion logging enabled") 

24 

25 

26def _is_2d(data: Any) -> bool: 

27 """ 

28 Check if data is a 2D array. 

29 

30 Args: 

31 data: Data to check 

32 

33 Returns: 

34 True if data is 2D, False otherwise 

35 """ 

36 # Check if data has a shape attribute 

37 if not hasattr(data, "shape"): 

38 return False 

39 

40 # Check if shape has length 2 

41 return len(data.shape) == 2 

42 

43 

44def _is_3d(data: Any) -> bool: 

45 """ 

46 Check if data is a 3D array. 

47 

48 Args: 

49 data: Data to check 

50 

51 Returns: 

52 True if data is 3D, False otherwise 

53 """ 

54 # Check if data has a shape attribute 

55 if not hasattr(data, "shape"): 

56 return False 

57 

58 # Check if shape has length 3 

59 return len(data.shape) == 3 

60 

61 

62def _enforce_gpu_device_requirements(memory_type: str, gpu_id: int) -> None: 

63 """ 

64 Enforce GPU device requirements. 

65 

66 Args: 

67 memory_type: The memory type 

68 gpu_id: The GPU device ID 

69 

70 Raises: 

71 ValueError: If gpu_id is negative 

72 """ 

73 # For GPU memory types, validate gpu_id 

74 if memory_type in {mem_type.value for mem_type in GPU_MEMORY_TYPES}: 

75 if gpu_id < 0: 

76 raise ValueError(f"Invalid GPU device ID: {gpu_id}. Must be a non-negative integer.") 

77 

78 

79# NOTE: Allocation operations now defined in framework_config.py 

80# This eliminates the scattered _ALLOCATION_OPS dict 

81 

82 

83def _allocate_stack_array( 

84 memory_type: str, stack_shape: tuple, first_slice: Any, gpu_id: int 

85) -> Any: 

86 """ 

87 Allocate a 3D array for stacking slices using framework config. 

88 

89 Args: 

90 memory_type: The target memory type 

91 stack_shape: The shape of the stack (Z, Y, X) 

92 first_slice: The first slice (used for dtype inference) 

93 gpu_id: The GPU device ID 

94 

95 Returns: 

96 Pre-allocated array or None for pyclesperanto 

97 """ 

98 # Convert string to enum 

99 mem_type = MemoryType(memory_type) 

100 config = _FRAMEWORK_CONFIG[mem_type] 

101 allocate_expr = config["allocate_stack"] 

102 

103 # Check if allocation is None (pyclesperanto uses custom stacking) 

104 if allocate_expr is None: 

105 return None 

106 

107 # Import the module 

108 mod = optional_import(mem_type.value) 

109 if mod is None: 

110 raise ValueError(f"{mem_type.value} is required for memory type {memory_type}") 

111 

112 # Handle dtype conversion if needed 

113 needs_conversion = config["needs_dtype_conversion"] 

114 if callable(needs_conversion): 

115 # It's a callable that determines if conversion is needed 

116 needs_conversion = needs_conversion(first_slice, detect_memory_type) 

117 

118 # Initialize variables for eval expressions 

119 sample_converted = None 

120 if needs_conversion: 

121 from arraybridge.converters import convert_memory 

122 

123 first_slice_source_type = detect_memory_type(first_slice) 

124 sample_converted = convert_memory( 

125 data=first_slice, 

126 source_type=first_slice_source_type, 

127 target_type=memory_type, 

128 gpu_id=gpu_id, 

129 ) 

130 

131 # Set up local variables for eval 

132 np = optional_import("numpy") # noqa: F841 

133 cupy = mod if mem_type == MemoryType.CUPY else None # noqa: F841 

134 torch = mod if mem_type == MemoryType.TORCH else None # noqa: F841 

135 tf = mod if mem_type == MemoryType.TENSORFLOW else None # noqa: F841 

136 jnp = optional_import("jax.numpy") if mem_type == MemoryType.JAX else None # noqa: F841 

137 # dtype is used in allocate_expr eval below (for numpy framework) 

138 dtype = ( # noqa: F841 

139 sample_converted.dtype 

140 if sample_converted is not None 

141 else (first_slice.dtype if hasattr(first_slice, "dtype") else None) 

142 ) 

143 

144 # Execute allocation with context if needed 

145 allocate_context = config.get("allocate_context") 

146 if allocate_context: 

147 context = eval(allocate_context) 

148 with context: 

149 return eval(allocate_expr) 

150 else: 

151 return eval(allocate_expr) 

152 

153 

154def stack_slices(slices: list[Any], memory_type: str, gpu_id: int) -> Any: 

155 """ 

156 Stack 2D slices into a 3D array with the specified memory type. 

157 

158 STRICT VALIDATION: Assumes all slices are 2D arrays. 

159 No automatic handling of improper inputs. 

160 

161 Args: 

162 slices: List of 2D slices (numpy arrays, cupy arrays, torch tensors, etc.) 

163 memory_type: The memory type to use for the stacked array (REQUIRED) 

164 gpu_id: The target GPU device ID (REQUIRED) 

165 

166 Returns: 

167 A 3D array with the specified memory type of shape [Z, Y, X] 

168 

169 Raises: 

170 ValueError: If memory_type is not supported or slices is empty 

171 ValueError: If gpu_id is negative for GPU memory types 

172 ValueError: If slices are not 2D arrays 

173 MemoryConversionError: If conversion fails 

174 """ 

175 if not slices: 

176 raise ValueError("Cannot stack empty list of slices") 

177 

178 # Verify all slices are 2D 

179 for i, slice_data in enumerate(slices): 

180 if not _is_2d(slice_data): 

181 raise ValueError(f"Slice at index {i} is not a 2D array. All slices must be 2D.") 

182 

183 # Check GPU requirements 

184 _enforce_gpu_device_requirements(memory_type, gpu_id) 

185 

186 # Pre-allocate the final 3D array to avoid intermediate list and final stack operation 

187 first_slice = slices[0] 

188 stack_shape = (len(slices), first_slice.shape[0], first_slice.shape[1]) 

189 

190 # Create pre-allocated result array in target memory type using enum dispatch 

191 result = _allocate_stack_array(memory_type, stack_shape, first_slice, gpu_id) 

192 

193 # Convert each slice and assign to result array 

194 conversion_count = 0 

195 

196 # Check for custom stack handler (pyclesperanto) 

197 mem_type = MemoryType(memory_type) 

198 config = _FRAMEWORK_CONFIG[mem_type] 

199 stack_handler = config.get("stack_handler") 

200 

201 if stack_handler: 

202 # Use custom stack handler 

203 mod = optional_import(mem_type.value) 

204 result = stack_handler(slices, memory_type, gpu_id, mod) 

205 else: 

206 # Standard stacking logic 

207 for i, slice_data in enumerate(slices): 

208 source_type = detect_memory_type(slice_data) 

209 

210 # Track conversions for batch logging 

211 if source_type != memory_type: 

212 conversion_count += 1 

213 

214 # Direct conversion 

215 if source_type == memory_type: 

216 converted_data = slice_data 

217 else: 

218 from arraybridge.converters import convert_memory 

219 

220 converted_data = convert_memory( 

221 data=slice_data, source_type=source_type, target_type=memory_type, gpu_id=gpu_id 

222 ) 

223 

224 # Assign converted slice using framework-specific handler if available 

225 assign_handler = config.get("assign_slice") 

226 if assign_handler: 

227 # Custom assignment (JAX immutability) 

228 result = assign_handler(result, i, converted_data) 

229 else: 

230 # Standard assignment 

231 result[i] = converted_data 

232 

233 # 🔍 MEMORY CONVERSION LOGGING: Only log when conversions happen or issues occur 

234 if conversion_count > 0: 

235 logger.debug( 

236 f"🔄 STACK_SLICES: Converted {conversion_count}/{len(slices)} " 

237 f"slices to {memory_type}" 

238 ) 

239 # Silent success for no-conversion cases to reduce log pollution 

240 

241 return result 

242 

243 

244def unstack_slices( 

245 array: Any, memory_type: str, gpu_id: int, validate_slices: bool = True 

246) -> list[Any]: 

247 """ 

248 Split a 3D array into 2D slices along axis 0 and convert to the specified memory type. 

249 

250 STRICT VALIDATION: Input must be a 3D array. No automatic handling of improper inputs. 

251 

252 Args: 

253 array: 3D array to split - MUST BE 3D 

254 memory_type: The memory type to use for the output slices (REQUIRED) 

255 gpu_id: The target GPU device ID (REQUIRED) 

256 validate_slices: If True, validates that each extracted slice is 2D 

257 

258 Returns: 

259 List of 2D slices in the specified memory type 

260 

261 Raises: 

262 ValueError: If array is not 3D 

263 ValueError: If validate_slices is True and any extracted slice is not 2D 

264 ValueError: If gpu_id is negative for GPU memory types 

265 ValueError: If memory_type is not supported 

266 MemoryConversionError: If conversion fails 

267 """ 

268 # Detect input type and check if conversion is needed 

269 input_type = detect_memory_type(array) 

270 getattr(array, "shape", "unknown") 

271 

272 # Verify the array is 3D - fail loudly if not 

273 if not _is_3d(array): 

274 raise ValueError(f"Array must be 3D, got shape {getattr(array, 'shape', 'unknown')}") 

275 

276 # Check GPU requirements 

277 _enforce_gpu_device_requirements(memory_type, gpu_id) 

278 

279 # Convert to target memory type 

280 source_type = input_type # Reuse already detected type 

281 

282 # Direct conversion 

283 if source_type == memory_type: 

284 # No conversion needed - silent success to reduce log pollution 

285 pass 

286 else: 

287 # Convert and log the conversion 

288 from arraybridge.converters import convert_memory 

289 

290 logger.debug(f"🔄 UNSTACK_SLICES: Converting array - {source_type}{memory_type}") 

291 array = convert_memory( 

292 data=array, source_type=source_type, target_type=memory_type, gpu_id=gpu_id 

293 ) 

294 

295 # Extract slices along axis 0 (already in the target memory type) 

296 slices = [array[i] for i in range(array.shape[0])] 

297 

298 # Validate that all extracted slices are 2D if requested 

299 if validate_slices: 

300 for i, slice_data in enumerate(slices): 

301 if not _is_2d(slice_data): 

302 raise ValueError( 

303 f"Extracted slice at index {i} is not 2D. " 

304 f"This indicates a malformed 3D array." 

305 ) 

306 

307 # 🔍 MEMORY CONVERSION LOGGING: Only log conversions or issues 

308 if source_type != memory_type: 

309 logger.debug(f"🔄 UNSTACK_SLICES: Converted and extracted {len(slices)} slices") 

310 elif len(slices) == 0: 

311 logger.warning("🔄 UNSTACK_SLICES: No slices extracted (empty array)") 

312 # Silent success for no-conversion cases to reduce log pollution 

313 

314 return slices