Coverage for openhcs/processing/backends/enhance/cupy_clahe.py: 7.0%

296 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2 

3import logging 

4from typing import Any, List, Optional, Tuple 

5 

6from openhcs.core.memory.decorators import cupy as cupy_func 

7from openhcs.core.utils import optional_import 

8 

9# Import CuPy as an optional dependency 

10cp = optional_import("cupy") 

11ndimage = None 

12if cp is not None: 12 ↛ 17line 12 didn't jump to line 17 because the condition on line 12 was always true

13 cupyx_scipy = optional_import("cupyx.scipy") 

14 if cupyx_scipy is not None: 14 ↛ 17line 14 didn't jump to line 17 because the condition on line 14 was always true

15 ndimage = cupyx_scipy.ndimage 

16 

17logger = logging.getLogger(__name__) 

18 

19@cupy_func 

20def clahe_2d( 

21 image: "cp.ndarray", 

22 clip_limit: float = 2.0, 

23 tile_grid_size: tuple = None, 

24 nbins: int = None, 

25 adaptive_bins: bool = True, 

26 adaptive_tiles: bool = True 

27) -> "cp.ndarray": 

28 """ 

29 Optimized 2D CLAHE with vectorized bilinear interpolation. 

30 """ 

31 

32 result = cp.zeros_like(image) 

33 

34 for z in range(image.shape[0]): 

35 slice_2d = image[z] 

36 height, width = slice_2d.shape 

37 

38 # Adaptive parameters 

39 if nbins is None: 

40 if adaptive_bins: 

41 data_range = float(cp.max(slice_2d) - cp.min(slice_2d)) 

42 adaptive_nbins = min(512, max(64, int(cp.sqrt(data_range)))) 

43 else: 

44 adaptive_nbins = 256 

45 else: 

46 adaptive_nbins = nbins 

47 

48 if tile_grid_size is None: 

49 if adaptive_tiles: 

50 target_tile_size = 80 

