Coverage for openhcs/core/memory/framework_config.py: 18.0%

65 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Single source of truth for ALL framework-specific behavior. 

3 

4This module consolidates all framework-specific logic that was previously 

5scattered across utils.py, stack_utils.py, gpu_cleanup.py, dtype_scaling.py, 

6and framework_ops.py. 

7 

8Architecture: 

9- Framework handlers: Custom logic for special cases (pyclesperanto, JAX, TensorFlow) 

10- Unified config: Single _FRAMEWORK_CONFIG dict with all framework metadata 

11- Polymorphic dispatch: Handlers can be callables or eval expressions 

12""" 

13 

14import gc 

15import logging 

16from typing import Any, Optional, Callable 

17from openhcs.constants.constants import MemoryType 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22# ============================================================================ 

23# FRAMEWORK HANDLERS - All special-case logic lives here 

24# ============================================================================ 

25 

26def _pyclesperanto_get_device_id(data: Any, mod: Any) -> int: 

27 """Get device ID for pyclesperanto array.""" 

28 try: 

29 current_device = mod.get_device() 

30 if hasattr(current_device, 'id'): 

31 return current_device.id 

32 devices = mod.list_available_devices() 

33 for i, device in enumerate(devices): 

34 if str(device) == str(current_device): 

35 return i 

36 return 0 

37 except Exception as e: 

38 logger.warning(f"Failed to get device ID for pyclesperanto: {e}") 

39 return 0 

40 

41 

42def _pyclesperanto_set_device(device_id: int, mod: Any) -> None: 

43 """Set device for pyclesperanto.""" 

44 devices = mod.list_available_devices() 

45 if device_id >= len(devices): 

46 raise ValueError(f"Device {device_id} not available. Available: {len(devices)}") 

47 mod.select_device(device_id) 

48 

49 

50def _pyclesperanto_move_to_device(data: Any, device_id: int, mod: Any, memory_type: str) -> Any: 

51 """Move pyclesperanto array to device.""" 

52 # Import here to avoid circular dependency 

53 from openhcs.core.memory.utils import _get_device_id 

54 

55 current_device_id = _get_device_id(data, memory_type) 

56 

57 if current_device_id != device_id: 

58 mod.select_device(device_id) 

59 result = mod.create_like(data) 

60 mod.copy(data, result) 

61 return result 

62 return data 

63 

64 

65def _pyclesperanto_stack_slices(slices: list, memory_type: str, gpu_id: int, mod: Any) -> Any: 

66 """Stack slices using pyclesperanto's concatenate_along_z.""" 

67 from openhcs.core.memory.converters import convert_memory, detect_memory_type 

68 

69 converted_slices = [] 

70 conversion_count = 0 

71 

72 for slice_data in slices: 

73 source_type = detect_memory_type(slice_data) 

74 

75 if source_type != memory_type: 

76 conversion_count += 1 

77 

78 if source_type == memory_type: 

79 converted_slices.append(slice_data) 

80 else: 

81 converted = convert_memory(slice_data, source_type, memory_type, gpu_id) 

82 converted_slices.append(converted) 

83 

84 # Log batch conversion 

85 if conversion_count > 0: 

86 logger.debug( 

87 f"🔄 MEMORY CONVERSION: Converted {conversion_count}/{len(slices)} slices " 

88 f"to {memory_type} for pyclesperanto stacking" 

89 ) 

90 

91 return mod.concatenate_along_z(converted_slices) 

92 

93 

94def _jax_assign_slice(result: Any, index: int, slice_data: Any) -> Any: 

95 """Assign slice to JAX array (immutable).""" 

96 return result.at[index].set(slice_data) 

97 

98 

99def _tensorflow_validate_dlpack(obj: Any, mod: Any) -> bool: 

100 """Validate TensorFlow DLPack support.""" 

101 # Check version 

102 major, minor = map(int, mod.__version__.split('.')[:2]) 

103 if major < 2 or (major == 2 and minor < 12): 

