Coverage for openhcs/processing/backends/processors/torch_processor.py: 9.9%

262 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2PyTorch Image Processor Implementation 

3 

4This module implements the ImageProcessorInterface using PyTorch as the backend. 

5It leverages GPU acceleration for image processing operations. 

6 

7Doctrinal Clauses: 

8- Clause 3 — Declarative Primacy: All functions are pure and stateless 

9- Clause 88 — No Inferred Capabilities: Explicit PyTorch dependency 

10- Clause 106-A — Declared Memory Types: All methods specify PyTorch tensors 

11""" 

12from __future__ import annotations 

13 

14import logging 

15from typing import Any, List, Optional, Tuple 

16 

17from openhcs.core.utils import optional_import 

18from openhcs.core.memory.decorators import torch as torch_func 

19 

20# Import PyTorch as an optional dependency 

21torch = optional_import("torch") 

22F = optional_import("torch.nn.functional") if torch is not None else None 

23HAS_TORCH = torch is not None 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "torch.Tensor": 

29 """ 

30 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center. 

31 

32 Args: 

33 height: Height of the mask 

34 width: Width of the mask 

35 margin_ratio: Ratio of the margin to the image size 

36 

37 Returns: 

38 2D PyTorch weight mask of shape (height, width) 

39 """ 

40 if torch is None: 

41 raise ImportError("PyTorch is required for TorchImageProcessor") 

42 

43 margin_y = int(torch.floor(torch.tensor(height * margin_ratio))) 

44 margin_x = int(torch.floor(torch.tensor(width * margin_ratio))) 

45 

46 weight_y = torch.ones(height, dtype=torch.float32) 

47 if margin_y > 0: 

48 ramp_top = torch.linspace(0, 1, margin_y, dtype=torch.float32) 

49 ramp_bottom = torch.linspace(1, 0, margin_y, dtype=torch.float32) 

50 weight_y[:margin_y] = ramp_top 

51 weight_y[-margin_y:] = ramp_bottom 

52 

53 weight_x = torch.ones(width, dtype=torch.float32) 

54 if margin_x > 0: 

55 ramp_left = torch.linspace(0, 1, margin_x, dtype=torch.float32) 

56 ramp_right = torch.linspace(1, 0, margin_x, dtype=torch.float32) 

57 weight_x[:margin_x] = ramp_left 

58 weight_x[-margin_x:] = ramp_right 

59 

60 # Create 2D weight mask using outer product 

61 weight_mask = torch.outer(weight_y, weight_x) 

62 

63 return weight_mask 

64 

65 

66def _validate_3d_array(array: Any, name: str = "input") -> None: 

67 """ 

68 Validate that the input is a 3D PyTorch tensor. 

69 

70 Args: 

71 array: Array to validate 

72 name: Name of the array for error messages 

73 

74 Raises: 

75 TypeError: If the array is not a PyTorch tensor 

76 ValueError: If the array is not 3D 

77 ImportError: If PyTorch is not available 

78 """ 

79 if torch is None: 

80 raise ImportError("PyTorch is required for TorchImageProcessor") 

81 

82 if not isinstance(array, torch.Tensor): 

83 raise TypeError(f"{name} must be a PyTorch tensor, got {type(array)}. " 

84 f"No automatic conversion is performed to maintain explicit contracts.") 

85 

86 if array.ndim != 3: 

87 raise ValueError(f"{name} must be a 3D tensor, got {array.ndim}D") 

88 

89def _gaussian_blur(image: "torch.Tensor", sigma: float) -> "torch.Tensor": 

90 """ 

91 Apply Gaussian blur to a 2D image. 

92 

93 Args: 

94 image: 2D PyTorch tensor of shape (H, W) 

95 sigma: Standard deviation of the Gaussian kernel 

96 

97 Returns: 

98 Blurred 2D PyTorch tensor of shape (H, W) 

