Coverage for openhcs/core/memory/gpu_cleanup.py: 25.2%

216 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 18:33 +0000

1""" 

2GPU memory cleanup utilities for different frameworks. 

3 

4This module provides unified GPU memory cleanup functions for PyTorch, CuPy, 

5TensorFlow, and JAX. The cleanup functions are designed to be called after 

6processing steps to free up GPU memory that's no longer needed. 

7""" 

8 

9import logging 

10import os 

11from typing import Optional 

12from openhcs.core.utils import optional_import 

13from openhcs.constants.constants import VALID_GPU_MEMORY_TYPES # Import directly if always available 

14 

15logger = logging.getLogger(__name__) 

16 

17# Check if we're in subprocess runner mode and should skip GPU imports 

18if os.getenv('OPENHCS_SUBPROCESS_NO_GPU') == '1': 18 ↛ 20line 18 didn't jump to line 20 because the condition on line 18 was never true

19 # Subprocess runner mode - skip GPU imports 

20 torch = None 

21 cupy = None 

22 tensorflow = None 

23 jax = None 

24 pyclesperanto = None 

25 logger.info("Subprocess runner mode - skipping GPU library imports in gpu_cleanup") 

26else: 

27 # Normal mode - import GPU frameworks as optional dependencies 

28 torch = optional_import("torch") 

29 cupy = optional_import("cupy") 

30 tensorflow = optional_import("tensorflow") 

31 jax = optional_import("jax") 

32 pyclesperanto = optional_import("pyclesperanto") 

33 

34# --- Cleanup functions --- 

35 

36def is_gpu_memory_type(memory_type: str) -> bool: 

37 """ 

38 Check if a memory type is a GPU memory type. 

39 

40 Args: 

41 memory_type: Memory type string 

42 

43 Returns: 

44 True if it's a GPU memory type, False otherwise 

45 """ 

46 # Using VALID_GPU_MEMORY_TYPES directly after top-level import 

47 # If openhcs.constants.constants is itself optional, then this function 

48 # might need to revert to its try-except, or ensure that constants are 

49 # always available for core utilities. Assuming it's always available now. 

50 return memory_type in VALID_GPU_MEMORY_TYPES 

51 

52 

53def cleanup_pytorch_gpu(device_id: Optional[int] = None) -> None: 

54 """ 

55 Clean up PyTorch GPU memory. 

56  

57 Args: 

58 device_id: Optional GPU device ID. If None, cleans all devices. 

59 """ 

60 if torch is None: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true

61 logger.debug("PyTorch not available, skipping PyTorch GPU cleanup") 

62 return 

63 

64 try: 

65 if not torch.cuda.is_available(): 65 ↛ anywhereline 65 didn't jump anywhere: it always raised an exception.

66 return 

67 

68 if device_id is not None: 

69 # Clean specific device 

70 with torch.cuda.device(device_id): 

71 torch.cuda.empty_cache() 

72 torch.cuda.synchronize() 

73 logger.debug(f"🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for device {device_id}") 

74 else: 

75 # Clean all devices 

76 torch.cuda.empty_cache() 

77 torch.cuda.synchronize() 

78 logger.debug("🔥 GPU CLEANUP: Cleared PyTorch CUDA cache for all devices") 

79 

80 except Exception as e: 

81 logger.warning(f"Failed to cleanup PyTorch GPU memory: {e}") 

82 

83 

84def cleanup_cupy_gpu(device_id: Optional[int] = None) -> None: 

85 """ 

86 Clean up CuPy GPU memory with aggressive defragmentation. 

87 

88 Args: 

89 device_id: Optional GPU device ID. If None, cleans current device. 

90 """ 

91 if cupy is None: 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true

92 logger.debug("CuPy not available, skipping CuPy GPU cleanup") 

93 return 

94 

95 try: 

96 if device_id is not None: 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was never true

97 # Clean specific device 

98 with cupy.cuda.Device(device_id): 

99 # Get memory info before cleanup 

100 mempool = cupy.get_default_memory_pool() 

101 used_before = mempool.used_bytes() 

102 

103 # Aggressive cleanup to defragment memory 

104 cupy.get_default_memory_pool().free_all_blocks() 

105 cupy.get_default_pinned_memory_pool().free_all_blocks() 

106 

107 # Force memory pool reset to defragment 

108 cupy.cuda.runtime.deviceSynchronize() 

109 

110 used_after = mempool.used_bytes() 

111 freed_mb = (used_before - used_after) / 1e6 

112 

113 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for device {device_id}, freed {freed_mb:.1f}MB") 

114 else: 

115 # Clean current device 

116 mempool = cupy.get_default_memory_pool() 

117 used_before = mempool.used_bytes() 

118 

119 # Aggressive cleanup to defragment memory 

