Coverage for openhcs/core/memory/gpu_cleanup.py: 27.4%

208 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2GPU memory cleanup utilities for different frameworks. 

3 

4This module provides unified GPU memory cleanup functions for PyTorch, CuPy,  

5TensorFlow, and JAX. The cleanup functions are designed to be called after 

6processing steps to free up GPU memory that's no longer needed. 

7""" 

8 

9import logging 

10from typing import Optional 

11from openhcs.core.utils import optional_import 

12from openhcs.constants.constants import VALID_GPU_MEMORY_TYPES # Import directly if always available 

13 

14logger = logging.getLogger(__name__) 

15 

16# --- Top-level optional imports for GPU frameworks --- 

17torch = optional_import("torch") 

18cupy = optional_import("cupy") 

19tensorflow = optional_import("tensorflow") 

20jax = optional_import("jax") 

21pyclesperanto = optional_import("pyclesperanto") 

22 

23# --- Cleanup functions --- 

24 

25def is_gpu_memory_type(memory_type: str) -> bool: 

26 """ 

27 Check if a memory type is a GPU memory type. 

28 

29 Args: 

30 memory_type: Memory type string 

31 

32 Returns: 

33 True if it's a GPU memory type, False otherwise 

34 """ 

35 # Using VALID_GPU_MEMORY_TYPES directly after top-level import 

36 # If openhcs.constants.constants is itself optional, then this function 

37 # might need to revert to its try-except, or ensure that constants are 

38 # always available for core utilities. Assuming it's always available now. 

39 return memory_type in VALID_GPU_MEMORY_TYPES 

40 

41 

42def cleanup_pytorch_gpu(device_id: Optional[int] = None) -> None: 

43 """ 

44 Clean up PyTorch GPU memory. 

45  

46 Args: 

47 device_id: Optional GPU device ID. If None, cleans all devices. 

48 """ 

49 if torch is None: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true

50 logger.debug("PyTorch not available, skipping PyTorch GPU cleanup") 

51 return 

52 

53 try: 

54 if not torch.cuda.is_available(): 54 ↛ anywhereline 54 didn't jump anywhere: it always raised an exception.

55 return 

56 

57 if device_id is not None: 

58 # Clean specific device 

59 with torch.cuda.device(device_id): 

60 torch.cuda.empty_cache() 

61 torch.cuda.synchronize() 

62 logger.debug(f"🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for device {device_id}") 

63 else: 

64 # Clean all devices 

65 torch.cuda.empty_cache() 

66 torch.cuda.synchronize() 

67 logger.debug("🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for all devices") 

68 

69 except Exception as e: 

70 logger.warning(f"Failed to cleanup PyTorch GPU memory: {e}") 

71 

72 

73def cleanup_cupy_gpu(device_id: Optional[int] = None) -> None: 

74 """ 

75 Clean up CuPy GPU memory with aggressive defragmentation. 

76 

77 Args: 

78 device_id: Optional GPU device ID. If None, cleans current device. 

79 """ 

80 if cupy is None: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true

81 logger.debug("CuPy not available, skipping CuPy GPU cleanup") 

82 return 

83 

84 try: 

85 if device_id is not None: 85 ↛ 87line 85 didn't jump to line 87 because the condition on line 85 was never true

86 # Clean specific device 

87 with cupy.cuda.Device(device_id): 

88 # Get memory info before cleanup 

89 mempool = cupy.get_default_memory_pool() 

90 used_before = mempool.used_bytes() 

91 

92 # Aggressive cleanup to defragment memory 

93 cupy.get_default_memory_pool().free_all_blocks() 

94 cupy.get_default_pinned_memory_pool().free_all_blocks() 

95 

96 # Force memory pool reset to defragment 

97 cupy.cuda.runtime.deviceSynchronize() 

98 

99 used_after = mempool.used_bytes() 

100 freed_mb = (used_before - used_after) / 1e6 

101 

102 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for device {device_id}, freed {freed_mb:.1f}MB") 

103 else: 

104 # Clean current device 

105 mempool = cupy.get_default_memory_pool() 

106 used_before = mempool.used_bytes() 

107 

108 # Aggressive cleanup to defragment memory 

109 cupy.get_default_memory_pool().free_all_blocks() 

110 cupy.get_default_pinned_memory_pool().free_all_blocks() 

111 

112 # Force memory pool reset to defragment 

113 cupy.cuda.runtime.deviceSynchronize() 

114 

115 used_after = mempool.used_bytes() 

116 freed_mb = (used_before - used_after) / 1e6 

117 

118 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for current device, freed {freed_mb:.1f}MB") 

119 

120 except Exception as e: 

121 logger.warning(f"Failed to cleanup CuPy GPU memory: {e}") 

122 

123 

124def cleanup_tensorflow_gpu(device_id: Optional[int] = None) -> None: 

125 """ 