99 """ 

100 # Calculate kernel size based on sigma 

101 kernel_size = max(3, int(2 * 4 * sigma + 1)) 

102 if kernel_size % 2 == 0: 

103 kernel_size += 1 # Ensure odd kernel size 

104 

105 # Create 1D Gaussian kernel 

106 coords = torch.arange(kernel_size, dtype=torch.float32, device=image.device) 

107 coords -= (kernel_size - 1) / 2 

108 

109 # Calculate Gaussian values 

110 gauss = torch.exp(-(coords**2) / (2 * sigma**2)) 

111 kernel = gauss / gauss.sum() 

112 

113 # Reshape for 2D convolution 

114 kernel_x = kernel.view(1, 1, kernel_size, 1) 

115 kernel_y = kernel.view(1, 1, 1, kernel_size) 

116 

117 # Add batch and channel dimensions to image 

118 img = image.unsqueeze(0).unsqueeze(0) 

119 

120 # Apply separable convolution 

121 blurred = F.conv2d(img, kernel_x, padding=(kernel_size//2, 0)) 

122 blurred = F.conv2d(blurred, kernel_y, padding=(0, kernel_size//2)) 

123 

124 # Remove batch and channel dimensions 

125 return blurred.squeeze(0).squeeze(0) 

126 

127@torch_func 

128def sharpen(image: "torch.Tensor", radius: float = 1.0, amount: float = 1.0) -> "torch.Tensor": 

129 """ 

130 Sharpen a 3D image using unsharp masking. 

131 

132 This applies sharpening to each Z-slice independently. 

133 

134 Args: 

135 image: 3D PyTorch tensor of shape (Z, Y, X) 

136 radius: Radius of Gaussian blur 

137 amount: Sharpening strength 

138 

139 Returns: 

140 Sharpened 3D PyTorch tensor of shape (Z, Y, X) 

141 """ 

142 _validate_3d_array(image) 

143 

144 # Store original dtype 

145 dtype = image.dtype 

146 

147 # Process each Z-slice independently 

148 result = torch.zeros_like(image, dtype=torch.float32) 

149 

150 for z in range(image.shape[0]): 

151 # Convert to float for processing 

152 slice_float_raw = image[z].float() 

153 slice_float = slice_float_raw / torch.max(slice_float_raw) 

154 

155 # Create blurred version for unsharp mask 

156 blurred = _gaussian_blur(slice_float, sigma=radius) 

157 

158 # Apply unsharp mask: original + amount * (original - blurred) 

159 sharpened = slice_float + amount * (slice_float - blurred) 

160 

161 # Clip to valid range 

162 sharpened = torch.clamp(sharpened, 0, 1.0) 

163 

164 # Scale back to original range 

165 min_val = torch.min(sharpened) 

166 max_val = torch.max(sharpened) 

167 if max_val > min_val: 

168 sharpened = (sharpened - min_val) * 65535 / (max_val - min_val) 

169 

170 result[z] = sharpened 

171 

172 # Convert back to original dtype 

173 if dtype == torch.uint16: 

174 result = torch.clamp(result, 0, 65535).to(torch.uint16) 

175 else: 

176 result = result.to(dtype) 

177 

178 return result 

179 

180@torch_func 

181def percentile_normalize( 

182 image: "torch.Tensor", 

183 low_percentile: float = 1.0, 

184 high_percentile: float = 99.0, 

185 target_min: float = 0.0, 

186 target_max: float = 65535.0 

187) -> "torch.Tensor": 

188 """ 

189 Normalize a 3D image using percentile-based contrast stretching. 

190 

191 This applies normalization to each Z-slice independently. 

192 

193 Args: 

194 image: 3D PyTorch tensor of shape (Z, Y, X) 

195 low_percentile: Lower percentile (0-100) 

196 high_percentile: Upper percentile (0-100) 

197 target_min: Target minimum value 

198 target_max: Target maximum value 

199 

200 Returns: 

201 Normalized 3D PyTorch tensor of shape (Z, Y, X) 

