Coverage for openhcs/processing/backends/enhance/cupy_clahe.py: 6.8%

295 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1from __future__ import annotations 

2 

3import logging 

4 

5from openhcs.core.memory.decorators import cupy as cupy_func 

6from openhcs.core.utils import optional_import 

7 

8# Import CuPy as an optional dependency 

9cp = optional_import("cupy") 

10ndimage = None 

11if cp is not None: 11 ↛ 16line 11 didn't jump to line 16 because the condition on line 11 was always true

12 cupyx_scipy = optional_import("cupyx.scipy") 

13 if cupyx_scipy is not None: 13 ↛ 16line 13 didn't jump to line 16 because the condition on line 13 was always true

14 ndimage = cupyx_scipy.ndimage 

15 

16logger = logging.getLogger(__name__) 

17 

18@cupy_func 

19def clahe_2d( 

20 image: "cp.ndarray", 

21 clip_limit: float = 2.0, 

22 tile_grid_size: tuple = None, 

23 nbins: int = None, 

24 adaptive_bins: bool = True, 

25 adaptive_tiles: bool = True 

26) -> "cp.ndarray": 

27 """ 

28 Optimized 2D CLAHE with vectorized bilinear interpolation. 

29 """ 

30 

31 result = cp.zeros_like(image) 

32 

33 for z in range(image.shape[0]): 

34 slice_2d = image[z] 

35 height, width = slice_2d.shape 

36 

37 # Adaptive parameters 

38 if nbins is None: 

39 if adaptive_bins: 

40 data_range = float(cp.max(slice_2d) - cp.min(slice_2d)) 

41 adaptive_nbins = min(512, max(64, int(cp.sqrt(data_range)))) 

42 else: 

43 adaptive_nbins = 256 

44 else: 

45 adaptive_nbins = nbins 

46 

47 if tile_grid_size is None: 

48 if adaptive_tiles: 

49 target_tile_size = 80 