126 Clean up TensorFlow GPU memory. 

127  

128 Args: 

129 device_id: Optional GPU device ID. If None, cleans all devices. 

130 """ 

131 if tensorflow is None: 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true

132 logger.debug("TensorFlow not available, skipping TensorFlow GPU cleanup") 

133 return 

134 

135 try: 

136 # Get list of GPU devices 

137 gpus = tensorflow.config.list_physical_devices('GPU') 

138 if not gpus: 

139 return 

140 

141 if device_id is not None and device_id < len(gpus): 

142 # Clean specific device - TensorFlow doesn't have per-device cleanup 

143 # so we trigger garbage collection which helps with memory management 

144 import gc 

145 gc.collect() 

146 logger.debug(f"🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPU {device_id}") 

147 else: 

148 # Clean all devices - trigger garbage collection 

149 import gc 

150 gc.collect() 

151 logger.debug("🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPUs") 

152 

153 except Exception as e: 

154 logger.warning(f"Failed to cleanup TensorFlow GPU memory: {e}") 

155 

156 

157def cleanup_jax_gpu(device_id: Optional[int] = None) -> None: 

158 """ 

159 Clean up JAX GPU memory. 

160 

161 Args: 

162 device_id: Optional GPU device ID. If None, cleans all devices. 

163 """ 

164 if jax is None: 164 ↛ 165line 164 didn't jump to line 165 because the condition on line 164 was never true

165 logger.debug("JAX not available, skipping JAX GPU cleanup") 

166 return 

167 

168 try: 

169 # JAX doesn't have explicit memory cleanup like PyTorch/CuPy 

170 # but we can trigger garbage collection and clear compilation cache 

171 import gc 

172 gc.collect() 

173 

174 # Clear JAX compilation cache which can hold GPU memory 

175 jax.clear_caches() 

176 

177 if device_id is not None: 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true

178 logger.debug(f"🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for device {device_id}") 

179 else: 

180 logger.debug("🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for all devices") 

181 

182 except Exception as e: 

183 logger.warning(f"Failed to cleanup JAX GPU memory: {e}") 

184 

185 

186def cleanup_pyclesperanto_gpu(device_id: Optional[int] = None) -> None: 

187 """ 

188 Clean up pyclesperanto GPU memory. 

189 

190 Args: 

191 device_id: Optional GPU device ID. If None, cleans current device. 

192 """ 

193 if pyclesperanto is None: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 logger.debug("pyclesperanto not available, skipping pyclesperanto GPU cleanup") 

195 return 

196 

197 try: 

198 import gc 

199 

200 # pyclesperanto doesn't have explicit memory cleanup like PyTorch/CuPy 

201 # but we can trigger garbage collection and clear any cached data 

202 

203 if device_id is not None: 203 ↛ 205line 203 didn't jump to line 205 because the condition on line 203 was never true

204 # Select the specific device 

205 devices = pyclesperanto.list_available_devices() 

206 if device_id < len(devices): 

207 pyclesperanto.select_device(device_id) 

208 logger.debug(f"🔥 GPU CLEANUP: Selected pyclesperanto device {device_id}") 

209 else: 

210 logger.warning(f"🔥 GPU CLEANUP: Device {device_id} not available in pyclesperanto") 

211 

212 # Trigger garbage collection to clean up any unreferenced GPU arrays 

213 collected = gc.collect() 

214 

215 # pyclesperanto uses OpenCL which manages memory automatically 

216 # but we can help by ensuring Python objects are cleaned up 

217 if device_id is not None: 217 ↛ 218line 217 didn't jump to line 218 because the condition on line 217 was never true

218 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto device {device_id}, collected {collected} objects") 

219 else: 

220 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto current device, collected {collected} objects") 

221 

222 except Exception as e: 

223 logger.warning(f"Failed to cleanup pyclesperanto GPU memory: {e}") 

224 

225 

226def cleanup_gpu_memory_by_framework(memory_type: str, device_id: Optional[int] = None) -> None: 

227 """ 

228 Clean up GPU memory based on the OpenHCS memory type. 

229 

230 Args: 

231 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy") 

232 device_id: Optional GPU device ID 

