Coverage for openhcs/processing/backends/processors/torch_processor.py: 10.5%

268 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-01 18:33 +0000

1""" 

2PyTorch Image Processor Implementation 

3 

4This module implements the ImageProcessorInterface using PyTorch as the backend. 

5It leverages GPU acceleration for image processing operations. 

6 

7Doctrinal Clauses: 

8- Clause 3 — Declarative Primacy: All functions are pure and stateless 

9- Clause 88 — No Inferred Capabilities: Explicit PyTorch dependency 

10- Clause 106-A — Declared Memory Types: All methods specify PyTorch tensors 

11""" 

12from __future__ import annotations 

13 

14import logging 

15import os 

16from typing import Any, List, Optional, Tuple 

17 

18from openhcs.core.utils import optional_import 

19from openhcs.core.memory.decorators import torch as torch_func 

20 

21logger = logging.getLogger(__name__) 

22 

23# Check if we're in subprocess runner mode and should skip GPU imports 

24if os.getenv('OPENHCS_SUBPROCESS_NO_GPU') == '1': 24 ↛ 26line 24 didn't jump to line 26 because the condition on line 24 was never true

25 # Subprocess runner mode - skip GPU imports 

26 torch = None 

27 F = None 

28 HAS_TORCH = False 

29 logger.info("Subprocess runner mode - skipping torch import") 

30else: 

31 # Normal mode - import PyTorch as an optional dependency 

32 torch = optional_import("torch") 

33 F = optional_import("torch.nn.functional") if torch is not None else None 

34 HAS_TORCH = torch is not None 

35 

36 

37def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "torch.Tensor": 

38 """ 

39 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center. 

40 

41 Args: 

42 height: Height of the mask 

43 width: Width of the mask 

44 margin_ratio: Ratio of the margin to the image size 

45 

46 Returns: 

47 2D PyTorch weight mask of shape (height, width) 

48 """ 

49 if torch is None: 

50 raise ImportError("PyTorch is required for TorchImageProcessor") 

51 

52 margin_y = int(torch.floor(torch.tensor(height * margin_ratio))) 

53 margin_x = int(torch.floor(torch.tensor(width * margin_ratio))) 

54 

55 weight_y = torch.ones(height, dtype=torch.float32) 

56 if margin_y > 0: 

57 ramp_top = torch.linspace(0, 1, margin_y, dtype=torch.float32) 

58 ramp_bottom = torch.linspace(1, 0, margin_y, dtype=torch.float32) 

59 weight_y[:margin_y] = ramp_top 

60 weight_y[-margin_y:] = ramp_bottom 

61 

62 weight_x = torch.ones(width, dtype=torch.float32) 

63 if margin_x > 0: 

64 ramp_left = torch.linspace(0, 1, margin_x, dtype=torch.float32) 

65 ramp_right = torch.linspace(1, 0, margin_x, dtype=torch.float32) 

66 weight_x[:margin_x] = ramp_left 

67 weight_x[-margin_x:] = ramp_right 

68 

69 # Create 2D weight mask using outer product 

70 weight_mask = torch.outer(weight_y, weight_x) 

71 

72 return weight_mask 

73 

74 

75def _validate_3d_array(array: Any, name: str = "input") -> None: 

76 """ 

77 Validate that the input is a 3D PyTorch tensor. 

78 

79 Args: 

80 array: Array to validate 

81 name: Name of the array for error messages 

82 

83 Raises: 

84 TypeError: If the array is not a PyTorch tensor 

85 ValueError: If the array is not 3D 

86 ImportError: If PyTorch is not available 

87 """ 

88 if torch is None: 

89 raise ImportError("PyTorch is required for TorchImageProcessor") 

90 

91 if not isinstance(array, torch.Tensor): 

92 raise TypeError(f"{name} must be a PyTorch tensor, got {type(array)}. " 

93 f"No automatic conversion is performed to maintain explicit contracts.") 

94 

95 if array.ndim != 3: 

96 raise ValueError(f"{name} must be a 3D tensor, got {array.ndim}D") 

97 

98def _gaussian_blur(image: "torch.Tensor", sigma: float) -> "torch.Tensor": 