51 adaptive_tile_rows = max(2, min(16, height // target_tile_size)) 

52 adaptive_tile_cols = max(2, min(16, width // target_tile_size)) 

53 adaptive_tile_grid = (adaptive_tile_rows, adaptive_tile_cols) 

54 else: 

55 adaptive_tile_grid = (8, 8) 

56 else: 

57 adaptive_tile_grid = tile_grid_size 

58 

59 result[z] = _clahe_2d_vectorized( 

60 slice_2d, clip_limit, adaptive_tile_grid, adaptive_nbins 

61 ) 

62 

63 return result 

64 

65def _clahe_2d_vectorized( 

66 image: "cp.ndarray", 

67 clip_limit: float, 

68 tile_grid_size: tuple, 

69 nbins: int 

70) -> "cp.ndarray": 

71 """ 

72 Vectorized CLAHE implementation for 2D images. 

73 """ 

74 if image.ndim != 2: 

75 raise ValueError("Input must be 2D array") 

76 

77 height, width = image.shape 

78 tile_rows, tile_cols = tile_grid_size 

79 

80 # Calculate tile dimensions 

81 tile_height = height // tile_rows 

82 tile_width = width // tile_cols 

83 

84 # Ensure we have valid tiles 

85 if tile_height < 1 or tile_width < 1: 

86 raise ValueError(f"Image too small for {tile_rows}x{tile_cols} tiles") 

87 

88 # Calculate crop dimensions 

89 crop_height = tile_height * tile_rows 

90 crop_width = tile_width * tile_cols 

91 image_crop = image[:crop_height, :crop_width] 

92 

93 # Calculate actual clip limit 

94 actual_clip_limit = max(1, int(clip_limit * tile_height * tile_width / nbins)) 

95 

96 # Get value range 

97 min_val = float(cp.min(image_crop)) 

98 max_val = float(cp.max(image_crop)) 

99 

100 if max_val <= min_val: 

101 return image.astype(image.dtype) # Constant image 

102 

103 # Compute tile CDFs 

104 tile_cdfs = _compute_tile_cdfs_2d( 

105 image_crop, tile_rows, tile_cols, tile_height, tile_width, 

106 nbins, actual_clip_limit, min_val, max_val 

107 ) 

108 

109 # Apply vectorized interpolation 

110 result = _apply_vectorized_interpolation_2d( 

111 image_crop, tile_cdfs, tile_rows, tile_cols, 

112 tile_height, tile_width, nbins, min_val, max_val 

113 ) 

114 

115 # Handle original image size 

116 if result.shape != image.shape: 

117 full_result = cp.zeros_like(image, dtype=result.dtype) 

118 full_result[:crop_height, :crop_width] = result 

119 

120 # Fill remaining areas by replication 

121 if crop_height < height: 

122 full_result[crop_height:, :crop_width] = result[-1:, :] 

123 if crop_width < width: 

124 full_result[:crop_height, crop_width:] = result[:, -1:] 

125 if crop_height < height and crop_width < width: 

126 full_result[crop_height:, crop_width:] = result[-1, -1] 

127 result = full_result 

128 

129 return result.astype(image.dtype) 

130 

131def _compute_tile_cdfs_2d( 

132 image: "cp.ndarray", 

133 tile_rows: int, 

134 tile_cols: int, 

135 tile_height: int, 

136 tile_width: int, 

137 nbins: int, 

138 clip_limit: int, 

139 min_val: float, 

140 max_val: float 

141) -> "cp.ndarray": 

142 """Compute CDFs for all tiles efficiently.""" 

143 

144 tile_cdfs = cp.zeros((tile_rows, tile_cols, nbins), dtype=cp.float32) 

145 

146 # Precompute bin edges 

147 bin_edges = cp.linspace(min_val, max_val, nbins + 1, dtype=cp.float32) 

148 bin_width = (max_val - min_val) / nbins 

149 

150 for row in range(tile_rows): 

151 for col in range(tile_cols): 

152 # Extract tile 

153 y_start = row * tile_height 

154 y_end = (row + 1) * tile_height 

155 x_start = col * tile_width 

156 x_end = (col + 1) * tile_width 

157 

158 tile = image[y_start:y_end, x_start:x_end] 

159 

160 # Compute histogram 

161 hist, _ = cp.histogram(tile, bins=bin_edges) 

162 

163 # Clip and redistribute 

164 hist = _clip_histogram_optimized(hist, clip_limit) 

165 

166 # Compute CDF and normalize properly 

167 cdf = cp.cumsum(hist, dtype=cp.float32) 

168 if cdf[-1] > 0: 

169 # Normalize to [0, 1] then scale to output range 

170 cdf = cdf / cdf[-1] 

171 # Map to intensity values (proper CLAHE transformation) 

172 tile_cdfs[row, col, :] = min_val + cdf * (max_val - min_val) 

173 else: 

174 tile_cdfs[row, col, :] = min_val 

175 

176 return tile_cdfs 

177 

178def _apply_vectorized_interpolation_2d( 

179 image: "cp.ndarray", 

180 tile_cdfs: "cp.ndarray", 

181 tile_rows: int, 

182 tile_cols: int, 

183 tile_height: int, 

184 tile_width: int, 

185 nbins: int, 

186 min_val: float, 

187 max_val: float 

188) -> "cp.ndarray": 

189 """Vectorized bilinear interpolation.""" 

190 

191 height, width = image.shape 

192 

193 # Create coordinate grids 

194 y_coords, x_coords = cp.meshgrid( 

195 cp.arange(height, dtype=cp.float32), 

196 cp.arange(width, dtype=cp.float32), 

197 indexing='ij' 

198 ) 

199 

200 # Calculate tile centers 

201 tile_centers_y = cp.arange(tile_rows, dtype=cp.float32) * tile_height + tile_height // 2 

202 tile_centers_x = cp.arange(tile_cols, dtype=cp.float32) * tile_width + tile_width // 2 

203 

204 # Find surrounding tiles for each pixel (vectorized) 

205 tile_y_low = cp.searchsorted(tile_centers_y, y_coords.flatten()) - 1 

206 tile_x_low = cp.searchsorted(tile_centers_x, x_coords.flatten()) - 1 

207 

208 # Clamp to valid ranges 

209 tile_y_low = cp.clip(tile_y_low, 0, tile_rows - 2).reshape(height, width) 

210 tile_x_low = cp.clip(tile_x_low, 0, tile_cols - 2).reshape(height, width) 

211 

212 tile_y_high = tile_y_low + 1 

213 tile_x_high = tile_x_low + 1 

214 

215 # Convert pixel values to bin indices (vectorized) 

216 normalized_values = (image - min_val) / (max_val - min_val) 

217 bin_indices = cp.clip( 

218 (normalized_values * (nbins - 1)).astype(cp.int32), 

219 0, nbins - 1 

220 ) 

221 

222 # Calculate interpolation weights (vectorized) 

223 center_y_low = tile_centers_y[tile_y_low] 

224 center_y_high = tile_centers_y[tile_y_high] 

225 center_x_low = tile_centers_x[tile_x_low] 

226 center_x_high = tile_centers_x[tile_x_high] 

227 

228 # Avoid division by zero 

229 dy = center_y_high - center_y_low 

230 dx = center_x_high - center_x_low 

231 

232 wy = cp.where(dy > 0, (y_coords - center_y_low) / dy, 0.0) 

233 wx = cp.where(dx > 0, (x_coords - center_x_low) / dx, 0.0) 

234 

235 # Clamp weights 

236 wy = cp.clip(wy, 0.0, 1.0) 

237 wx = cp.clip(wx, 0.0, 1.0) 

238 

239 # Get transformation values (this is the tricky part - need advanced indexing) 

240 val_tl = tile_cdfs[tile_y_low, tile_x_low, bin_indices] 

241 val_tr = tile_cdfs[tile_y_low, tile_x_high, bin_indices] 

242 val_bl = tile_cdfs[tile_y_high, tile_x_low, bin_indices] 

243 val_br = tile_cdfs[tile_y_high, tile_x_high, bin_indices] 

244 

245 # Bilinear interpolation (vectorized) 

246 val_top = (1 - wx) * val_tl + wx * val_tr 

247 val_bottom = (1 - wx) * val_bl + wx * val_br 

248 result = (1 - wy) * val_top + wy * val_bottom 

249 

250 return result 

251 

252def _clip_histogram_optimized(hist: "cp.ndarray", clip_limit: int) -> "cp.ndarray": 

253 """Optimized histogram clipping.""" 

254 if clip_limit <= 0: 

255 return hist 

256 

257 # Convert to float for precise calculations 

258 hist_float = hist.astype(cp.float32) 

259 

260 # Find excess and clip 

261 excess = cp.maximum(hist_float - clip_limit, 0) 

262 total_excess = cp.sum(excess) 

263 

264 clipped_hist = cp.minimum(hist_float, clip_limit) 

265 

266 # Redistribute excess uniformly 

267 if total_excess > 0: 

268 nbins = len(hist) 

269 redistribution = total_excess / nbins 

270 clipped_hist += redistribution 

271 

272 # Handle overflow after redistribution (iterative clipping) 

273 for _ in range(3): # Max 3 iterations should be enough 

274 overflow = cp.maximum(clipped_hist - clip_limit, 0) 

275 total_overflow = cp.sum(overflow) 

276 

277 if total_overflow < 1e-6: 

278 break 

279 

280 clipped_hist = cp.minimum(clipped_hist, clip_limit) 

281 # Redistribute overflow to non-saturated bins 

282 non_saturated = clipped_hist < clip_limit 

283 if cp.any(non_saturated): 

284 available_space = cp.sum(cp.maximum(clip_limit - clipped_hist, 0)) 

285 if available_space > 0: 

286 redistrib_factor = min(1.0, total_overflow / available_space) 

287 clipped_hist += cp.where( 

288 non_saturated, 

289 redistrib_factor * cp.maximum(clip_limit - clipped_hist, 0), 

290 0 

291 ) 

292 

293 return clipped_hist.astype(hist.dtype) 

294 

295@cupy_func 

296def clahe_3d( 

297 stack: "cp.ndarray", 

298 clip_limit: float = 2.0, 

299 tile_grid_size_3d: tuple = None, 

300 nbins: int = None, 

301 adaptive_bins: bool = True, 

302 adaptive_tiles: bool = True, 

303 memory_efficient: bool = True 

304) -> "cp.ndarray": 

305 """ 

306 Optimized 3D CLAHE with vectorized trilinear interpolation. 

307  

308 Args: 

309 stack: 3D CuPy array of shape (Z, Y, X) 

310 clip_limit: Threshold for contrast limiting 

311 tile_grid_size_3d: Number of tiles (z_tiles, y_tiles, x_tiles) 

312 nbins: Number of histogram bins 

313 adaptive_bins: Whether to adapt bins based on data range 

314 adaptive_tiles: Whether to adapt tile size based on volume dimensions 

315 memory_efficient: Use chunked processing for large volumes 

316 """ 

317 

318 depth, height, width = stack.shape 

319 

320 # Adaptive parameters 

321 if nbins is None: 

322 if adaptive_bins: 

323 data_range = float(cp.max(stack) - cp.min(stack)) 

324 adaptive_nbins = min(512, max(128, int(cp.cbrt(data_range * 64)))) 

325 else: 

326 adaptive_nbins = 256 

327 else: 

328 adaptive_nbins = nbins 

329 

330 if tile_grid_size_3d is None: 

331 if adaptive_tiles: 

332 target_tile_size = 48 

333 adaptive_z_tiles = max(1, min(depth // 4, depth // target_tile_size)) 

334 adaptive_y_tiles = max(2, min(8, height // target_tile_size)) 

335 adaptive_x_tiles = max(2, min(8, width // target_tile_size)) 

336 adaptive_tile_grid_3d = (adaptive_z_tiles, adaptive_y_tiles, adaptive_x_tiles) 

337 else: 

338 adaptive_tile_grid_3d = (max(1, depth // 8), 4, 4) 

339 else: 

340 adaptive_tile_grid_3d = tile_grid_size_3d 

341 

342 # Check memory requirements and use chunked processing if needed 

343 total_voxels = depth * height * width 

344 if memory_efficient and total_voxels > 512**3: # ~134M voxels threshold 

345 return _clahe_3d_chunked(stack, clip_limit, adaptive_tile_grid_3d, adaptive_nbins) 

346 else: 

347 return _clahe_3d_vectorized(stack, clip_limit, adaptive_tile_grid_3d, adaptive_nbins) 

348 

349def _clahe_3d_vectorized( 

350 stack: "cp.ndarray", 

351 clip_limit: float, 

352 tile_grid_size_3d: tuple, 

353 nbins: int 

354) -> "cp.ndarray": 

355 """ 

356 Full vectorized 3D CLAHE implementation. 

357 """ 

358 depth, height, width = stack.shape 

359 tile_z, tile_y, tile_x = tile_grid_size_3d 

360 

361 # Calculate 3D tile dimensions 

362 tile_depth = max(1, depth // tile_z) 

363 tile_height = max(4, height // tile_y) 

364 tile_width = max(4, width // tile_x) 

365 

366 # Ensure valid tiles 

367 if tile_depth < 1 or tile_height < 1 or tile_width < 1: 

368 raise ValueError(f"Volume too small for {tile_z}x{tile_y}x{tile_x} tiles") 

369 

370 # Recalculate actual number of tiles 

371 actual_tile_z = depth // tile_depth 

372 actual_tile_y = height // tile_height 

373 actual_tile_x = width // tile_width 

374 

375 # Calculate crop dimensions 

376 crop_depth = tile_depth * actual_tile_z 

377 crop_height = tile_height * actual_tile_y 

378 crop_width = tile_width * actual_tile_x 

379 stack_crop = stack[:crop_depth, :crop_height, :crop_width] 

380 

381 # Calculate actual clip limit 

382 voxels_per_tile = tile_depth * tile_height * tile_width 

383 actual_clip_limit = max(1, int(clip_limit * voxels_per_tile / nbins)) 

384 

385 # Get value range 

386 min_val = float(cp.min(stack_crop)) 

387 max_val = float(cp.max(stack_crop)) 

388 

389 if max_val <= min_val: 

390 return stack.astype(stack.dtype) # Constant volume 

391 

392 # Compute 3D tile CDFs 

393 tile_cdfs = _compute_tile_cdfs_3d( 

394 stack_crop, actual_tile_z, actual_tile_y, actual_tile_x, 

395 tile_depth, tile_height, tile_width, 

396 nbins, actual_clip_limit, min_val, max_val 

397 ) 

398 

399 # Apply vectorized trilinear interpolation 

400 result = _apply_vectorized_trilinear_interpolation( 

401 stack_crop, tile_cdfs, actual_tile_z, actual_tile_y, actual_tile_x, 

402 tile_depth, tile_height, tile_width, nbins, min_val, max_val 

403 ) 

404 

405 # Handle original stack size 

406 if result.shape != stack.shape: 

407 full_result = cp.zeros_like(stack, dtype=result.dtype) 

408 full_result[:crop_depth, :crop_height, :crop_width] = result 

409 

410 # Fill remaining regions efficiently 

411 _fill_3d_boundaries(full_result, result, crop_depth, crop_height, crop_width, 

412 depth, height, width) 

413 result = full_result 

414 

415 return result.astype(stack.dtype) 

416 

417def _compute_tile_cdfs_3d( 

418 stack: "cp.ndarray", 

419 tile_z: int, 

420 tile_y: int, 

421 tile_x: int, 

422 tile_depth: int, 

423 tile_height: int, 

424 tile_width: int, 

425 nbins: int, 

426 clip_limit: int, 

427 min_val: float, 

428 max_val: float 

429) -> "cp.ndarray": 

430 """Compute CDFs for all 3D tiles efficiently.""" 

431 

432 tile_cdfs = cp.zeros((tile_z, tile_y, tile_x, nbins), dtype=cp.float32) 

433 

434 # Precompute bin edges 

435 bin_edges = cp.linspace(min_val, max_val, nbins + 1, dtype=cp.float32) 

436 

437 for z_idx in range(tile_z): 

438 for y_idx in range(tile_y): 

439 for x_idx in range(tile_x): 

440 # Extract 3D tile 

441 z_start = z_idx * tile_depth 

442 z_end = (z_idx + 1) * tile_depth 

443 y_start = y_idx * tile_height 

444 y_end = (y_idx + 1) * tile_height 

445 x_start = x_idx * tile_width 

446 x_end = (x_idx + 1) * tile_width 

447 

448 tile_3d = stack[z_start:z_end, y_start:y_end, x_start:x_end] 

449 

450 # Compute 3D histogram efficiently 

451 hist, _ = cp.histogram(tile_3d.ravel(), bins=bin_edges) 

452 

453 # Clip and redistribute 

454 hist = _clip_histogram_optimized(hist, clip_limit) 

455 

456 # Compute CDF and normalize properly 

457 cdf = cp.cumsum(hist, dtype=cp.float32) 

458 if cdf[-1] > 0: 

459 # Normalize to [0, 1] then scale to output range 

460 cdf = cdf / cdf[-1] 

461 tile_cdfs[z_idx, y_idx, x_idx, :] = min_val + cdf * (max_val - min_val) 

462 else: 

463 tile_cdfs[z_idx, y_idx, x_idx, :] = min_val 

464 

465 return tile_cdfs 

466 

467def _apply_vectorized_trilinear_interpolation( 

468 stack: "cp.ndarray", 

469 tile_cdfs: "cp.ndarray", 

470 tile_z: int, 

471 tile_y: int, 

472 tile_x: int, 

473 tile_depth: int, 

474 tile_height: int, 

475 tile_width: int, 

476 nbins: int, 

477 min_val: float, 

478 max_val: float 

479) -> "cp.ndarray": 

480 """Vectorized trilinear interpolation for 3D CLAHE.""" 

481 

482 depth, height, width = stack.shape 

483 

484 # Create 3D coordinate grids 

485 z_coords, y_coords, x_coords = cp.meshgrid( 

486 cp.arange(depth, dtype=cp.float32), 

487 cp.arange(height, dtype=cp.float32), 

488 cp.arange(width, dtype=cp.float32), 

489 indexing='ij' 

490 ) 

491 

492 # Calculate tile centers 

493 tile_centers_z = cp.arange(tile_z, dtype=cp.float32) * tile_depth + tile_depth // 2 

494 tile_centers_y = cp.arange(tile_y, dtype=cp.float32) * tile_height + tile_height // 2 

495 tile_centers_x = cp.arange(tile_x, dtype=cp.float32) * tile_width + tile_width // 2 

496 

497 # Find surrounding tiles for each voxel (vectorized) 

498 total_voxels = depth * height * width 

499 coords_flat = cp.column_stack([ 

500 z_coords.ravel(), 

501 y_coords.ravel(), 

502 x_coords.ravel() 

503 ]) 

504 

505 # Use searchsorted to find tile indices 

506 tile_z_low = cp.searchsorted(tile_centers_z, coords_flat[:, 0]) - 1 

507 tile_y_low = cp.searchsorted(tile_centers_y, coords_flat[:, 1]) - 1 

508 tile_x_low = cp.searchsorted(tile_centers_x, coords_flat[:, 2]) - 1 

509 

510 # Clamp to valid ranges 

511 tile_z_low = cp.clip(tile_z_low, 0, tile_z - 2).reshape(depth, height, width) 

512 tile_y_low = cp.clip(tile_y_low, 0, tile_y - 2).reshape(depth, height, width) 

513 tile_x_low = cp.clip(tile_x_low, 0, tile_x - 2).reshape(depth, height, width) 

514 

515 # Handle edge case for single tile in z-dimension 

516 if tile_z == 1: 

517 tile_z_low = cp.zeros_like(tile_z_low) 

518 

519 tile_z_high = cp.minimum(tile_z_low + 1, tile_z - 1) 

520 tile_y_high = tile_y_low + 1 

521 tile_x_high = tile_x_low + 1 

522 

523 # Convert voxel values to bin indices (vectorized) 

524 normalized_values = (stack - min_val) / (max_val - min_val) 

525 bin_indices = cp.clip( 

526 (normalized_values * (nbins - 1)).astype(cp.int32), 

527 0, nbins - 1 

528 ) 

529 

530 # Calculate interpolation weights (vectorized) 

531 center_z_low = tile_centers_z[tile_z_low] 

532 center_z_high = tile_centers_z[tile_z_high] 

533 center_y_low = tile_centers_y[tile_y_low] 

534 center_y_high = tile_centers_y[tile_y_high] 

535 center_x_low = tile_centers_x[tile_x_low] 

536 center_x_high = tile_centers_x[tile_x_high] 

537 

538 # Avoid division by zero 

539 dz = center_z_high - center_z_low 

540 dy = center_y_high - center_y_low 

541 dx = center_x_high - center_x_low 

542 

543 wz = cp.where(dz > 0, (z_coords - center_z_low) / dz, 0.0) 

544 wy = cp.where(dy > 0, (y_coords - center_y_low) / dy, 0.0) 

545 wx = cp.where(dx > 0, (x_coords - center_x_low) / dx, 0.0) 

546 

547 # Clamp weights 

548 wz = cp.clip(wz, 0.0, 1.0) 

549 wy = cp.clip(wy, 0.0, 1.0) 

550 wx = cp.clip(wx, 0.0, 1.0) 

551 

552 # Get the 8 surrounding transformation values using advanced indexing 

553 val_000 = tile_cdfs[tile_z_low, tile_y_low, tile_x_low, bin_indices] 

554 val_001 = tile_cdfs[tile_z_low, tile_y_low, tile_x_high, bin_indices] 

555 val_010 = tile_cdfs[tile_z_low, tile_y_high, tile_x_low, bin_indices] 

556 val_011 = tile_cdfs[tile_z_low, tile_y_high, tile_x_high, bin_indices] 

557 val_100 = tile_cdfs[tile_z_high, tile_y_low, tile_x_low, bin_indices] 

558 val_101 = tile_cdfs[tile_z_high, tile_y_low, tile_x_high, bin_indices] 

559 val_110 = tile_cdfs[tile_z_high, tile_y_high, tile_x_low, bin_indices] 

560 val_111 = tile_cdfs[tile_z_high, tile_y_high, tile_x_high, bin_indices] 

561 

562 # Trilinear interpolation (vectorized) 

563 # First interpolate along x-axis 

564 val_00 = (1 - wx) * val_000 + wx * val_001 # front-bottom 

565 val_01 = (1 - wx) * val_010 + wx * val_011 # front-top 

566 val_10 = (1 - wx) * val_100 + wx * val_101 # back-bottom 

567 val_11 = (1 - wx) * val_110 + wx * val_111 # back-top 

568 

569 # Then interpolate along y-axis 

570 val_0 = (1 - wy) * val_00 + wy * val_01 # front face 

571 val_1 = (1 - wy) * val_10 + wy * val_11 # back face 

572 

573 # Finally interpolate along z-axis 

574 result = (1 - wz) * val_0 + wz * val_1 

575 

576 return result 

577 

578def _clahe_3d_chunked( 

579 stack: "cp.ndarray", 

580 clip_limit: float, 

581 tile_grid_size_3d: tuple, 

582 nbins: int, 

583 chunk_size: int = 128 

584) -> "cp.ndarray": 

585 """ 

586 Memory-efficient chunked processing for very large 3D volumes. 

587  

588 Processes the volume in overlapping chunks to manage memory usage. 

589 """ 

590 depth, height, width = stack.shape 

591 result = cp.zeros_like(stack) 

592 

593 # Calculate overlap needed for smooth transitions 

594 tile_z, tile_y, tile_x = tile_grid_size_3d 

595 tile_depth = max(1, depth // tile_z) 

596 overlap = tile_depth // 2 

597 

598 # Process volume in z-chunks 

599 for z_start in range(0, depth, chunk_size - overlap): 

600 z_end = min(z_start + chunk_size, depth) 

601 

602 # Extract chunk with context 

603 chunk_start = max(0, z_start - overlap) 

604 chunk_end = min(depth, z_end + overlap) 

605 

606 chunk = stack[chunk_start:chunk_end, :, :] 

607 

608 # Adjust tile grid for chunk 

609 chunk_depth = chunk_end - chunk_start 

610 chunk_tile_z = max(1, min(tile_z, chunk_depth // tile_depth)) 

611 chunk_tile_grid = (chunk_tile_z, tile_y, tile_x) 

612 

613 # Process chunk 

614 chunk_result = _clahe_3d_vectorized( 

615 chunk, clip_limit, chunk_tile_grid, nbins 

616 ) 

617 

618 # Extract the relevant part (without overlap) 

619 extract_start = z_start - chunk_start 

620 extract_end = extract_start + (z_end - z_start) 

621 

622 result[z_start:z_end, :, :] = chunk_result[extract_start:extract_end, :, :] 

623 

624 return result 

625 

626def _fill_3d_boundaries( 

627 full_result: "cp.ndarray", 

628 cropped_result: "cp.ndarray", 

629 crop_depth: int, 

630 crop_height: int, 

631 crop_width: int, 

632 depth: int, 

633 height: int, 

634 width: int 

635) -> None: 

636 """Efficiently fill boundary regions by replicating edge values.""" 

637 

638 # Fill z-direction boundaries 

639 if crop_depth < depth: 

640 full_result[crop_depth:, :crop_height, :crop_width] = cropped_result[-1:, :, :] 

641 

642 # Fill y-direction boundaries 

643 if crop_height < height: 

644 full_result[:crop_depth, crop_height:, :crop_width] = cropped_result[:, -1:, :] 

645 if crop_depth < depth: 

646 full_result[crop_depth:, crop_height:, :crop_width] = cropped_result[-1:, -1:, :] 

647 

648 # Fill x-direction boundaries 

649 if crop_width < width: 

650 full_result[:crop_depth, :crop_height, crop_width:] = cropped_result[:, :, -1:] 

651 if crop_height < height: 

652 full_result[:crop_depth, crop_height:, crop_width:] = cropped_result[:, -1:, -1:] 

653 if crop_depth < depth: 

654 full_result[crop_depth:, :crop_height, crop_width:] = cropped_result[-1:, :, -1:] 

655 if crop_depth < depth and crop_height < height: 

656 full_result[crop_depth:, crop_height:, crop_width:] = cropped_result[-1, -1, -1]