120 cupy.get_default_memory_pool().free_all_blocks() 

121 cupy.get_default_pinned_memory_pool().free_all_blocks() 

122 

123 # Force memory pool reset to defragment 

124 cupy.cuda.runtime.deviceSynchronize() 

125 

126 used_after = mempool.used_bytes() 

127 freed_mb = (used_before - used_after) / 1e6 

128 

129 logger.debug(f"🔥 GPU CLEANUP: Cleared CuPy memory pools for current device, freed {freed_mb:.1f}MB") 

130 

131 except Exception as e: 

132 logger.warning(f"Failed to cleanup CuPy GPU memory: {e}") 

133 

134 

135def cleanup_tensorflow_gpu(device_id: Optional[int] = None) -> None: 

136 """ 

137 Clean up TensorFlow GPU memory. 

138  

139 Args: 

140 device_id: Optional GPU device ID. If None, cleans all devices. 

141 """ 

142 if tensorflow is None: 142 ↛ 143line 142 didn't jump to line 143 because the condition on line 142 was never true

143 logger.debug("TensorFlow not available, skipping TensorFlow GPU cleanup") 

144 return 

145 

146 try: 

147 # Get list of GPU devices 

148 gpus = tensorflow.config.list_physical_devices('GPU') 

149 if not gpus: 

150 return 

151 

152 if device_id is not None and device_id < len(gpus): 

153 # Clean specific device - TensorFlow doesn't have per-device cleanup 

154 # so we trigger garbage collection which helps with memory management 

155 import gc 

156 gc.collect() 

157 logger.debug(f"🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPU {device_id}") 

158 else: 

159 # Clean all devices - trigger garbage collection 

160 import gc 

161 gc.collect() 

162 logger.debug("🔥 GPU CLEANUP: Triggered garbage collection for TensorFlow GPUs") 

163 

164 except Exception as e: 

165 logger.warning(f"Failed to cleanup TensorFlow GPU memory: {e}") 

166 

167 

168def cleanup_jax_gpu(device_id: Optional[int] = None) -> None: 

169 """ 

170 Clean up JAX GPU memory. 

171 

172 Args: 

173 device_id: Optional GPU device ID. If None, cleans all devices. 

174 """ 

175 if jax is None: 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true

176 logger.debug("JAX not available, skipping JAX GPU cleanup") 

177 return 

178 

179 try: 

180 # JAX doesn't have explicit memory cleanup like PyTorch/CuPy 

181 # but we can trigger garbage collection and clear compilation cache 

182 import gc 

183 gc.collect() 

184 

185 # Clear JAX compilation cache which can hold GPU memory 

186 jax.clear_caches() 

187 

188 if device_id is not None: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 logger.debug(f"🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for device {device_id}") 

190 else: 

191 logger.debug("🔥 GPU CLEANUP: Cleared JAX caches and triggered GC for all devices") 

192 

193 except Exception as e: 

194 logger.warning(f"Failed to cleanup JAX GPU memory: {e}") 

195 

196 

197def cleanup_pyclesperanto_gpu(device_id: Optional[int] = None) -> None: 

198 """ 

199 Clean up pyclesperanto GPU memory. 

200 

201 Args: 

202 device_id: Optional GPU device ID. If None, cleans current device. 

203 """ 

204 if pyclesperanto is None: 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true

205 logger.debug("pyclesperanto not available, skipping pyclesperanto GPU cleanup") 

206 return 

207 

208 try: 

209 import gc 

210 

211 # pyclesperanto doesn't have explicit memory cleanup like PyTorch/CuPy 

212 # but we can trigger garbage collection and clear any cached data 

213 

214 if device_id is not None: 214 ↛ 216line 214 didn't jump to line 216 because the condition on line 214 was never true

215 # Select the specific device 

216 devices = pyclesperanto.list_available_devices() 

217 if device_id < len(devices): 

218 pyclesperanto.select_device(device_id) 

219 logger.debug(f"🔥 GPU CLEANUP: Selected pyclesperanto device {device_id}") 

220 else: 

221 logger.warning(f"🔥 GPU CLEANUP: Device {device_id} not available in pyclesperanto") 

222 

223 # Trigger garbage collection to clean up any unreferenced GPU arrays 

224 collected = gc.collect() 

225 

226 # pyclesperanto uses OpenCL which manages memory automatically 

227 # but we can help by ensuring Python objects are cleaned up 

228 if device_id is not None: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true

229 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto device {device_id}, collected {collected} objects") 

230 else: 

231 logger.debug(f"🔥 GPU CLEANUP: Triggered GC for pyclesperanto current device, collected {collected} objects") 

232 

233 except Exception as e: 

234 logger.warning(f"Failed to cleanup pyclesperanto GPU memory: {e}") 

235 

236 