99 """ 

100 Apply Gaussian blur to a 2D image. 

101 

102 Args: 

103 image: 2D PyTorch tensor of shape (H, W) 

104 sigma: Standard deviation of the Gaussian kernel 

105 

106 Returns: 

107 Blurred 2D PyTorch tensor of shape (H, W) 

108 """ 

109 # Calculate kernel size based on sigma 

110 kernel_size = max(3, int(2 * 4 * sigma + 1)) 

111 if kernel_size % 2 == 0: 

112 kernel_size += 1 # Ensure odd kernel size 

113 

114 # Create 1D Gaussian kernel 

115 coords = torch.arange(kernel_size, dtype=torch.float32, device=image.device) 

116 coords -= (kernel_size - 1) / 2 

117 

118 # Calculate Gaussian values 

119 gauss = torch.exp(-(coords**2) / (2 * sigma**2)) 

120 kernel = gauss / gauss.sum() 

121 

122 # Reshape for 2D convolution 

123 kernel_x = kernel.view(1, 1, kernel_size, 1) 

124 kernel_y = kernel.view(1, 1, 1, kernel_size) 

125 

126 # Add batch and channel dimensions to image 

127 img = image.unsqueeze(0).unsqueeze(0) 

128 

129 # Apply separable convolution 

130 blurred = F.conv2d(img, kernel_x, padding=(kernel_size//2, 0)) 

131 blurred = F.conv2d(blurred, kernel_y, padding=(0, kernel_size//2)) 

132 

133 # Remove batch and channel dimensions 

134 return blurred.squeeze(0).squeeze(0) 

135 

136@torch_func 

137def sharpen(image: "torch.Tensor", radius: float = 1.0, amount: float = 1.0) -> "torch.Tensor": 

138 """ 

139 Sharpen a 3D image using unsharp masking. 

140 

141 This applies sharpening to each Z-slice independently. 

142 

143 Args: 

144 image: 3D PyTorch tensor of shape (Z, Y, X) 

145 radius: Radius of Gaussian blur 

146 amount: Sharpening strength 

147 

148 Returns: 

149 Sharpened 3D PyTorch tensor of shape (Z, Y, X) 

150 """ 

151 _validate_3d_array(image) 

152 

153 # Store original dtype 

154 dtype = image.dtype 

155 

156 # Process each Z-slice independently 

157 result = torch.zeros_like(image, dtype=torch.float32) 

158 

159 for z in range(image.shape[0]): 

160 # Convert to float for processing 

161 slice_float_raw = image[z].float() 

162 slice_float = slice_float_raw / torch.max(slice_float_raw) 

163 

164 # Create blurred version for unsharp mask 

165 blurred = _gaussian_blur(slice_float, sigma=radius) 

166 

167 # Apply unsharp mask: original + amount * (original - blurred) 

168 sharpened = slice_float + amount * (slice_float - blurred) 

169 

170 # Clip to valid range 

171 sharpened = torch.clamp(sharpened, 0, 1.0) 

172 

173 # Scale back to original range 

174 min_val = torch.min(sharpened) 

175 max_val = torch.max(sharpened) 

176 if max_val > min_val: 

177 sharpened = (sharpened - min_val) * 65535 / (max_val - min_val) 

178 

179 result[z] = sharpened 

180 

181 # Convert back to original dtype 

182 if dtype == torch.uint16: 

183 result = torch.clamp(result, 0, 65535).to(torch.uint16) 

184 else: 

185 result = result.to(dtype) 

186 

187 return result 

188 

189@torch_func 

190def percentile_normalize( 

191 image: "torch.Tensor", 

192 low_percentile: float = 1.0, 

193 high_percentile: float = 99.0, 

194 target_min: float = 0.0, 

195 target_max: float = 65535.0 

196) -> "torch.Tensor": 

197 """ 

198 Normalize a 3D image using percentile-based contrast stretching. 

199 

200 This applies normalization to each Z-slice independently. 

201 

202 Args: 

203 image: 3D PyTorch tensor of shape (Z, Y, X) 

204 low_percentile: Lower percentile (0-100) 

205 high_percentile: Upper percentile (0-100) 

206 target_min: Target minimum value 

207 target_max: Target maximum value 

208 

209 Returns: 

210 Normalized 3D PyTorch tensor of shape (Z, Y, X) 

211 """ 

212 _validate_3d_array(image) 

213 

214 # Process each Z-slice independently 

215 result = torch.zeros_like(image, dtype=torch.float32) 

216 

217 for z in range(image.shape[0]): 

218 # Get percentile values for this slice 

219 # Handle large slices that exceed PyTorch's quantile() size limits 

220 slice_float = image[z].float() 

221 slice_elements = slice_float.numel() 

222 

223 # PyTorch quantile() fails on very large tensors, so we use sampling for large slices 

224 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile() 

225 

226 logger.debug(f"🔥 QUANTILE DEBUG: percentile_normalize slice {z} shape {image[z].shape}, {slice_elements:,} elements") 

227 

228 if slice_elements > max_elements_for_quantile: 

229 # Use random sampling for large slices to estimate percentiles 

230 sample_size = min(max_elements_for_quantile, slice_elements // 10) # Sample 10% or max size 

231 flat_slice = slice_float.flatten() 

232 

233 # Generate random indices for sampling (memory efficient) 

234 # Use torch.randint instead of torch.randperm to avoid creating huge tensors 

235 indices = torch.randint(0, slice_elements, (sample_size,), device=image.device) 

236 sampled_values = flat_slice[indices] 

237 

238 p_low = torch.quantile(sampled_values, low_percentile / 100.0) 

239 p_high = torch.quantile(sampled_values, high_percentile / 100.0) 

240 else: 

241 # Use full slice for smaller slices 

242 p_low = torch.quantile(slice_float, low_percentile / 100.0) 

243 p_high = torch.quantile(slice_float, high_percentile / 100.0) 

244 

245 # Avoid division by zero 

246 if p_high == p_low: 

247 result[z] = torch.ones_like(image[z], dtype=torch.float32) * target_min 

248 continue 

249 

250 # Clip and normalize to target range 

251 clipped = torch.clamp(image[z].float(), p_low, p_high) 

252 scale = (target_max - target_min) / (p_high - p_low) 

253 normalized = (clipped - p_low) * scale + target_min 

254 result[z] = normalized 

255 

256 # Convert to uint16 

257 result = torch.clamp(result, 0, 65535).to(torch.uint16) 

258 

259 return result 

260 

261@torch_func 

262def stack_percentile_normalize( 

263 stack: "torch.Tensor", 

264 low_percentile: float = 1.0, 

265 high_percentile: float = 99.0, 

266 target_min: float = 0.0, 

267 target_max: float = 65535.0 

268) -> "torch.Tensor": 

269 """ 

270 Normalize a stack using global percentile-based contrast stretching. 

271 

272 This ensures consistent normalization across all Z-slices by computing 

273 global percentiles across the entire stack. 

274 

275 Args: 

276 stack: 3D PyTorch tensor of shape (Z, Y, X) 

277 low_percentile: Lower percentile (0-100) 

278 high_percentile: Upper percentile (0-100) 

279 target_min: Target minimum value 

280 target_max: Target maximum value 

281 

282 Returns: 

283 Normalized 3D PyTorch tensor of shape (Z, Y, X) 

284 """ 

285 _validate_3d_array(stack) 

286 

287 # Calculate global percentiles across the entire stack 

288 # Handle large tensors that exceed PyTorch's quantile() size limits 

289 stack_float = stack.float() 

290 total_elements = stack_float.numel() 

291 

292 # PyTorch quantile() fails on very large tensors, so we use sampling for large stacks 

293 max_elements_for_quantile = 10_000_000 # ~10M elements, conservative limit for quantile() 

294 

295 logger.debug(f"🔥 QUANTILE DEBUG: stack_percentile_normalize called with tensor shape {stack.shape}, {total_elements:,} elements") 

296 

297 if total_elements > max_elements_for_quantile: 

298 # Use random sampling for large tensors to estimate percentiles 

299 sample_size = min(max_elements_for_quantile, total_elements // 10) # Sample 10% or max size 

300 flat_stack = stack_float.flatten() 

301 

302 # Generate random indices for sampling (memory efficient) 

303 # Use torch.randint instead of torch.randperm to avoid creating huge tensors 

304 indices = torch.randint(0, total_elements, (sample_size,), device=stack.device) 

305 sampled_values = flat_stack[indices] 

306 

307 p_low = torch.quantile(sampled_values, low_percentile / 100.0) 

308 p_high = torch.quantile(sampled_values, high_percentile / 100.0) 

309 

310 logger.debug(f"Used sampling ({sample_size:,} of {total_elements:,} elements) for percentile calculation due to large tensor size") 

311 else: 

312 # Use full tensor for smaller stacks 

313 p_low = torch.quantile(stack_float, low_percentile / 100.0) 

314 p_high = torch.quantile(stack_float, high_percentile / 100.0) 

315 

316 # Avoid division by zero 

317 if p_high == p_low: 

318 return torch.ones_like(stack) * target_min 

319 

320 # Clip and normalize to target range (match NumPy implementation exactly) 

321 clipped = torch.clamp(stack, p_low, p_high) 

322 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min 

323 normalized = normalized.to(torch.uint16) 

324 

325 return normalized 

326 

327@torch_func 

328def create_composite( 

329 stack: "torch.Tensor", weights: Optional[List[float]] = None 

330) -> "torch.Tensor": 

331 """ 

332 Create a composite image from a 3D stack of 2D images. 

333 

334 Args: 

335 stack: 3D PyTorch tensor of shape (N, Y, X) where N is number of images 

336 weights: List of weights for each image. If None, equal weights are used. 

337 

338 Returns: 

339 Composite 3D PyTorch tensor of shape (1, Y, X) 

340 """ 

341 # Validate input is 3D tensor 

342 _validate_3d_array(stack) 

343 

344 n_images, height, width = stack.shape 

345 

346 # Default weights if none provided 

347 if weights is None: 

348 # Equal weights for all images 

349 weights = [1.0 / n_images] * n_images 

350 elif not isinstance(weights, list): 

351 raise TypeError("weights must be a list of values") 

352 

353 # FAIL FAST: No fallback weights - weights must match exactly 

354 if len(weights) != n_images: 

355 raise ValueError( 

356 f"Weights list length ({len(weights)}) must exactly match number of images ({n_images}). " 

357 f"No automatic padding or truncation allowed." 

358 ) 

359 

360 dtype = stack.dtype 

361 device = stack.device 

362 

363 # Create empty composite 

364 composite = torch.zeros((height, width), dtype=torch.float32, device=device) 

365 total_weight = 0.0 

366 

367 # Add each image with its weight 

368 for i in range(n_images): 

369 weight = weights[i] 

370 if weight <= 0.0: 

371 continue 

372 

373 # Add to composite 

374 composite += stack[i].float() * weight 

375 total_weight += weight 

376 

377 # Normalize by total weight 

378 if total_weight > 0: 

379 composite /= total_weight 

380 

381 # Convert back to original dtype (usually uint16) 

382 if dtype in [torch.uint8, torch.uint16, torch.uint32, torch.int8, torch.int16, torch.int32, torch.int64]: 

383 # Get the maximum value for the specific integer dtype 

384 if dtype == torch.uint8: 

385 max_val = 255 

386 elif dtype == torch.uint16: 

387 max_val = 65535 

388 elif dtype == torch.uint32: 

389 max_val = 4294967295 

390 elif dtype == torch.int8: 

391 max_val = 127 

392 elif dtype == torch.int16: 

393 max_val = 32767 

394 elif dtype == torch.int32: 

395 max_val = 2147483647 

396 elif dtype == torch.int64: 

397 max_val = 9223372036854775807 

398 

399 composite = torch.clamp(composite, 0, max_val).to(dtype) 

400 else: 

401 composite = composite.to(dtype) 

402 

403 # Return as 3D tensor with shape (1, Y, X) 

404 return composite.reshape(1, height, width) 

405 

406@torch_func 

407def apply_mask(image: "torch.Tensor", mask: "torch.Tensor") -> "torch.Tensor": 

408 """ 

409 Apply a mask to a 3D image while maintaining 3D structure. 

410 

411 This applies the mask to each Z-slice independently if mask is 2D, 

412 or applies the 3D mask directly if mask is 3D. 

413 

414 Args: 

415 image: 3D PyTorch tensor of shape (Z, Y, X) 

416 mask: 3D PyTorch tensor of shape (Z, Y, X) or 2D PyTorch tensor of shape (Y, X) 

417 

418 Returns: 

419 Masked 3D PyTorch tensor of shape (Z, Y, X) - dimensionality preserved 

420 """ 

421 _validate_3d_array(image) 

422 

423 # Handle 2D mask (apply to each Z-slice) 

424 if isinstance(mask, torch.Tensor) and mask.ndim == 2: 

425 if mask.shape != image.shape[1:]: 

426 raise ValueError( 

427 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}" 

428 ) 

429 

430 # Apply 2D mask to each Z-slice 

431 result = torch.zeros_like(image) 

432 for z in range(image.shape[0]): 

433 result[z] = image[z].float() * mask.float() 

434 

435 return result.to(image.dtype) 

436 

437 # Handle 3D mask 

438 if isinstance(mask, torch.Tensor) and mask.ndim == 3: 

439 if mask.shape != image.shape: 

440 raise ValueError( 

441 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}" 

442 ) 

443 

444 # Apply 3D mask directly 

445 masked = image.float() * mask.float() 

446 return masked.to(image.dtype) 

447 

448 # If we get here, the mask is neither 2D nor 3D PyTorch tensor 

449 raise TypeError(f"mask must be a 2D or 3D PyTorch tensor, got {type(mask)}") 

450 

451@torch_func 

452def create_weight_mask( 

453 shape: Tuple[int, int], margin_ratio: float = 0.1 

454) -> "torch.Tensor": 

455 """ 

456 Create a weight mask for blending images. 

457 

458 Args: 

459 shape: Shape of the mask (height, width) 

460 margin_ratio: Ratio of image size to use as margin 

461 

462 Returns: 

463 2D PyTorch weight mask of shape (Y, X) 

464 """ 

465 if not isinstance(shape, tuple) or len(shape) != 2: 

466 raise TypeError("shape must be a tuple of (height, width)") 

467 

468 height, width = shape 

469 return create_linear_weight_mask(height, width, margin_ratio) 

470 

471@torch_func 

472def max_projection(stack: "torch.Tensor") -> "torch.Tensor": 

473 """ 

474 Create a maximum intensity projection from a Z-stack. 

475 

476 Args: 

477 stack: 3D PyTorch tensor of shape (Z, Y, X) 

478 

479 Returns: 

480 3D PyTorch tensor of shape (1, Y, X) 

481 """ 

482 _validate_3d_array(stack) 

483 

484 # Store original dtype for conversion back 

485 original_dtype = stack.dtype 

486 

487 # Convert to float32 if needed for GPU operations 

488 if stack.dtype == torch.uint16: 

489 stack_float = stack.float() 

490 else: 

491 stack_float = stack 

492 

493 # Create max projection 

494 projection_2d = torch.max(stack_float, dim=0)[0] 

495 

496 # Convert back to original dtype 

497 projection_2d = projection_2d.to(original_dtype) 

498 

499 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1]) 

500 

501@torch_func 

502def mean_projection(stack: "torch.Tensor") -> "torch.Tensor": 

503 """ 

504 Create a mean intensity projection from a Z-stack. 

505 

506 Args: 

507 stack: 3D PyTorch tensor of shape (Z, Y, X) 

508 

509 Returns: 

510 3D PyTorch tensor of shape (1, Y, X) 

511 """ 

512 _validate_3d_array(stack) 

513 

514 # Store original dtype for conversion back 

515 original_dtype = stack.dtype 

516 

517 # Convert to float32 for mean calculation (always needed for mean) 

518 stack_float = stack.float() 

519 

520 # Create mean projection 

521 projection_2d = torch.mean(stack_float, dim=0) 

522 

523 # Convert back to original dtype 

524 projection_2d = projection_2d.to(original_dtype) 

525 

526 return projection_2d.reshape(1, projection_2d.shape[0], projection_2d.shape[1]) 

527 

528@torch_func 

529def stack_equalize_histogram( 

530 stack: "torch.Tensor", 

531 bins: int = 65536, 

532 range_min: float = 0.0, 

533 range_max: float = 65535.0 

534) -> "torch.Tensor": 

535 """ 

536 Apply histogram equalization to an entire stack. 

537 

538 This ensures consistent contrast enhancement across all Z-slices by 

539 computing a global histogram across the entire stack. 

540 

541 Args: 

542 stack: 3D PyTorch tensor of shape (Z, Y, X) 

543 bins: Number of bins for histogram computation 

544 range_min: Minimum value for histogram range 

545 range_max: Maximum value for histogram range 

546 

547 Returns: 

548 Equalized 3D PyTorch tensor of shape (Z, Y, X) 

549 """ 

550 _validate_3d_array(stack) 

551 

552 # PyTorch doesn't have a direct histogram equalization function 

553 # We'll implement it manually using torch.histc for the histogram 

554 

555 # Flatten the entire stack to compute the global histogram 

556 flat_stack = stack.float().flatten() 

557 

558 # For very large stacks, use sampling to avoid memory issues 

559 max_elements_for_histogram = 50_000_000 # 50M elements limit 

560 if flat_stack.numel() > max_elements_for_histogram: 

561 # Use random sampling for histogram computation 

562 sample_size = max_elements_for_histogram 

563 indices = torch.randint(0, flat_stack.numel(), (sample_size,), device=stack.device) 

564 sampled_stack = flat_stack[indices] 

565 hist = torch.histc(sampled_stack, bins=bins, min=range_min, max=range_max) 

566 logger.debug(f"Used sampling ({sample_size:,} of {flat_stack.numel():,} elements) for histogram computation") 

567 else: 

568 # Use full stack for smaller stacks 

569 hist = torch.histc(flat_stack, bins=bins, min=range_min, max=range_max) 

570 

571 # We don't need bin edges for the lookup table approach 

572 

573 # Calculate cumulative distribution function (CDF) 

574 cdf = torch.cumsum(hist, dim=0) 

575 

576 # Normalize the CDF to the range [0, 65535] 

577 # Avoid division by zero 

578 if cdf[-1] > 0: 

579 cdf = 65535 * cdf / cdf[-1] 

580 

581 # PyTorch doesn't have a direct equivalent to numpy's interp 

582 # We'll use a lookup table approach 

583 

584 # Scale input values to bin indices 

585 indices = torch.clamp( 

586 ((flat_stack - range_min) / (range_max - range_min) * (bins - 1)).long(), 

587 0, bins - 1 

588 ) 

589 

590 # Look up CDF values 

591 equalized_flat = torch.gather(cdf, 0, indices) 

592 

593 # Reshape back to original shape 

594 equalized_stack = equalized_flat.reshape(stack.shape) 

595 

596 # Convert to uint16 

597 return equalized_stack.to(torch.uint16) 

598 

599@torch_func 

600def create_projection( 

601 stack: "torch.Tensor", method: str = "max_projection" 

602) -> "torch.Tensor": 

603 """ 

604 Create a projection from a stack using the specified method. 

605 

606 Args: 

607 stack: 3D PyTorch tensor of shape (Z, Y, X) 

608 method: Projection method (max_projection, mean_projection) 

609 

610 Returns: 

611 3D PyTorch tensor of shape (1, Y, X) 

612 """ 

613 _validate_3d_array(stack) 

614 

615 if method == "max_projection": 

616 return max_projection(stack) 

617 

618 if method == "mean_projection": 

619 return mean_projection(stack) 

620 

621 # FAIL FAST: No fallback projection methods 

622 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection") 

623 

624@torch_func 

625def tophat( 

626 image: "torch.Tensor", 

627 selem_radius: int = 50, 

628 downsample_factor: int = 4 

629) -> "torch.Tensor": 

630 """ 

631 Apply white top-hat filter to a 3D image for background removal. 

632 

633 This applies the filter to each Z-slice independently using PyTorch's 

634 native operations. 

635 

636 Args: 

637 image: 3D PyTorch tensor of shape (Z, Y, X) 

638 selem_radius: Radius of the structuring element disk 

639 downsample_factor: Factor by which to downsample the image for processing 

640 

641 Returns: 

642 Filtered 3D PyTorch tensor of shape (Z, Y, X) 

643 """ 

644 _validate_3d_array(image) 

645 

646 # Store device for later use 

647 device = image.device 

648 

649 # Process each Z-slice independently 

650 result = torch.zeros_like(image) 

651 

652 # We'll create structuring elements for each slice as needed 

653 

654 for z in range(image.shape[0]): 

655 # Store original data type 

656 input_dtype = image[z].dtype 

657 

658 # 1) Downsample using PyTorch's interpolate function 

659 # First, add batch and channel dimensions for interpolate 

660 img_4d = image[z].float().unsqueeze(0).unsqueeze(0) 

661 

662 # Calculate new dimensions 

663 new_h = image[z].shape[0] // downsample_factor 

664 new_w = image[z].shape[1] // downsample_factor 

665 

666 # Resize using PyTorch's interpolate function 

667 image_small = F.interpolate( 

668 img_4d, 

669 size=(new_h, new_w), 

670 mode='bilinear', 

671 align_corners=False 

672 ).squeeze(0).squeeze(0) 

673 

674 # 2) Resize the structuring element to match the downsampled image 

675 small_selem_radius = max(1, selem_radius // downsample_factor) 

676 small_grid_size = 2 * small_selem_radius + 1 

677 small_grid_y, small_grid_x = torch.meshgrid( 

678 torch.arange(small_grid_size, device=device) - small_selem_radius, 

679 torch.arange(small_grid_size, device=device) - small_selem_radius, 

680 indexing='ij' 

681 ) 

682 small_mask = (small_grid_x.pow(2) + small_grid_y.pow(2)) <= small_selem_radius**2 

683 small_selem = small_mask.float() 

684 

685 # 3) Apply white top-hat using PyTorch's convolution operations 

686 # White top-hat is opening subtracted from the original image 

687 # Opening is erosion followed by dilation 

688 

689 # Implement erosion using min pooling with custom kernel 

690 # First, pad the image to handle boundary conditions 

691 pad_size = small_selem_radius 

692 padded = F.pad( 

693 image_small.unsqueeze(0).unsqueeze(0), 

694 (pad_size, pad_size, pad_size, pad_size), 

695 mode='reflect' 

696 ) 

697 

698 # Unfold the padded image into patches 

699 patches = F.unfold(padded, kernel_size=small_grid_size, stride=1) 

700 

701 # Reshape patches for processing 

702 patch_size = small_grid_size * small_grid_size 

703 patches = patches.reshape(1, patch_size, new_h, new_w) 

704 

705 # Apply the structuring element as a mask 

706 masked_patches = patches * small_selem.reshape(-1, 1, 1) 

707 

708 # Perform erosion (min pooling) 

709 eroded = torch.min( 

710 masked_patches + (1 - small_selem.reshape(-1, 1, 1)) * 1e9, 

711 dim=1 

712 )[0] 

713 

714 # Implement dilation using max pooling with custom kernel 

715 # Pad the eroded image 

716 padded_eroded = F.pad( 

717 eroded.unsqueeze(0).unsqueeze(0), 

718 (pad_size, pad_size, pad_size, pad_size), 

719 mode='reflect' 

720 ) 

721 

722 # Unfold the padded eroded image into patches 

723 patches_eroded = F.unfold(padded_eroded, kernel_size=small_grid_size, stride=1) 

724 

725 # Reshape patches for processing 

726 patch_size = small_grid_size * small_grid_size 

727 patches_eroded = patches_eroded.reshape(1, patch_size, new_h, new_w) 

728 

729 # Apply the structuring element as a mask 

730 masked_patches_eroded = patches_eroded * small_selem.reshape(-1, 1, 1) 

731 

732 # Perform dilation (max pooling) 

733 opened = torch.max(masked_patches_eroded, dim=1)[0] 

734 

735 # White top-hat is original minus opening 

736 tophat_small = image_small - opened 

737 

738 # 4) Calculate background 

739 background_small = image_small - tophat_small 

740 

741 # 5) Upscale background to original size 

742 background_4d = background_small.unsqueeze(0).unsqueeze(0) 

743 background_large = F.interpolate( 

744 background_4d, 

745 size=image[z].shape, 

746 mode='bilinear', 

747 align_corners=False 

748 ).squeeze(0).squeeze(0) 

749 

750 # 6) Subtract background and clip negative values 

751 slice_result = torch.clamp(image[z].float() - background_large, min=0.0) 

752 

753 # 7) Convert back to original data type 

754 result[z] = slice_result.to(input_dtype) 

755 

756 return result