104 raise RuntimeError( 

105 f"TensorFlow {mod.__version__} does not support stable DLPack. " 

106 f"Version 2.12.0+ required. " 

107 f"Clause 88 violation: Cannot infer DLPack capability." 

108 ) 

109 

110 # Check GPU 

111 device_str = obj.device.lower() 

112 if "gpu" not in device_str: 

113 raise RuntimeError( 

114 "TensorFlow tensor on CPU cannot use DLPack operations reliably. " 

115 "Only GPU tensors are supported for DLPack operations. " 

116 "Clause 88 violation: Cannot infer GPU capability." 

117 ) 

118 

119 # Check module 

120 if not hasattr(mod.experimental, "dlpack"): 

121 raise RuntimeError( 

122 "TensorFlow installation missing experimental.dlpack module. " 

123 "Clause 88 violation: Cannot infer DLPack capability." 

124 ) 

125 

126 return True 

127 

128 

129def _numpy_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool: 

130 """Check if NumPy needs dtype conversion (only for torch sources).""" 

131 source_type = detect_memory_type_func(first_slice) 

132 return source_type == MemoryType.TORCH.value 

133 

134 

135def _torch_dtype_conversion_needed(first_slice: Any, detect_memory_type_func: Callable) -> bool: 

136 """Torch always needs dtype conversion to get correct torch dtype.""" 

137 return True 

138 

139 

140# ============================================================================ 

141# UNIFIED FRAMEWORK CONFIGURATION 

142# ============================================================================ 

143 