237def cleanup_gpu_memory_by_framework(memory_type: str, device_id: Optional[int] = None) -> None: 

238 """ 

239 Clean up GPU memory based on the OpenHCS memory type. 

240 

241 Args: 

242 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy") 

243 device_id: Optional GPU device ID 

244 """ 

245 # Handle exact OpenHCS memory type values 

246 if memory_type == "torch": 

247 cleanup_pytorch_gpu(device_id) 

248 elif memory_type == "cupy": 

249 cleanup_cupy_gpu(device_id) 

250 elif memory_type == "tensorflow": 

251 cleanup_tensorflow_gpu(device_id) 

252 elif memory_type == "jax": 

253 cleanup_jax_gpu(device_id) 

254 elif memory_type == "pyclesperanto": 

255 cleanup_pyclesperanto_gpu(device_id) 

256 elif memory_type == "numpy": 

257 # CPU memory type - no GPU cleanup needed 

258 logger.debug(f"No GPU cleanup needed for CPU memory type: {memory_type}") 

259 else: 

260 # Fallback for unknown types - try pattern matching 

261 memory_type_lower = memory_type.lower() 

262 if "torch" in memory_type_lower or "pytorch" in memory_type_lower: 

263 cleanup_pytorch_gpu(device_id) 

264 elif "cupy" in memory_type_lower: 

265 cleanup_cupy_gpu(device_id) 

266 elif "tensorflow" in memory_type_lower or "tf" in memory_type_lower: 

267 cleanup_tensorflow_gpu(device_id) 

268 elif "jax" in memory_type_lower: 

269 cleanup_jax_gpu(device_id) 

270 elif "pyclesperanto" in memory_type_lower or "clesperanto" in memory_type_lower: 

271 cleanup_pyclesperanto_gpu(device_id) 

272 else: 

273 logger.debug(f"Unknown memory type for GPU cleanup: {memory_type}") 

274 

275 

276def cleanup_numpy_noop(device_id: Optional[int] = None) -> None: 

277 """ 

278 No-op cleanup for numpy (CPU memory type). 

279 

280 Args: 

281 device_id: Optional GPU device ID (ignored for CPU) 

282 """ 

283 logger.debug("🔥 GPU CLEANUP: No-op for numpy (CPU memory type)") 

284 

285 

286def cleanup_all_gpu_frameworks(device_id: Optional[int] = None) -> None: 

287 """ 

288 Clean up GPU memory for all available frameworks. 

289 

290 Args: 

291 device_id: Optional GPU device ID 

292 """ 

293 cleanup_pytorch_gpu(device_id) 

294 cleanup_cupy_gpu(device_id) 

295 cleanup_tensorflow_gpu(device_id) 

296 cleanup_jax_gpu(device_id) 

297 cleanup_pyclesperanto_gpu(device_id) 

298 

299 # Also trigger Python garbage collection 

300 import gc 

301 gc.collect() 

302 

303 logger.debug("🔥 GPU CLEANUP: Performed comprehensive cleanup for all GPU frameworks") 

304 

305 

306# Registry mapping memory types to their cleanup functions 

307MEMORY_TYPE_CLEANUP_REGISTRY = { 

308 "torch": cleanup_pytorch_gpu, 

309 "cupy": cleanup_cupy_gpu, 

310 "tensorflow": cleanup_tensorflow_gpu, 

311 "jax": cleanup_jax_gpu, 

312 "pyclesperanto": cleanup_pyclesperanto_gpu, 

313 "numpy": cleanup_numpy_noop, 

314} 

315 

316 

317def cleanup_memory_by_type(memory_type: str, device_id: Optional[int] = None) -> None: 

318 """ 

319 Clean up memory using the registered cleanup function for the memory type. 

320 

321 Args: 

322 memory_type: OpenHCS memory type string ("torch", "cupy", "tensorflow", "jax", "numpy") 

323 device_id: Optional GPU device ID 

324 """ 

325 cleanup_func = MEMORY_TYPE_CLEANUP_REGISTRY.get(memory_type) 

326 

327 if cleanup_func: 

328 cleanup_func(device_id) 

329 else: 

330 logger.warning(f"No cleanup function registered for memory type: {memory_type}") 

331 logger.debug(f"Available memory types: {list(MEMORY_TYPE_CLEANUP_REGISTRY.keys())}") 

332 

333 

334def check_gpu_memory_usage() -> None: 

335 """ 

336 Check and log current GPU memory usage for all available frameworks. 

337 

338 This is a utility function for debugging memory issues. 

339 """ 

340 logger.debug("🔍 GPU Memory Usage Report:") 

341 

342 # Check PyTorch 

343 if torch is not None: 

344 if torch.cuda.is_available(): 

345 for i in range(torch.cuda.device_count()): 

346 allocated = torch.cuda.memory_allocated(i) / 1024**3 