233 """ 

234 # Handle exact OpenHCS memory type values 

235 if memory_type == "torch": 

236 cleanup_pytorch_gpu(device_id) 

237 elif memory_type == "cupy": 

238 cleanup_cupy_gpu(device_id) 

239 elif memory_type == "tensorflow": 

240 cleanup_tensorflow_gpu(device_id) 

241 elif memory_type == "jax": 

242 cleanup_jax_gpu(device_id) 

243 elif memory_type == "pyclesperanto": 

244 cleanup_pyclesperanto_gpu(device_id) 

245 elif memory_type == "numpy": 

246 # CPU memory type - no GPU cleanup needed 

247 logger.debug(f"No GPU cleanup needed for CPU memory type: {memory_type}") 

248 else: 

249 # Fallback for unknown types - try pattern matching 

250 memory_type_lower = memory_type.lower() 

251 if "torch" in memory_type_lower or "pytorch" in memory_type_lower: 

252 cleanup_pytorch_gpu(device_id) 

253 elif "cupy" in memory_type_lower: 

254 cleanup_cupy_gpu(device_id) 

255 elif "tensorflow" in memory_type_lower or "tf" in memory_type_lower: 

256 cleanup_tensorflow_gpu(device_id) 

257 elif "jax" in memory_type_lower: 

258 cleanup_jax_gpu(device_id) 

259 elif "pyclesperanto" in memory_type_lower or "clesperanto" in memory_type_lower: 

260 cleanup_pyclesperanto_gpu(device_id) 

261 else: 

262 logger.debug(f"Unknown memory type for GPU cleanup: {memory_type}") 

263 

264 

265def cleanup_numpy_noop(device_id: Optional[int] = None) -> None: 

266 """ 

267 No-op cleanup for numpy (CPU memory type). 

268 

269 Args: 

270 device_id: Optional GPU device ID (ignored for CPU) 

271 """ 

272 logger.debug("🔥 GPU CLEANUP: No-op for numpy (CPU memory type)") 

273 

274 

275def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None: 

276 """ 

277 Clean up GPU memory for all available frameworks. 

278 

279 Args: 

280 device_id: Optional GPU device ID 

281 """ 

282 cleanup_pytorch_gpu(device_id) 

283 cleanup_cupy_gpu(device_id) 

284 cleanup_tensorflow_gpu(device_id) 

285 cleanup_jax_gpu(device_id) 

286 cleanup_pyclesperanto_gpu(device_id) 

287 

288 # Also trigger Python garbage collection 

289 import gc 

290 gc.collect() 

291 

292 logger.debug("🔥 GPU CLEANUP: Performed comprehensive cleanup for all GPU frameworks") 

293 

294 

295# Registry mapping memory types to their cleanup functions 

296MEMORY_TYPE_CLEANUP_REGISTRY = { 

297 "torch": cleanup_pytorch_gpu, 

298 "cupy": cleanup_cupy_gpu, 

299 "tensorflow": cleanup_tensorflow_gpu, 

300 "jax": cleanup_jax_gpu, 

301 "pyclesperanto": cleanup_pyclesperanto_gpu, 

302 "numpy": cleanup_numpy_noop, 

303} 

304 

305 

306def cleanup_memory_by_type(memory_type: str, device_id: Optional[int] = None) -> None: 

307 """ 

308 Clean up memory using the registered cleanup function for the memory type. 

309 

310 Args: 

311 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy") 

312 device_id: Optional GPU device ID 

313 """ 

314 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY.get(memory_type) 

315 

316 if cleanup_func: 

317 cleanup_func(device_id) 

318 else: 

319 logger.warning(f"No cleanup function registered for memory type: {memory_type}") 

320 logger.debug(f"Available memory types: {list(MEMORY_TYPE_CLEANUP_REGISTRY.keys())}") 

321 

322 

323def check_gpu_memory_usage() -> None: 

324 """ 

325 Check and log current GPU memory usage for all available frameworks. 

326 

327 This is a utility function for debugging memory issues. 