144_FRAMEWORK_CONFIG = { 

145 MemoryType.NUMPY: { 

146 # Metadata 

147 'import_name': 'numpy', 

148 'display_name': 'NumPy', 

149 'is_gpu': False, 

150 

151 # Device operations 

152 'get_device_id': None, # CPU 

153 'set_device': None, # CPU 

154 'move_to_device': None, # CPU 

155 

156 # Stack operations 

157 'allocate_stack': 'np.empty(stack_shape, dtype=dtype)', 

158 'allocate_context': None, 

159 'needs_dtype_conversion': _numpy_dtype_conversion_needed, # Callable 

160 'assign_slice': None, # Standard: result[i] = slice 

161 'stack_handler': None, # Standard stacking 

162 

163 # Dtype scaling 

164 'scaling_ops': { 

165 'min': 'result.min()', 

166 'max': 'result.max()', 

167 'astype': 'result.astype(target_dtype)', 

168 'check_float': 'np.issubdtype(result.dtype, np.floating)', 

169 'check_int': 'target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]', 

170 }, 

171 

172 # Conversion operations 

173 'conversion_ops': { 

174 'to_numpy': 'data', 

175 'from_numpy': 'data', 

176 'from_dlpack': None, 

177 'move_to_device': 'data', 

178 }, 

179 

180 # DLPack 

181 'supports_dlpack': False, 

182 'validate_dlpack': None, 

183 

184 # GPU/Cleanup 

185 'lazy_getter': None, 

186 'gpu_check': None, 

187 'stream_context': None, 

188 'device_context': None, 

189 'cleanup_ops': None, 

190 'has_oom_recovery': False, 

191 'oom_exception_types': [], 

192 'oom_string_patterns': ['cannot allocate memory', 'memory exhausted'], 

193 'oom_clear_cache': 'import gc; gc.collect()', 

194 }, 

195 

196 MemoryType.CUPY: { 

197 # Metadata 

198 'import_name': 'cupy', 

199 'display_name': 'CuPy', 

200 'is_gpu': True, 

201 

202 # Device operations (eval expressions) 

203 'get_device_id': 'data.device.id', 

204 'get_device_id_fallback': '0', 

205 'set_device': '{mod}.cuda.Device(device_id).use()', 

206 'move_to_device': 'data.copy() if data.device.id != device_id else data', 

207 'move_context': '{mod}.cuda.Device(device_id)', 

208 

209 # Stack operations 

210 'allocate_stack': 'cupy.empty(stack_shape, dtype=first_slice.dtype)', 

211 'allocate_context': 'cupy.cuda.Device(gpu_id)', 

212 'needs_dtype_conversion': False, 

213 'assign_slice': None, # Standard 

214 'stack_handler': None, # Standard 

215 

216 # Dtype scaling 

217 'scaling_ops': { 

218 'min': 'mod.min(result)', 

219 'max': 'mod.max(result)', 

220 'astype': 'result.astype(target_dtype)', 

221 'check_float': 'mod.issubdtype(result.dtype, mod.floating)', 

222 'check_int': 'not mod.issubdtype(target_dtype, mod.floating)', 

223 }, 

224 

225 # Conversion operations 

226 'conversion_ops': { 

227 'to_numpy': 'data.get()', 

228 'from_numpy': '({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]', 

229 'from_dlpack': '{mod}.from_dlpack(data)', 

230 'move_to_device': 'data if data.device.id == gpu_id else ({mod}.cuda.Device(gpu_id), {mod}.array(data))[1]', 

231 }, 

232 

233 # DLPack 

234 'supports_dlpack': True, 

235 'validate_dlpack': None, 

236 

237 # GPU/Cleanup 

238 'lazy_getter': '_get_cupy', 

239 'gpu_check': '{mod} is not None and hasattr({mod}, "cuda")', 

240 'stream_context': '{mod}.cuda.Stream()', 

241 'device_context': '{mod}.cuda.Device({device_id})', 

242 'cleanup_ops': '{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()', 

243 'has_oom_recovery': True, 

244 'oom_exception_types': ['{mod}.cuda.memory.OutOfMemoryError', '{mod}.cuda.runtime.CUDARuntimeError'], 

245 'oom_string_patterns': ['out of memory', 'cuda_error_out_of_memory'], 

246 'oom_clear_cache': '{mod}.get_default_memory_pool().free_all_blocks(); {mod}.get_default_pinned_memory_pool().free_all_blocks(); {mod}.cuda.runtime.deviceSynchronize()', 

247 }, 

248 

249 MemoryType.TORCH: { 

250 # Metadata 

251 'import_name': 'torch', 

252 'display_name': 'PyTorch', 

253 'is_gpu': True, 

254 

255 # Device operations 

256 'get_device_id': 'data.device.index if data.is_cuda else None', 

257 'get_device_id_fallback': 'None', 

258 'set_device': None, # PyTorch handles device at tensor creation 

259 'move_to_device': 'data.to(f"cuda:{device_id}") if (not data.is_cuda or data.device.index != device_id) else data', 

260 

261 # Stack operations 

262 'allocate_stack': 'torch.empty(stack_shape, dtype=sample_converted.dtype, device=sample_converted.device)', 

263 'allocate_context': None, 

264 'needs_dtype_conversion': _torch_dtype_conversion_needed, # Callable 

265 'assign_slice': None, # Standard 

266 'stack_handler': None, # Standard 

267 

268 # Dtype scaling 

269 'scaling_ops': { 

270 'min': 'result.min()', 

271 'max': 'result.max()', 

272 'astype': 'result.to(target_dtype_mapped)', 

273 'check_float': 'result.dtype in [mod.float16, mod.float32, mod.float64]', 

274 'check_int': 'target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]', 

275 'needs_dtype_map': True, 

276 }, 

277 

278 # Conversion operations 

279 'conversion_ops': { 

280 'to_numpy': 'data.cpu().numpy()', 

281 'from_numpy': '{mod}.from_numpy(data).cuda(gpu_id)', 

282 'from_dlpack': '{mod}.from_dlpack(data)', 

283 'move_to_device': 'data if data.device.index == gpu_id else data.cuda(gpu_id)', 

284 }, 

285 

286 # DLPack 

287 'supports_dlpack': True, 

288 'validate_dlpack': None, 

289 

290 # GPU/Cleanup 

291 'lazy_getter': '_get_torch', 

292 'gpu_check': '{mod} is not None and hasattr({mod}, "cuda") and {mod}.cuda.is_available()', 

293 'stream_context': '{mod}.cuda.Stream()', 

294 'device_context': '{mod}.cuda.device({device_id})', 

295 'cleanup_ops': '{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()', 

296 'has_oom_recovery': True, 

297 'oom_exception_types': ['{mod}.cuda.OutOfMemoryError'], 

298 'oom_string_patterns': ['out of memory', 'cuda_error_out_of_memory'], 

299 'oom_clear_cache': '{mod}.cuda.empty_cache(); {mod}.cuda.synchronize()', 

300 }, 

301 

302 MemoryType.TENSORFLOW: { 

303 # Metadata 

304 'import_name': 'tensorflow', 

305 'display_name': 'TensorFlow', 

306 'is_gpu': True, 

307 

308 # Device operations 

309 'get_device_id': 'int(data.device.lower().split(":")[-1]) if "gpu" in data.device.lower() else None', 

310 'get_device_id_fallback': 'None', 

311 'set_device': None, # TensorFlow handles device at tensor creation 

312 'move_to_device': '{mod}.identity(data)', 

313 'move_context': '{mod}.device(f"/device:GPU:{device_id}")', 

314 

315 # Stack operations 

316 'allocate_stack': 'tf.zeros(stack_shape, dtype=first_slice.dtype)', # TF doesn't have empty() 

317 'allocate_context': 'tf.device(f"/device:GPU:{gpu_id}")', 

318 'needs_dtype_conversion': False, 

319 'assign_slice': None, # Standard 

320 'stack_handler': None, # Standard 

321 

322 # Dtype scaling 

323 'scaling_ops': { 

324 'min': 'mod.reduce_min(result)', 

325 'max': 'mod.reduce_max(result)', 

326 'astype': 'mod.cast(result, target_dtype_mapped)', 

327 'check_float': 'result.dtype in [mod.float16, mod.float32, mod.float64]', 

328 'check_int': 'target_dtype_mapped in [mod.uint8, mod.int8, mod.int16, mod.int32, mod.int64]', 

329 'needs_dtype_map': True, 

330 }, 

331 

332 # Conversion operations 

333 'conversion_ops': { 

334 'to_numpy': 'data.numpy()', 

335 'from_numpy': '{mod}.convert_to_tensor(data)', 

336 'from_dlpack': '{mod}.experimental.dlpack.from_dlpack(data)', 

337 'move_to_device': 'data', 

338 }, 

339 

340 # DLPack 

341 'supports_dlpack': True, 

342 'validate_dlpack': _tensorflow_validate_dlpack, # Custom validation 

343 

344 # GPU/Cleanup 

345 'lazy_getter': '_get_tensorflow', 

346 'gpu_check': '{mod} is not None and {mod}.config.list_physical_devices("GPU")', 

347 'stream_context': None, # TensorFlow manages streams internally 

348 'device_context': '{mod}.device("/GPU:0")', 

349 'cleanup_ops': None, # TensorFlow has no explicit cache clearing API 

350 'has_oom_recovery': True, 

351 'oom_exception_types': ['{mod}.errors.ResourceExhaustedError', '{mod}.errors.InvalidArgumentError'], 

352 'oom_string_patterns': ['out of memory', 'resource_exhausted'], 

353 'oom_clear_cache': None, # TensorFlow has no explicit cache clearing API 

354 }, 

355 

356 MemoryType.JAX: { 

357 # Metadata 

358 'import_name': 'jax', 

359 'display_name': 'JAX', 

360 'is_gpu': True, 

361 

362 # Device operations 

363 'get_device_id': 'int(str(data.device).lower().split(":")[-1]) if "gpu" in str(data.device).lower() else None', 

364 'get_device_id_fallback': 'None', 

365 'set_device': None, # JAX handles device at array creation 

366 'move_to_device': '{mod}.device_put(data, {mod}.devices("gpu")[device_id])', 

367 

368 # Stack operations 

369 'allocate_stack': 'jnp.empty(stack_shape, dtype=first_slice.dtype)', 

370 'allocate_context': None, 

371 'needs_dtype_conversion': False, 

372 'assign_slice': _jax_assign_slice, # Custom handler for immutability 

373 'stack_handler': None, # Standard 

374 

375 # Dtype scaling 

376 'scaling_ops': { 

377 'min': 'jnp.min(result)', 

378 'max': 'jnp.max(result)', 

379 'astype': 'result.astype(target_dtype_mapped)', 

380 'check_float': 'result.dtype in [jnp.float16, jnp.float32, jnp.float64]', 

381 'check_int': 'target_dtype_mapped in [jnp.uint8, jnp.int8, jnp.int16, jnp.int32, jnp.int64]', 

382 'needs_dtype_map': True, 

383 'extra_import': 'jax.numpy', 

384 }, 

385 

386 # Conversion operations 

387 'conversion_ops': { 

388 'to_numpy': 'np.asarray(data)', 

389 'from_numpy': '{mod}.device_put(data, {mod}.devices()[gpu_id])', 

390 'from_dlpack': '{mod}.dlpack.from_dlpack(data)', 

391 'move_to_device': 'data', 

392 }, 

393 

394 # DLPack 

395 'supports_dlpack': True, 

396 'validate_dlpack': None, 

397 

398 # GPU/Cleanup 

399 'lazy_getter': '_get_jax', 

400 'gpu_check': '{mod} is not None and any(d.platform == "gpu" for d in {mod}.devices())', 

401 'stream_context': None, # JAX/XLA manages streams internally 

402 'device_context': '{mod}.default_device([d for d in {mod}.devices() if d.platform == "gpu"][0])', 

403 'cleanup_ops': '{mod}.clear_caches()', 

404 'has_oom_recovery': True, 

405 'oom_exception_types': [], 

406 'oom_string_patterns': ['out of memory', 'oom when allocating', 'allocation failure'], 

407 'oom_clear_cache': '{mod}.clear_caches()', 

408 }, 

409 

410 MemoryType.PYCLESPERANTO: { 

411 # Metadata 

412 'import_name': 'pyclesperanto', 

413 'display_name': 'pyclesperanto', 

414 'is_gpu': True, 

415 

416 # Device operations (custom handlers) 

417 'get_device_id': _pyclesperanto_get_device_id, # Callable 

418 'get_device_id_fallback': '0', 

419 'set_device': _pyclesperanto_set_device, # Callable 

420 'move_to_device': _pyclesperanto_move_to_device, # Callable 

421 

422 # Stack operations (custom handler) 

423 'allocate_stack': None, # Uses concatenate_along_z 

424 'allocate_context': None, 

425 'needs_dtype_conversion': False, 

426 'assign_slice': None, # Not used (custom stacking) 

427 'stack_handler': _pyclesperanto_stack_slices, # Custom stacking 

428 

429 # Conversion operations 

430 'conversion_ops': { 

431 'to_numpy': '{mod}.pull(data)', 

432 'from_numpy': '{mod}.push(data)', 

433 'from_dlpack': None, 

434 'move_to_device': 'data', 

435 }, 

436 

437 # Dtype scaling (custom implementation in dtype_scaling.py) 

438 'scaling_ops': None, # Custom _scale_pyclesperanto function 

439 

440 # DLPack 

441 'supports_dlpack': False, 

442 'validate_dlpack': None, 

443 

444 # GPU/Cleanup 

445 'lazy_getter': None, 

446 'gpu_check': None, # pyclesperanto always uses GPU if available 

447 'stream_context': None, # OpenCL manages streams internally 

448 'device_context': None, # OpenCL device selection is global 

449 'cleanup_ops': None, # pyclesperanto/OpenCL has no explicit cache clearing API 

450 'has_oom_recovery': True, 

451 'oom_exception_types': [], 

452 'oom_string_patterns': ['cl_mem_object_allocation_failure', 'cl_out_of_resources', 'out of memory'], 

453 'oom_clear_cache': None, # pyclesperanto/OpenCL has no explicit cache clearing API 

454 }, 

455} 

456