347 reserved = torch.cuda.memory_reserved(i) / 1024**3 

348 logger.debug(f" PyTorch GPU {i}: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") 

349 else: 

350 logger.debug(" PyTorch: No CUDA available") 

351 else: 

352 logger.debug(" PyTorch: Not installed") 

353 

354 # Check CuPy 

355 if cupy is not None: 

356 mempool = cupy.get_default_memory_pool() 

357 used_bytes = mempool.used_bytes() 

358 total_bytes = mempool.total_bytes() 

359 logger.debug(f" CuPy: {used_bytes / 1024**3:.2f}GB used, {total_bytes / 1024**3:.2f}GB total") 

360 else: 

361 logger.debug(" CuPy: Not installed") # Added missing log for consistency 

362 

363 # Note: TensorFlow and JAX don't have easy memory introspection 

364 logger.debug(" TensorFlow/JAX: Memory usage not easily queryable. Check if installed:") 

365 if tensorflow is None: 

366 logger.debug(" TensorFlow: Not installed") 

367 if jax is None: 

368 logger.debug(" JAX: Not installed") 

369 

370 

371def log_gpu_memory_usage(context: str = "") -> None: 

372 """ 

373 Log GPU memory usage with a specific context for tracking. 

374 

375 Args: 

376 context: Description of when/where this memory check is happening 

377 """ 

378 context_str = f" ({context})" if context else "" 

379 

380 if torch is not None: 

381 try: # Keep try-except for runtime CUDA availability check 

382 if torch.cuda.is_available(): 

383 for i in range(torch.cuda.device_count()): 

384 allocated = torch.cuda.memory_allocated(i) / 1024**3 

385 reserved = torch.cuda.memory_reserved(i) / 1024**3 

386 free_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 - reserved 

387 logger.debug(f"🔍 VRAM{context_str} GPU {i}: {allocated:.2f}GB alloc, {reserved:.2f}GB reserved, {free_memory:.2f}GB free") 

388 else: 

389 logger.debug(f"🔍 VRAM{context_str}: No CUDA available") 

390 except Exception as e: 

391 logger.warning(f"🔍 VRAM{context_str}: Error checking PyTorch memory - {e}") 

392 else: 

393 logger.debug(f"🔍 VRAM{context_str}: PyTorch not available") 

394 

395 

396def get_gpu_memory_summary() -> dict: 

397 """ 

398 Get GPU memory usage as a dictionary for programmatic use. 

399 

400 Returns: 

401 Dictionary with memory usage information 

402 """ 

403 memory_info = { 

404 "pytorch": {"available": False, "devices": []}, 

405 "cupy": {"available": False, "used_gb": 0, "total_gb": 0} 

406 } 

407 

408 # Check PyTorch 

409 if torch is not None: 

410 try: # Keep try-except for runtime CUDA availability check 

411 if torch.cuda.is_available(): 

412 memory_info["pytorch"]["available"] = True 

413 for i in range(torch.cuda.device_count()): 

414 allocated = torch.cuda.memory_allocated(i) / 1024**3 

415 reserved = torch.cuda.memory_reserved(i) / 1024**3 

416 total = torch.cuda.get_device_properties(i).total_memory / 1024**3 

417 memory_info["pytorch"]["devices"].append({ 

418 "device_id": i, 

419 "allocated_gb": allocated, 

420 "reserved_gb": reserved, 

421 "total_gb": total, 

422 "free_gb": total - reserved 

423 }) 

424 except Exception: # Catch exceptions related to CUDA operations if available 

425 pass # Suppress specific error details if main check is for availability 

426 

427 # Check CuPy 

428 if cupy is not None: 

429 memory_info["cupy"]["available"] = True 

430 mempool = cupy.get_default_memory_pool() 

431 memory_info["cupy"]["used_gb"] = mempool.used_bytes() / 1024**3 

432 memory_info["cupy"]["total_gb"] = mempool.total_bytes() / 1024**3 

433 

434 return memory_info 

435 

436 

437def force_comprehensive_cleanup() -> None: 

438 """ 

439 Force comprehensive GPU cleanup across all frameworks and trigger garbage collection. 

440 

441 This is the nuclear option for clearing GPU memory when you suspect leaks. 

442 """ 

443 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Starting nuclear cleanup...") 

444 

445 # Clean all GPU frameworks 

446 cleanup_all_gpu_frameworks() 

447 

448 # Multiple rounds of garbage collection 

449 import gc 

450 for i in range(3): 

451 collected = gc.collect() 

452 logger.debug(f"🧹 Garbage collection round {i+1}: collected {collected} objects") 

453 

454 # Check memory usage after cleanup 

455 check_gpu_memory_usage() 

456 

457 logger.debug("🧹 FORCE COMPREHENSIVE CLEANUP: Complete")