328 """ 

329 logger.debug("🔍 GPU Memory Usage Report:") 

330 

331 # Check PyTorch 

332 if torch is not None: 

333 if torch.cuda.is_available(): 

334 for i in range(torch.cuda.device_count()): 

335 allocated = torch.cuda.memory_allocated(i) / 1024**3 

336 reserved = torch.cuda.memory_reserved(i) / 1024**3 

337 logger.debug(f" PyTorch GPU {i}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") 

338 else: 

339 logger.debug(" PyTorch: No CUDA available") 

340 else: 

341 logger.debug(" PyTorch: Not installed") 

342 

343 # Check CuPy 

344 if cupy is not None: 

345 mempool = cupy.get_default_memory_pool() 

346 used_bytes = mempool.used_bytes() 

347 total_bytes = mempool.total_bytes() 

348 logger.debug(f" CuPy: {used_bytes / 1024**3:.2f}GB used, {total_bytes / 1024**3:.2f}GB total") 

349 else: 

350 logger.debug(" CuPy: Not installed") # Added missing log for consistency 

351 

352 # Note: TensorFlow and JAX don't have easy memory introspection 

353 logger.debug(" TensorFlow/JAX: Memory usage not easily queryable. Check if installed:") 

354 if tensorflow is None: 

355 logger.debug(" TensorFlow: Not installed") 

356 if jax is None: 

357 logger.debug(" JAX: Not installed") 

358 

359 

360def log_gpu_memory_usage(context: str = "") -> None: 

361 """ 

362 Log GPU memory usage with a specific context for tracking. 

363 

364 Args: 

365 context: Description of when/where this memory check is happening 

366 """ 

367 context_str = f" ({context})" if context else "" 

368 

369 if torch is not None: 369 ↛ 382line 369 didn't jump to line 382 because the condition on line 369 was always true

370 try: # Keep try-except for runtime CUDA availability check 

371 if torch.cuda.is_available(): 371 ↛ anywhereline 371 didn't jump anywhere: it always raised an exception.

372 for i in range(torch.cuda.device_count()): 

373 allocated = torch.cuda.memory_allocated(i) / 1024**3 

374 reserved = torch.cuda.memory_reserved(i) / 1024**3 

375 free_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 - reserved 

376 logger.debug(f"🔍 VRAM{context_str} GPU {i}: {allocated:.2f}GB alloc, {reserved:.2f}GB reserved, {free_memory:.2f}GB free") 

377 else: 

378 logger.debug(f"🔍 VRAM{context_str}: No CUDA available") 

379 except Exception as e: 

380 logger.warning(f"🔍 VRAM{context_str}: Error checking PyTorch memory - {e}") 

381 else: 

382 logger.debug(f"🔍 VRAM{context_str}: PyTorch not available") 

383 

384 

385def get_gpu_memory_summary() -> dict: 

386 """ 

387 Get GPU memory usage as a dictionary for programmatic use. 

388 

389 Returns: 

390 Dictionary with memory usage information 

391 """ 

392 memory_info = { 

393 "pytorch": {"available": False, "devices": []}, 

394 "cupy": {"available": False, "used_gb": 0, "total_gb": 0} 

395 } 

396 

397 # Check PyTorch 

398 if torch is not None: 

399 try: # Keep try-except for runtime CUDA availability check 

400 if torch.cuda.is_available(): 

401 memory_info["pytorch"]["available"] = True 

402 for i in range(torch.cuda.device_count()): 

403 allocated = torch.cuda.memory_allocated(i) / 1024**3 

404 reserved = torch.cuda.memory_reserved(i) / 1024**3 

405 total = torch.cuda.get_device_properties(i).total_memory / 1024**3 

406 memory_info["pytorch"]["devices"].append({ 

407 "device_id": i, 

408 "allocated_gb": allocated, 

409 "reserved_gb": reserved, 

410 "total_gb": total, 

411 "free_gb": total - reserved 

412 }) 

413 except Exception: # Catch exceptions related to CUDA operations if available 

414 pass # Suppress specific error details if main check is for availability 

415 

416 # Check CuPy 

417 if cupy is not None: 

418 memory_info["cupy"]["available"] = True 

419 mempool = cupy.get_default_memory_pool() 

420 memory_info["cupy"]["used_gb"] = mempool.used_bytes() / 1024**3 

421 memory_info["cupy"]["total_gb"] = mempool.total_bytes() / 1024**3 

422 

423 return memory_info 

424 

425 

426def force_comprehensive_cleanup() -> None: 

427 """ 

428 Force comprehensive GPU cleanup across all frameworks and trigger garbage collection. 

429 

430 This is the nuclear option for clearing GPU memory when you suspect leaks. 

431 """ 

432 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Starting nuclear cleanup...") 

433 

434 # Clean all GPU frameworks 

435 cleanup_all_gpu_frameworks() 

436 

437 # Multiple rounds of garbage collection 

438 import gc 

439 for i in range(3): 

440 collected = gc.collect() 

441 logger.debug(f"🧹 Garbage collection round {i+1}: collected {collected} objects") 

442 

443 # Check memory usage after cleanup 

444 check_gpu_memory_usage() 

445 

446 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Complete")