50 adaptive_tile_rows = max(2, min(16, height // target_tile_size)) 

51 adaptive_tile_cols = max(2, min(16, width // target_tile_size)) 

52 adaptive_tile_grid = (adaptive_tile_rows, adaptive_tile_cols) 

53 else: 

54 adaptive_tile_grid = (8, 8) 

55 else: 

56 adaptive_tile_grid = tile_grid_size 

57 

58 result[z] = _clahe_2d_vectorized( 

59 slice_2d, clip_limit, adaptive_tile_grid, adaptive_nbins 

60 ) 

61 

62 return result 

63 

64def _clahe_2d_vectorized( 

65 image: "cp.ndarray", 

66 clip_limit: float, 

67 tile_grid_size: tuple, 

68 nbins: int 

69) -> "cp.ndarray": 

70 """ 

71 Vectorized CLAHE implementation for 2D images. 

72 """ 

73 if image.ndim != 2: 

74 raise ValueError("Input must be 2D array") 

75 

76 height, width = image.shape 

77 tile_rows, tile_cols = tile_grid_size 

78 

79 # Calculate tile dimensions 

80 tile_height = height // tile_rows 

81 tile_width = width // tile_cols 

82 

83 # Ensure we have valid tiles 

84 if tile_height < 1 or tile_width < 1: 

85 raise ValueError(f"Image too small for {tile_rows}x{tile_cols} tiles") 

86 

87 # Calculate crop dimensions 

88 crop_height = tile_height * tile_rows 

89 crop_width = tile_width * tile_cols 

90 image_crop = image[:crop_height, :crop_width] 

91 

92 # Calculate actual clip limit 

93 actual_clip_limit = max(1, int(clip_limit * tile_height * tile_width / nbins)) 

94 

95 # Get value range 

96 min_val = float(cp.min(image_crop)) 

97 max_val = float(cp.max(image_crop)) 

98 

99 if max_val <= min_val: 

100 return image.astype(image.dtype) # Constant image 

101 

102 # Compute tile CDFs 

103 tile_cdfs = _compute_tile_cdfs_2d( 

104 image_crop, tile_rows, tile_cols, tile_height, tile_width, 

105 nbins, actual_clip_limit, min_val, max_val 

106 ) 

107 

108 # Apply vectorized interpolation 

109 result = _apply_vectorized_interpolation_2d( 

110 image_crop, tile_cdfs, tile_rows, tile_cols, 

111 tile_height, tile_width, nbins, min_val, max_val 

112 ) 

113 

114 # Handle original image size 

115 if result.shape != image.shape: 

116 full_result = cp.zeros_like(image, dtype=result.dtype) 

117 full_result[:crop_height, :crop_width] = result 

118 

119 # Fill remaining areas by replication 

120 if crop_height < height: 

121 full_result[crop_height:, :crop_width] = result[-1:, :] 

122 if crop_width < width: 

123 full_result[:crop_height, crop_width:] = result[:, -1:] 

124 if crop_height < height and crop_width < width: 

125 full_result[crop_height:, crop_width:] = result[-1, -1] 

126 result = full_result 

127 

128 return result.astype(image.dtype) 

129 

130def _compute_tile_cdfs_2d( 

131 image: "cp.ndarray", 

132 tile_rows: int, 

133 tile_cols: int, 

134 tile_height: int, 

135 tile_width: int, 

136 nbins: int, 

137 clip_limit: int, 

138 min_val: float, 

139 max_val: float 

140) -> "cp.ndarray": 

141 """Compute CDFs for all tiles efficiently.""" 

142 

143 tile_cdfs = cp.zeros((tile_rows, tile_cols, nbins), dtype=cp.float32) 

144 

145 # Precompute bin edges 

146 bin_edges = cp.linspace(min_val, max_val, nbins + 1, dtype=cp.float32) 

147 bin_width = (max_val - min_val) / nbins 

148 

149 for row in range(tile_rows): 

150 for col in range(tile_cols): 

151 # Extract tile 

152 y_start = row * tile_height 

153 y_end = (row + 1) * tile_height 

154 x_start = col * tile_width 

155 x_end = (col + 1) * tile_width 

156 

157 tile = image[y_start:y_end, x_start:x_end] 

158 

159 # Compute histogram 

160 hist, _ = cp.histogram(tile, bins=bin_edges) 

161 

162 # Clip and redistribute 

163 hist = _clip_histogram_optimized(hist, clip_limit) 

164 

165 # Compute CDF and normalize properly 

166 cdf = cp.cumsum(hist, dtype=cp.float32) 

167 if cdf[-1] > 0: 

168 # Normalize to [0, 1] then scale to output range 

169 cdf = cdf / cdf[-1] 

170 # Map to intensity values (proper CLAHE transformation) 

171 tile_cdfs[row, col, :] = min_val + cdf * (max_val - min_val) 

172 else: 

173 tile_cdfs[row, col, :] = min_val 

174 

175 return tile_cdfs 

176 

177def _apply_vectorized_interpolation_2d( 

178 image: "cp.ndarray", 

179 tile_cdfs: "cp.ndarray", 

180 tile_rows: int, 

181 tile_cols: int, 

182 tile_height: int, 

183 tile_width: int, 

184 nbins: int, 

185 min_val: float, 

186 max_val: float 

187) -> "cp.ndarray": 

188 """Vectorized bilinear interpolation.""" 

189 

190 height, width = image.shape 

191 

192 # Create coordinate grids 

193 y_coords, x_coords = cp.meshgrid( 

194 cp.arange(height, dtype=cp.float32), 

195 cp.arange(width, dtype=cp.float32), 

196 indexing='ij' 

197 ) 

198 

199 # Calculate tile centers 

200 tile_centers_y = cp.arange(tile_rows, dtype=cp.float32) * tile_height + tile_height // 2 

201 tile_centers_x = cp.arange(tile_cols, dtype=cp.float32) * tile_width + tile_width // 2 

202 

203 # Find surrounding tiles for each pixel (vectorized) 

204 tile_y_low = cp.searchsorted(tile_centers_y, y_coords.flatten()) - 1 

205 tile_x_low = cp.searchsorted(tile_centers_x, x_coords.flatten()) - 1 

206 

207 # Clamp to valid ranges 

208 tile_y_low = cp.clip(tile_y_low, 0, tile_rows - 2).reshape(height, width) 

209 tile_x_low = cp.clip(tile_x_low, 0, tile_cols - 2).reshape(height, width) 

210 

211 tile_y_high = tile_y_low + 1 

212 tile_x_high = tile_x_low + 1 

213 

214 # Convert pixel values to bin indices (vectorized) 

215 normalized_values = (image - min_val) / (max_val - min_val) 

216 bin_indices = cp.clip( 

217 (normalized_values * (nbins - 1)).astype(cp.int32), 

218 0, nbins - 1 

219 ) 

220 

221 # Calculate interpolation weights (vectorized) 

222 center_y_low = tile_centers_y[tile_y_low] 

223 center_y_high = tile_centers_y[tile_y_high] 

224 center_x_low = tile_centers_x[tile_x_low] 

225 center_x_high = tile_centers_x[tile_x_high] 

226 

227 # Avoid division by zero 

228 dy = center_y_high - center_y_low 

229 dx = center_x_high - center_x_low 

230 

231 wy = cp.where(dy > 0, (y_coords - center_y_low) / dy, 0.0) 

232 wx = cp.where(dx > 0, (x_coords - center_x_low) / dx, 0.0) 

233 

234 # Clamp weights 

235 wy = cp.clip(wy, 0.0, 1.0) 

236 wx = cp.clip(wx, 0.0, 1.0) 

237 

238 # Get transformation values (this is the tricky part - need advanced indexing) 

239 val_tl = tile_cdfs[tile_y_low, tile_x_low, bin_indices] 

240 val_tr = tile_cdfs[tile_y_low, tile_x_high, bin_indices] 

241 val_bl = tile_cdfs[tile_y_high, tile_x_low, bin_indices] 

242 val_br = tile_cdfs[tile_y_high, tile_x_high, bin_indices] 

243 

244 # Bilinear interpolation (vectorized) 

245 val_top = (1 - wx) * val_tl + wx * val_tr 

246 val_bottom = (1 - wx) * val_bl + wx * val_br 

247 result = (1 - wy) * val_top + wy * val_bottom 

248 

249 return result 

250 

251def _clip_histogram_optimized(hist: "cp.ndarray", clip_limit: int) -> "cp.ndarray": 

252 """Optimized histogram clipping.""" 

253 if clip_limit <= 0: 

254 return hist 

255 

256 # Convert to float for precise calculations 

257 hist_float = hist.astype(cp.float32) 

258 

259 # Find excess and clip 

260 excess = cp.maximum(hist_float - clip_limit, 0) 

261 total_excess = cp.sum(excess) 

262 

263 clipped_hist = cp.minimum(hist_float, clip_limit) 

264 

265 # Redistribute excess uniformly 

266 if total_excess > 0: 

267 nbins = len(hist) 

268 redistribution = total_excess / nbins 

269 clipped_hist += redistribution 

270 

271 # Handle overflow after redistribution (iterative clipping) 

272 for _ in range(3): # Max 3 iterations should be enough 

273 overflow = cp.maximum(clipped_hist - clip_limit, 0) 

274 total_overflow = cp.sum(overflow) 

275 

276 if total_overflow < 1e-6: 

277 break 

278 

279 clipped_hist = cp.minimum(clipped_hist, clip_limit) 

280 # Redistribute overflow to non-saturated bins 

281 non_saturated = clipped_hist < clip_limit 

282 if cp.any(non_saturated): 

283 available_space = cp.sum(cp.maximum(clip_limit - clipped_hist, 0)) 

284 if available_space > 0: 

285 redistrib_factor = min(1.0, total_overflow / available_space) 

286 clipped_hist += cp.where( 

287 non_saturated, 

288 redistrib_factor * cp.maximum(clip_limit - clipped_hist, 0), 

289 0 

290 ) 

291 

292 return clipped_hist.astype(hist.dtype) 

293 

294@cupy_func 

295def clahe_3d( 

296 stack: "cp.ndarray", 

297 clip_limit: float = 2.0, 

298 tile_grid_size_3d: tuple = None, 

299 nbins: int = None, 

300 adaptive_bins: bool = True, 

301 adaptive_tiles: bool = True, 

302 memory_efficient: bool = True 

303) -> "cp.ndarray": 

304 """ 

305 Optimized 3D CLAHE with vectorized trilinear interpolation. 

306  

307 Args: 

308 stack: 3D CuPy array of shape (Z, Y, X) 

309 clip_limit: Threshold for contrast limiting 

310 tile_grid_size_3d: Number of tiles (z_tiles, y_tiles, x_tiles) 

311 nbins: Number of histogram bins 

312 adaptive_bins: Whether to adapt bins based on data range 

313 adaptive_tiles: Whether to adapt tile size based on volume dimensions 

314 memory_efficient: Use chunked processing for large volumes 

315 """ 

316 

317 depth, height, width = stack.shape 

318 

319 # Adaptive parameters 

320 if nbins is None: 

321 if adaptive_bins: 

322 data_range = float(cp.max(stack) - cp.min(stack)) 

323 adaptive_nbins = min(512, max(128, int(cp.cbrt(data_range * 64)))) 

324 else: 

325 adaptive_nbins = 256 

326 else: 

327 adaptive_nbins = nbins 

328 

329 if tile_grid_size_3d is None: 

330 if adaptive_tiles: 

331 target_tile_size = 48 

332 adaptive_z_tiles = max(1, min(depth // 4, depth // target_tile_size)) 

333 adaptive_y_tiles = max(2, min(8, height // target_tile_size)) 

334 adaptive_x_tiles = max(2, min(8, width // target_tile_size)) 

335 adaptive_tile_grid_3d = (adaptive_z_tiles, adaptive_y_tiles, adaptive_x_tiles) 

336 else: 

337 adaptive_tile_grid_3d = (max(1, depth // 8), 4, 4) 

338 else: 

339 adaptive_tile_grid_3d = tile_grid_size_3d 

340 

341 # Check memory requirements and use chunked processing if needed 

342 total_voxels = depth * height * width 

343 if memory_efficient and total_voxels > 512**3: # ~134M voxels threshold 

344 return _clahe_3d_chunked(stack, clip_limit, adaptive_tile_grid_3d, adaptive_nbins) 

345 else: 

346 return _clahe_3d_vectorized(stack, clip_limit, adaptive_tile_grid_3d, adaptive_nbins) 

347 

348def _clahe_3d_vectorized( 

349 stack: "cp.ndarray", 

350 clip_limit: float, 

351 tile_grid_size_3d: tuple, 

352 nbins: int 

353) -> "cp.ndarray": 

354 """ 

355 Full vectorized 3D CLAHE implementation. 

356 """ 

357 depth, height, width = stack.shape 

358 tile_z, tile_y, tile_x = tile_grid_size_3d 

359 

360 # Calculate 3D tile dimensions 

361 tile_depth = max(1, depth // tile_z) 

362 tile_height = max(4, height // tile_y) 

363 tile_width = max(4, width // tile_x) 

364 

365 # Ensure valid tiles 

366 if tile_depth < 1 or tile_height < 1 or tile_width < 1: 

367 raise ValueError(f"Volume too small for {tile_z}x{tile_y}x{tile_x} tiles") 

368 

369 # Recalculate actual number of tiles 

370 actual_tile_z = depth // tile_depth 

371 actual_tile_y = height // tile_height 

372 actual_tile_x = width // tile_width 

373 

374 # Calculate crop dimensions 

375 crop_depth = tile_depth * actual_tile_z 

376 crop_height = tile_height * actual_tile_y 

377 crop_width = tile_width * actual_tile_x 

378 stack_crop = stack[:crop_depth, :crop_height, :crop_width] 

379 

380 # Calculate actual clip limit 

381 voxels_per_tile = tile_depth * tile_height * tile_width 

382 actual_clip_limit = max(1, int(clip_limit * voxels_per_tile / nbins)) 

383 

384 # Get value range 

385 min_val = float(cp.min(stack_crop)) 

386 max_val = float(cp.max(stack_crop)) 

387 

388 if max_val <= min_val: 

389 return stack.astype(stack.dtype) # Constant volume 

390 

391 # Compute 3D tile CDFs 

392 tile_cdfs = _compute_tile_cdfs_3d( 

393 stack_crop, actual_tile_z, actual_tile_y, actual_tile_x, 

394 tile_depth, tile_height, tile_width, 

395 nbins, actual_clip_limit, min_val, max_val 

396 ) 

397 

398 # Apply vectorized trilinear interpolation 

399 result = _apply_vectorized_trilinear_interpolation( 

400 stack_crop, tile_cdfs, actual_tile_z, actual_tile_y, actual_tile_x, 

401 tile_depth, tile_height, tile_width, nbins, min_val, max_val 

402 ) 

403 

404 # Handle original stack size 

405 if result.shape != stack.shape: 

406 full_result = cp.zeros_like(stack, dtype=result.dtype) 

407 full_result[:crop_depth, :crop_height, :crop_width] = result 

408 

409 # Fill remaining regions efficiently 

410 _fill_3d_boundaries(full_result, result, crop_depth, crop_height, crop_width, 

411 depth, height, width) 

412 result = full_result 

413 

414 return result.astype(stack.dtype) 

415 

416def _compute_tile_cdfs_3d( 

417 stack: "cp.ndarray", 

418 tile_z: int, 

419 tile_y: int, 

420 tile_x: int, 

421 tile_depth: int, 

422 tile_height: int, 

423 tile_width: int, 

424 nbins: int, 

425 clip_limit: int, 

426 min_val: float, 

427 max_val: float 

428) -> "cp.ndarray": 

429 """Compute CDFs for all 3D tiles efficiently.""" 

430 

431 tile_cdfs = cp.zeros((tile_z, tile_y, tile_x, nbins), dtype=cp.float32) 

432 

433 # Precompute bin edges 

434 bin_edges = cp.linspace(min_val, max_val, nbins + 1, dtype=cp.float32) 

435 

436 for z_idx in range(tile_z): 

437 for y_idx in range(tile_y): 

438 for x_idx in range(tile_x): 

439 # Extract 3D tile 

440 z_start = z_idx * tile_depth 

441 z_end = (z_idx + 1) * tile_depth 

442 y_start = y_idx * tile_height 

443 y_end = (y_idx + 1) * tile_height 

444 x_start = x_idx * tile_width 

445 x_end = (x_idx + 1) * tile_width 

446 

447 tile_3d = stack[z_start:z_end, y_start:y_end, x_start:x_end] 

448 

449 # Compute 3D histogram efficiently 

450 hist, _ = cp.histogram(tile_3d.ravel(), bins=bin_edges) 

451 

452 # Clip and redistribute 

453 hist = _clip_histogram_optimized(hist, clip_limit) 

454 

455 # Compute CDF and normalize properly 

456 cdf = cp.cumsum(hist, dtype=cp.float32) 

457 if cdf[-1] > 0: 

458 # Normalize to [0, 1] then scale to output range 

459 cdf = cdf / cdf[-1] 

460 tile_cdfs[z_idx, y_idx, x_idx, :] = min_val + cdf * (max_val - min_val) 

461 else: 

462 tile_cdfs[z_idx, y_idx, x_idx, :] = min_val 

463 

464 return tile_cdfs 

465 

466def _apply_vectorized_trilinear_interpolation( 

467 stack: "cp.ndarray", 

468 tile_cdfs: "cp.ndarray", 

469 tile_z: int, 

470 tile_y: int, 

471 tile_x: int, 

472 tile_depth: int, 

473 tile_height: int, 

474 tile_width: int, 

475 nbins: int, 

476 min_val: float, 

477 max_val: float 

478) -> "cp.ndarray": 

479 """Vectorized trilinear interpolation for 3D CLAHE.""" 

480 

481 depth, height, width = stack.shape 

482 

483 # Create 3D coordinate grids 

484 z_coords, y_coords, x_coords = cp.meshgrid( 

485 cp.arange(depth, dtype=cp.float32), 

486 cp.arange(height, dtype=cp.float32), 

487 cp.arange(width, dtype=cp.float32), 

488 indexing='ij' 

489 ) 

490 

491 # Calculate tile centers 

492 tile_centers_z = cp.arange(tile_z, dtype=cp.float32) * tile_depth + tile_depth // 2 

493 tile_centers_y = cp.arange(tile_y, dtype=cp.float32) * tile_height + tile_height // 2 

494 tile_centers_x = cp.arange(tile_x, dtype=cp.float32) * tile_width + tile_width // 2 

495 

496 # Find surrounding tiles for each voxel (vectorized) 

497 total_voxels = depth * height * width 

498 coords_flat = cp.column_stack([ 

499 z_coords.ravel(), 

500 y_coords.ravel(), 

501 x_coords.ravel() 

502 ]) 

503 

504 # Use searchsorted to find tile indices 

505 tile_z_low = cp.searchsorted(tile_centers_z, coords_flat[:, 0]) - 1 

506 tile_y_low = cp.searchsorted(tile_centers_y, coords_flat[:, 1]) - 1 

507 tile_x_low = cp.searchsorted(tile_centers_x, coords_flat[:, 2]) - 1 

508 

509 # Clamp to valid ranges 

510 tile_z_low = cp.clip(tile_z_low, 0, tile_z - 2).reshape(depth, height, width) 

511 tile_y_low = cp.clip(tile_y_low, 0, tile_y - 2).reshape(depth, height, width) 

512 tile_x_low = cp.clip(tile_x_low, 0, tile_x - 2).reshape(depth, height, width) 

513 

514 # Handle edge case for single tile in z-dimension 

515 if tile_z == 1: 

516 tile_z_low = cp.zeros_like(tile_z_low) 

517 

518 tile_z_high = cp.minimum(tile_z_low + 1, tile_z - 1) 

519 tile_y_high = tile_y_low + 1 

520 tile_x_high = tile_x_low + 1 

521 

522 # Convert voxel values to bin indices (vectorized) 

523 normalized_values = (stack - min_val) / (max_val - min_val) 

524 bin_indices = cp.clip( 

525 (normalized_values * (nbins - 1)).astype(cp.int32), 

526 0, nbins - 1 

527 ) 

528 

529 # Calculate interpolation weights (vectorized) 

530 center_z_low = tile_centers_z[tile_z_low] 

531 center_z_high = tile_centers_z[tile_z_high] 

532 center_y_low = tile_centers_y[tile_y_low] 

533 center_y_high = tile_centers_y[tile_y_high] 

534 center_x_low = tile_centers_x[tile_x_low] 

535 center_x_high = tile_centers_x[tile_x_high] 

536 

537 # Avoid division by zero 

538 dz = center_z_high - center_z_low 

539 dy = center_y_high - center_y_low 

540 dx = center_x_high - center_x_low 

541 

542 wz = cp.where(dz > 0, (z_coords - center_z_low) / dz, 0.0) 

543 wy = cp.where(dy > 0, (y_coords - center_y_low) / dy, 0.0) 

544 wx = cp.where(dx > 0, (x_coords - center_x_low) / dx, 0.0) 

545 

546 # Clamp weights 

547 wz = cp.clip(wz, 0.0, 1.0) 

548 wy = cp.clip(wy, 0.0, 1.0) 

549 wx = cp.clip(wx, 0.0, 1.0) 

550 

551 # Get the 8 surrounding transformation values using advanced indexing 

552 val_000 = tile_cdfs[tile_z_low, tile_y_low, tile_x_low, bin_indices] 

553 val_001 = tile_cdfs[tile_z_low, tile_y_low, tile_x_high, bin_indices] 

554 val_010 = tile_cdfs[tile_z_low, tile_y_high, tile_x_low, bin_indices] 

555 val_011 = tile_cdfs[tile_z_low, tile_y_high, tile_x_high, bin_indices] 

556 val_100 = tile_cdfs[tile_z_high, tile_y_low, tile_x_low, bin_indices] 

557 val_101 = tile_cdfs[tile_z_high, tile_y_low, tile_x_high, bin_indices] 

558 val_110 = tile_cdfs[tile_z_high, tile_y_high, tile_x_low, bin_indices] 

559 val_111 = tile_cdfs[tile_z_high, tile_y_high, tile_x_high, bin_indices] 

560 

561 # Trilinear interpolation (vectorized) 

562 # First interpolate along x-axis 

563 val_00 = (1 - wx) * val_000 + wx * val_001 # front-bottom 

564 val_01 = (1 - wx) * val_010 + wx * val_011 # front-top 

565 val_10 = (1 - wx) * val_100 + wx * val_101 # back-bottom 

566 val_11 = (1 - wx) * val_110 + wx * val_111 # back-top 

567 

568 # Then interpolate along y-axis 

569 val_0 = (1 - wy) * val_00 + wy * val_01 # front face 

570 val_1 = (1 - wy) * val_10 + wy * val_11 # back face 

571 

572 # Finally interpolate along z-axis 

573 result = (1 - wz) * val_0 + wz * val_1 

574 

575 return result 

576 

577def _clahe_3d_chunked( 

578 stack: "cp.ndarray", 

579 clip_limit: float, 

580 tile_grid_size_3d: tuple, 

581 nbins: int, 

582 chunk_size: int = 128 

583) -> "cp.ndarray": 

584 """ 

585 Memory-efficient chunked processing for very large 3D volumes. 

586  

587 Processes the volume in overlapping chunks to manage memory usage. 

588 """ 

589 depth, height, width = stack.shape 

590 result = cp.zeros_like(stack) 

591 

592 # Calculate overlap needed for smooth transitions 

593 tile_z, tile_y, tile_x = tile_grid_size_3d 

594 tile_depth = max(1, depth // tile_z) 

595 overlap = tile_depth // 2 

596 

597 # Process volume in z-chunks 

598 for z_start in range(0, depth, chunk_size - overlap): 

599 z_end = min(z_start + chunk_size, depth) 

600 

601 # Extract chunk with context 

602 chunk_start = max(0, z_start - overlap) 

603 chunk_end = min(depth, z_end + overlap) 

604 

605 chunk = stack[chunk_start:chunk_end, :, :] 

606 

607 # Adjust tile grid for chunk 

608 chunk_depth = chunk_end - chunk_start 

609 chunk_tile_z = max(1, min(tile_z, chunk_depth // tile_depth)) 

610 chunk_tile_grid = (chunk_tile_z, tile_y, tile_x) 

611 

612 # Process chunk 

613 chunk_result = _clahe_3d_vectorized( 

614 chunk, clip_limit, chunk_tile_grid, nbins 

615 ) 

616 

617 # Extract the relevant part (without overlap) 

618 extract_start = z_start - chunk_start 

619 extract_end = extract_start + (z_end - z_start) 

620 

621 result[z_start:z_end, :, :] = chunk_result[extract_start:extract_end, :, :] 

622 

623 return result 

624 

625def _fill_3d_boundaries( 

626 full_result: "cp.ndarray", 

627 cropped_result: "cp.ndarray", 

628 crop_depth: int, 

629 crop_height: int, 

630 crop_width: int, 

631 depth: int, 

632 height: int, 

633 width: int 

634) -> None: 

635 """Efficiently fill boundary regions by replicating edge values.""" 

636 

637 # Fill z-direction boundaries 

638 if crop_depth < depth: 

639 full_result[crop_depth:, :crop_height, :crop_width] = cropped_result[-1:, :, :] 

640 

641 # Fill y-direction boundaries 

642 if crop_height < height: 

643 full_result[:crop_depth, crop_height:, :crop_width] = cropped_result[:, -1:, :] 

644 if crop_depth < depth: 

645 full_result[crop_depth:, crop_height:, :crop_width] = cropped_result[-1:, -1:, :] 

646 

647 # Fill x-direction boundaries 

648 if crop_width < width: 

649 full_result[:crop_depth, :crop_height, crop_width:] = cropped_result[:, :, -1:] 

650 if crop_height < height: 

651 full_result[:crop_depth, crop_height:, crop_width:] = cropped_result[:, -1:, -1:] 

652 if crop_depth < depth: 

653 full_result[crop_depth:, :crop_height, crop_width:] = cropped_result[-1:, :, -1:] 

654 if crop_depth < depth and crop_height < height: 

655 full_result[crop_depth:, crop_height:, crop_width:] = cropped_result[-1, -1, -1]