202 """ 

203 _validate_3d_array(image) 

204 

205 # Process each Z-slice independently 

206 result = torch.zeros_like(image, dtype=torch.float32) 

207 

208 for z in range(image.shape[0]): 

209 # Get percentile values for this slice 

210 # Handle large slices that exceed PyTorch's quantile() size limits 

211 slice_float = image[z].float() 

212 slice_elements = slice_float.numel() 

213 

214 # PyTorch quantile() fails on very large tensors, so we use sampling for large slices 

215 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile() 

216 

217 logger.debug(f"🔥 QUANTILE DEBUG: percentile_normalize slice {z} shape {image[z].shape}, {slice_elements:,} elements") 

218 

219 if slice_elements > max_elements_for_quantile: 

220 # Use random sampling for large slices to estimate percentiles 

221 sample_size = min(max_elements_for_quantile, slice_elements // 10) # Sample 10% or max size 

222 flat_slice = slice_float.flatten() 

223 

224 # Generate random indices for sampling (memory efficient) 

225 # Use torch.randint instead of torch.randperm to avoid creating huge tensors 

226 indices = torch.randint(0, slice_elements, (sample_size,), device=image.device) 

227 sampled_values = flat_slice[indices] 

228 

229 p_low = torch.quantile(sampled_values, low_percentile / 100.0) 

230 p_high = torch.quantile(sampled_values, high_percentile / 100.0) 

231 else: 

232 # Use full slice for smaller slices 

233 p_low = torch.quantile(slice_float, low_percentile / 100.0) 

234 p_high = torch.quantile(slice_float, high_percentile / 100.0) 

235 

236 # Avoid division by zero 

237 if p_high == p_low: 

238 result[z] = torch.ones_like(image[z], dtype=torch.float32) * target_min 

239 continue 

240 

241 # Clip and normalize to target range 

242 clipped = torch.clamp(image[z].float(), p_low, p_high) 

243 scale = (target_max - target_min) / (p_high - p_low) 

244 normalized = (clipped - p_low) * scale + target_min 

245 result[z] = normalized 

246 

247 # Convert to uint16 

248 result = torch.clamp(result, 0, 65535).to(torch.uint16) 

249 

250 return result 

251 

252@torch_func 

253def stack_percentile_normalize( 

254 stack: "torch.Tensor", 

255 low_percentile: float = 1.0, 

256 high_percentile: float = 99.0, 

257 target_min: float = 0.0, 

258 target_max: float = 65535.0 

259) -> "torch.Tensor": 

260 """ 

261 Normalize a stack using global percentile-based contrast stretching. 

262 

263 This ensures consistent normalization across all Z-slices by computing 

264 global percentiles across the entire stack. 

265 

266 Args: 

267 stack: 3D PyTorch tensor of shape (Z, Y, X) 

268 low_percentile: Lower percentile (0-100) 

269 high_percentile: Upper percentile (0-100) 

270 target_min: Target minimum value 

271 target_max: Target maximum value 

272 

273 Returns: 

274 Normalized 3D PyTorch tensor of shape (Z, Y, X) 

275 """ 

276 _validate_3d_array(stack) 

277 

278 # Calculate global percentiles across the entire stack 

279 # Handle large tensors that exceed PyTorch's quantile() size limits 

280 stack_float = stack.float() 

281 total_elements = stack_float.numel() 

282 

283 # PyTorch quantile() fails on very large tensors, so we use sampling for large stacks 

284 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile() 

285 

286 logger.debug(f"🔥 QUANTILE DEBUG: stack_percentile_normalize called with tensor shape {stack.shape}, {total_elements:,} elements") 

287 

288 if total_elements > max_elements_for_quantile: 

289 # Use random sampling for large tensors to estimate percentiles 

290 sample_size = min(max_elements_for_quantile, total_elements // 10) # Sample 10% or max size 

291 flat_stack = stack_float.flatten() 

292 

293 # Generate random indices for sampling (memory efficient) 

294 # Use torch.randint instead of torch.randperm to avoid creating huge tensors 

295 indices = torch.randint(0, total_elements, (sample_size,), device=stack.device) 

296 sampled_values = flat_stack[indices] 

297 

298 p_low = torch.quantile(sampled_values, low_percentile / 100.0) 

299 p_high = torch.quantile(sampled_values, high_percentile / 100.0) 

300 

301 logger.debug(f"Used sampling ({sample_size:,} of {total_elements:,} elements) for percentile calculation due to large tensor size") 

302 else: 

303 # Use full tensor for smaller stacks 

304 p_low = torch.quantile(stack_float, low_percentile / 100.0) 

305 p_high = torch.quantile(stack_float, high_percentile / 100.0) 

306 

307 # Avoid division by zero 

308 if p_high == p_low: 

309 return torch.ones_like(stack) * target_min 

310 

311 # Clip and normalize to target range (match NumPy implementation exactly) 

312 clipped = torch.clamp(stack, p_low, p_high) 

313 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min 

314 normalized = normalized.to(torch.uint16) 

315 

316 return normalized 

317 

318@torch_func 

319def create_composite( 

320 stack: "torch.Tensor", weights: Optional[List[float]] = None 

321) -> "torch.Tensor": 

322 """ 

323 Create a composite image from a 3D stack of 2D images. 

324 

325 Args: 

326 stack: 3D PyTorch tensor of shape (N, Y, X) where N is number of images 

327 weights: List of weights for each image. If None, equal weights are used. 

328 

329 Returns: 

330 Composite 3D PyTorch tensor of shape (1, Y, X) 

331 """ 

332 # Validate input is 3D tensor 

333 _validate_3d_array(stack) 

334 

335 n_images, height, width = stack.shape 

336 

337 # Default weights if none provided 

338 if weights is None: 

339 # Equal weights for all images 

340 weights = [1.0 / n_images] * n_images 

341 elif not isinstance(weights, list): 

342 raise TypeError("weights must be a list of values") 

343 

344 # FAIL FAST: No fallback weights - weights must match exactly 

345 if len(weights) != n_images: 

346 raise ValueError( 

347 f"Weights list length ({len(weights)}) must exactly match number of images ({n_images}). " 

348 f"No automatic padding or truncation allowed." 

349 ) 

350 

351 dtype = stack.dtype 

352 device = stack.device 

353 

354 # Create empty composite 

355 composite = torch.zeros((height, width), dtype=torch.float32, device=device) 

356 total_weight = 0.0 

357 

358 # Add each image with its weight 

359 for i in range(n_images): 

360 weight = weights[i] 

361 if weight <= 0.0: 

362 continue 

363 

364 # Add to composite 

365 composite += stack[i].float() * weight 

366 total_weight += weight 

367 

368 # Normalize by total weight 

369 if total_weight > 0: 

370 composite /= total_weight 

371 

372 # Convert back to original dtype (usually uint16) 

373 if dtype in [torch.uint8, torch.uint16, torch.uint32, torch.int8, torch.int16, torch.int32, torch.int64]: 

374 # Get the maximum value for the specific integer dtype 

375 if dtype == torch.uint8: 

376 max_val = 255 

377 elif dtype == torch.uint16: 

378 max_val = 65535 

379 elif dtype == torch.uint32: 

380 max_val = 4294967295 

381 elif dtype == torch.int8: 

382 max_val = 127 

383 elif dtype == torch.int16: 

384 max_val = 32767 

385 elif dtype == torch.int32: 

386 max_val = 2147483647 

387 elif dtype == torch.int64: 

388 max_val = 9223372036854775807 

389 

390 composite = torch.clamp(composite, 0, max_val).to(dtype) 

391 else: 

392 composite = composite.to(dtype) 

393 

394 # Return as 3D tensor with shape (1, Y, X) 

395 return composite.reshape(1, height, width) 

396 

397@torch_func 

398def apply_mask(image: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor": 

399 """ 

400 Apply a mask to a 3D image while maintaining 3D structure. 

401 

402 This applies the mask to each Z-slice independently if mask is 2D, 

403 or applies the 3D mask directly if mask is 3D. 

404 

405 Args: 

406 image: 3D PyTorch tensor of shape (Z, Y, X) 

407 mask: 3D PyTorch tensor of shape (Z, Y, X) or 2D PyTorch tensor of shape (Y, X) 

408 

409 Returns: 

410 Masked 3D PyTorch tensor of shape (Z, Y, X) - dimensionality preserved 

411 """ 

412 _validate_3d_array(image) 

413 

414 # Handle 2D mask (apply to each Z-slice) 

415 if isinstance(mask, torch.Tensor) and mask.ndim == 2: 

416 if mask.shape != image.shape[1:]: 

417 raise ValueError( 

418 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}" 

419 ) 

420 

421 # Apply 2D mask to each Z-slice 

422 result = torch.zeros_like(image) 

423 for z in range(image.shape[0]): 

424 result[z] = image[z].float() * mask.float() 

425 

426 return result.to(image.dtype) 

427 

428 # Handle 3D mask 

429 if isinstance(mask, torch.Tensor) and mask.ndim == 3: 

430 if mask.shape != image.shape: 

431 raise ValueError( 

432 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}" 

433 ) 

434 

435 # Apply 3D mask directly 

436 masked = image.float() * mask.float() 

437 return masked.to(image.dtype) 

438 

439 # If we get here, the mask is neither 2D nor 3D PyTorch tensor 

440 raise TypeError(f"mask must be a 2D or 3D PyTorch tensor, got {type(mask)}") 

441 

442@torch_func 

443def create_weight_mask( 

444 shape: Tuple[int, int], margin_ratio: float = 0.1 

445) -> "torch.Tensor": 

446 """ 

447 Create a weight mask for blending images. 

448 

449 Args: 

450 shape: Shape of the mask (height, width) 

451 margin_ratio: Ratio of image size to use as margin 

452 

453 Returns: 

454 2D PyTorch weight mask of shape (Y, X) 

455 """ 

456 if not isinstance(shape, tuple) or len(shape) != 2: 

457 raise TypeError("shape must be a tuple of (height, width)") 

458 

459 height, width = shape 

460 return create_linear_weight_mask(height, width, margin_ratio) 

461 

462@torch_func 

463def max_projection(stack: "torch.Tensor") -> "torch.Tensor": 

464 """ 

465 Create a maximum intensity projection from a Z-stack. 

466 

467 Args: 

468 stack: 3D PyTorch tensor of shape (Z, Y, X) 

469 

470 Returns: 

471 3D PyTorch tensor of shape (1, Y, X) 

472 """ 

473 _validate_3d_array(stack) 

474 

475 # Store original dtype for conversion back 

476 original_dtype = stack.dtype 

477 

478 # Convert to float32 if needed for GPU operations 

479 if stack.dtype == torch.uint16: 

480 stack_float = stack.float() 

481 else: 

482 stack_float = stack 

483 

484 # Create max projection 

485 projection_2d = torch.max(stack_float, dim=0)[0] 

486 

487 # Convert back to original dtype 

488 projection_2d = projection_2d.to(original_dtype) 

489 

490 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1]) 

491 

492@torch_func 

493def mean_projection(stack: "torch.Tensor") -> "torch.Tensor": 

494 """ 

495 Create a mean intensity projection from a Z-stack. 

496 

497 Args: 

498 stack: 3D PyTorch tensor of shape (Z, Y, X) 

499 

500 Returns: 

501 3D PyTorch tensor of shape (1, Y, X) 

502 """ 

503 _validate_3d_array(stack) 

504 

505 # Store original dtype for conversion back 

506 original_dtype = stack.dtype 

507 

508 # Convert to float32 for mean calculation (always needed for mean) 

509 stack_float = stack.float() 

510 

511 # Create mean projection 

512 projection_2d = torch.mean(stack_float, dim=0) 

513 

514 # Convert back to original dtype 

515 projection_2d = projection_2d.to(original_dtype) 

516 

517 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1]) 

518 

519@torch_func 

520def stack_equalize_histogram( 

521 stack: "torch.Tensor", 

522 bins: int = 65536, 

523 range_min: float = 0.0, 

524 range_max: float = 65535.0 

525) -> "torch.Tensor": 

526 """ 

527 Apply histogram equalization to an entire stack. 

528 

529 This ensures consistent contrast enhancement across all Z-slices by 

530 computing a global histogram across the entire stack. 

531 

532 Args: 

533 stack: 3D PyTorch tensor of shape (Z, Y, X) 

534 bins: Number of bins for histogram computation 

535 range_min: Minimum value for histogram range 

536 range_max: Maximum value for histogram range 

537 

538 Returns: 

539 Equalized 3D PyTorch tensor of shape (Z, Y, X) 

540 """ 

541 _validate_3d_array(stack) 

542 

543 # PyTorch doesn't have a direct histogram equalization function 

544 # We'll implement it manually using torch.histc for the histogram 

545 

546 # Flatten the entire stack to compute the global histogram 

547 flat_stack = stack.float().flatten() 

548 

549 # For very large stacks, use sampling to avoid memory issues 

550 max_elements_for_histogram = 50_000_000 # 50M elements limit 

551 if flat_stack.numel() > max_elements_for_histogram: 

552 # Use random sampling for histogram computation 

553 sample_size = max_elements_for_histogram 

554 indices = torch.randint(0, flat_stack.numel(), (sample_size,), device=stack.device) 

555 sampled_stack = flat_stack[indices] 

556 hist = torch.histc(sampled_stack, bins=bins, min=range_min, max=range_max) 

557 logger.debug(f"Used sampling ({sample_size:,} of {flat_stack.numel():,} elements) for histogram computation") 

558 else: 

559 # Use full stack for smaller stacks 

560 hist = torch.histc(flat_stack, bins=bins, min=range_min, max=range_max) 

561 

562 # We don't need bin edges for the lookup table approach 

563 

564 # Calculate cumulative distribution function (CDF) 

565 cdf = torch.cumsum(hist, dim=0) 

566 

567 # Normalize the CDF to the range [0, 65535] 

568 # Avoid division by zero 

569 if cdf[-1] > 0: 

570 cdf = 65535 * cdf / cdf[-1] 

571 

572 # PyTorch doesn't have a direct equivalent to numpy's interp 

573 # We'll use a lookup table approach 

574 

575 # Scale input values to bin indices 

576 indices = torch.clamp( 

577 ((flat_stack - range_min) / (range_max - range_min) * (bins - 1)).long(), 

578 0, bins - 1 

579 ) 

580 

581 # Look up CDF values 

582 equalized_flat = torch.gather(cdf, 0, indices) 

583 

584 # Reshape back to original shape 

585 equalized_stack = equalized_flat.reshape(stack.shape) 

586 

587 # Convert to uint16 

588 return equalized_stack.to(torch.uint16) 

589 

590@torch_func 

591def create_projection( 

592 stack: "torch.Tensor", method: str = "max_projection" 

593) -> "torch.Tensor": 

594 """ 

595 Create a projection from a stack using the specified method. 

596 

597 Args: 

598 stack: 3D PyTorch tensor of shape (Z, Y, X) 

599 method: Projection method (max_projection, mean_projection) 

600 

601 Returns: 

602 3D PyTorch tensor of shape (1, Y, X) 

603 """ 

604 _validate_3d_array(stack) 

605 

606 if method == "max_projection": 

607 return max_projection(stack) 

608 

609 if method == "mean_projection": 

610 return mean_projection(stack) 

611 

612 # FAIL FAST: No fallback projection methods 

613 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection") 

614 

615@torch_func 

616def tophat( 

617 image: "torch.Tensor", 

618 selem_radius: int = 50, 

619 downsample_factor: int = 4 

620) -> "torch.Tensor": 

621 """ 

622 Apply white top-hat filter to a 3D image for background removal. 

623 

624 This applies the filter to each Z-slice independently using PyTorch's 

625 native operations. 

626 

627 Args: 

628 image: 3D PyTorch tensor of shape (Z, Y, X) 

629 selem_radius: Radius of the structuring element disk 

630 downsample_factor: Factor by which to downsample the image for processing 

631 

632 Returns: 

633 Filtered 3D PyTorch tensor of shape (Z, Y, X) 

634 """ 

635 _validate_3d_array(image) 

636 

637 # Store device for later use 

638 device = image.device 

639 

640 # Process each Z-slice independently 

641 result = torch.zeros_like(image) 

642 

643 # We'll create structuring elements for each slice as needed 

644 

645 for z in range(image.shape[0]): 

646 # Store original data type 

647 input_dtype = image[z].dtype 

648 

649 # 1) Downsample using PyTorch's interpolate function 

650 # First, add batch and channel dimensions for interpolate 

651 img_4d = image[z].float().unsqueeze(0).unsqueeze(0) 

652 

653 # Calculate new dimensions 

654 new_h = image[z].shape[0] // downsample_factor 

655 new_w = image[z].shape[1] // downsample_factor 

656 

657 # Resize using PyTorch's interpolate function 

658 image_small = F.interpolate( 

659 img_4d, 

660 size=(new_h, new_w), 

661 mode='bilinear', 

662 align_corners=False 

663 ).squeeze(0).squeeze(0) 

664 

665 # 2) Resize the structuring element to match the downsampled image 

666 small_selem_radius = max(1, selem_radius // downsample_factor) 

667 small_grid_size = 2 * small_selem_radius + 1 

668 small_grid_y, small_grid_x = torch.meshgrid( 

669 torch.arange(small_grid_size, device=device) - small_selem_radius, 

670 torch.arange(small_grid_size, device=device) - small_selem_radius, 

671 indexing='ij' 

672 ) 

673 small_mask = (small_grid_x.pow(2) + small_grid_y.pow(2)) <= small_selem_radius**2 

674 small_selem = small_mask.float() 

675 

676 # 3) Apply white top-hat using PyTorch's convolution operations 

677 # White top-hat is opening subtracted from the original image 

678 # Opening is erosion followed by dilation 

679 

680 # Implement erosion using min pooling with custom kernel 

681 # First, pad the image to handle boundary conditions 

682 pad_size = small_selem_radius 

683 padded = F.pad( 

684 image_small.unsqueeze(0).unsqueeze(0), 

685 (pad_size, pad_size, pad_size, pad_size), 

686 mode='reflect' 

687 ) 

688 

689 # Unfold the padded image into patches 

690 patches = F.unfold(padded, kernel_size=small_grid_size, stride=1) 

691 

692 # Reshape patches for processing 

693 patch_size = small_grid_size * small_grid_size 

694 patches = patches.reshape(1, patch_size, new_h, new_w) 

695 

696 # Apply the structuring element as a mask 

697 masked_patches = patches * small_selem.reshape(-1, 1, 1) 

698 

699 # Perform erosion (min pooling) 

700 eroded = torch.min( 

701 masked_patches + (1 - small_selem.reshape(-1, 1, 1)) * 1e9, 

702 dim=1 

703 )[0] 

704 

705 # Implement dilation using max pooling with custom kernel 

706 # Pad the eroded image 

707 padded_eroded = F.pad( 

708 eroded.unsqueeze(0).unsqueeze(0), 

709 (pad_size, pad_size, pad_size, pad_size), 

710 mode='reflect' 

711 ) 

712 

713 # Unfold the padded eroded image into patches 

714 patches_eroded = F.unfold(padded_eroded, kernel_size=small_grid_size, stride=1) 

715 

716 # Reshape patches for processing 

717 patch_size = small_grid_size * small_grid_size 

718 patches_eroded = patches_eroded.reshape(1, patch_size, new_h, new_w) 

719 

720 # Apply the structuring element as a mask 

721 masked_patches_eroded = patches_eroded * small_selem.reshape(-1, 1, 1) 

722 

723 # Perform dilation (max pooling) 

724 opened = torch.max(masked_patches_eroded, dim=1)[0] 

725 

726 # White top-hat is original minus opening 

727 tophat_small = image_small - opened 

728 

729 # 4) Calculate background 

730 background_small = image_small - tophat_small 

731 

732 # 5) Upscale background to original size 

733 background_4d = background_small.unsqueeze(0).unsqueeze(0) 

734 background_large = F.interpolate( 

735 background_4d, 

736 size=image[z].shape, 

737 mode='bilinear', 

738 align_corners=False 

739 ).squeeze(0).squeeze(0) 

740 

741 # 6) Subtract background and clip negative values 

742 slice_result = torch.clamp(image[z].float() - background_large, min=0.0) 

743 

744 # 7) Convert back to original data type 

745 result[z] = slice_result.to(input_dtype) 

746 

747 return result