Coverage for openhcs/processing/backends/processors/tensorflow_processor.py: 12.7%

248 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2TensorFlow Image Processor Implementation 

3 

4This module implements the ImageProcessorInterface using TensorFlow as the backend. 

5It leverages GPU acceleration for image processing operations. 

6 

7Doctrinal Clauses: 

8- Clause 3 — Declarative Primacy: All functions are pure and stateless 

9- Clause 88 — No Inferred Capabilities: Explicit TensorFlow dependency 

10- Clause 106-A — Declared Memory Types: All methods specify TensorFlow tensors 

11""" 

12from __future__ import annotations 

13 

14import logging 

15from typing import Any, List, Optional, Tuple 

16 

17import pkg_resources 

18 

19from openhcs.core.memory.decorators import tensorflow as tensorflow_func 

20from openhcs.core.utils import optional_import 

21 

22# Define error variable 

23TENSORFLOW_ERROR = "" 

24 

25# Import TensorFlow as an optional dependency 

26tf = optional_import("tensorflow") 

27 

28# Check TensorFlow version for DLPack compatibility if available 

29if tf is not None: 29 ↛ 45line 29 didn't jump to line 45 because the condition on line 29 was always true

30 try: 

31 tf_version = pkg_resources.parse_version(tf.__version__) 

32 min_version = pkg_resources.parse_version("2.12.0") 

33 

34 if tf_version < min_version: 

35 TENSORFLOW_ERROR = ( 

36 f"TensorFlow version {tf.__version__} is not supported for DLPack operations. " 

37 f"Version 2.12.0 or higher is required for stable DLPack support. " 

38 f"Clause 88 violation: Cannot infer DLPack capability." 

39 ) 

40 tf = None 

41 except Exception as e: 

42 TENSORFLOW_ERROR = str(e) 

43 tf = None 

44 

45logger = logging.getLogger(__name__) 

46 

47 

48@tensorflow_func 

49def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "tf.Tensor": 

50 """ 

51 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center. 

52 

53 Args: 

54 height: Height of the mask 

55 width: Width of the mask 

56 margin_ratio: Ratio of the margin to the image size 

57 

58 Returns: 

59 2D TensorFlow weight mask of shape (height, width) 

60 """ 

61 # The compiler will ensure this function is only called when TensorFlow is available 

62 # No need to check for TensorFlow availability here 

63 

64 margin_y = int(tf.math.floor(height * margin_ratio)) 

65 margin_x = int(tf.math.floor(width * margin_ratio)) 

66 

67 weight_y = tf.ones(height, dtype=tf.float32) 

68 if margin_y > 0: 

69 ramp_top = tf.linspace(0.0, 1.0, margin_y) 

70 ramp_bottom = tf.linspace(1.0, 0.0, margin_y) 

71 

72 # Update slices of the weight_y tensor 

73 weight_y = tf.tensor_scatter_nd_update( 

74 weight_y, 

75 tf.stack([tf.range(margin_y)], axis=1), 

76 ramp_top 

77 ) 

78 weight_y = tf.tensor_scatter_nd_update( 

79 weight_y, 

80 tf.stack([tf.range(height - margin_y, height)], axis=1), 

81 ramp_bottom 

82 ) 

83 

84 weight_x = tf.ones(width, dtype=tf.float32) 

85 if margin_x > 0: 

86 ramp_left = tf.linspace(0.0, 1.0, margin_x) 

87 ramp_right = tf.linspace(1.0, 0.0, margin_x) 

88 

89 # Update slices of the weight_x tensor 

90 weight_x = tf.tensor_scatter_nd_update( 

91 weight_x, 

92 tf.stack([tf.range(margin_x)], axis=1), 

93 ramp_left 

94 ) 

95 weight_x = tf.tensor_scatter_nd_update( 

96 weight_x, 

97 tf.stack([tf.range(width - margin_x, width)], axis=1), 

98 ramp_right 

99 ) 

100 

101 # Create 2D weight mask using outer product 

102 weight_mask = tf.tensordot(weight_y, weight_x, axes=0) 

103 

104 return weight_mask 

105 

106 

107def _validate_3d_array(array: Any, name: str = "input") -> None: 

108 """ 

109 Validate that the input is a 3D TensorFlow tensor. 

110 

111 Args: 

112 array: Array to validate 

113 name: Name of the array for error messages 

114 

115 Raises: 

116 TypeError: If the array is not a TensorFlow tensor 

117 ValueError: If the array is not 3D 

118 ImportError: If TensorFlow is not available 

119 """ 

120 # The compiler will ensure this function is only called when TensorFlow is available 

121 # No need to check for TensorFlow availability here 

122 

123 if not isinstance(array, tf.Tensor): 

124 raise TypeError(f"{name} must be a TensorFlow tensor, got {type(array)}. " 

125 f"No automatic conversion is performed to maintain explicit contracts.") 

126 

127 if len(array.shape) != 3: 

128 raise ValueError(f"{name} must be a 3D tensor, got {len(array.shape)}D") 

129 

130def _gaussian_blur(image: "tf.Tensor", sigma: float) -> "tf.Tensor": 

131 """ 

132 Apply Gaussian blur to a 2D image. 

133 

134 Args: 

135 image: 2D TensorFlow tensor of shape (H, W) 

136 sigma: Standard deviation of the Gaussian kernel 

137 

138 Returns: 

139 Blurred 2D TensorFlow tensor of shape (H, W) 

140 """ 

141 # Calculate kernel size based on sigma 

142 kernel_size = max(3, int(2 * 4 * sigma + 1)) 

143 if kernel_size % 2 == 0: 

144 kernel_size += 1 # Ensure odd kernel size 

145 

146 # Add batch and channel dimensions for tf.image.gaussian_blur 

147 img = tf.expand_dims(tf.expand_dims(image, 0), -1) 

148 

149 # Apply Gaussian blur 

150 blurred = tf.image.gaussian_blur( 

151 img, 

152 [kernel_size, kernel_size], 

153 sigma, 

154 "REFLECT" 

155 ) 

156 

157 # Remove batch and channel dimensions 

158 return tf.squeeze(blurred) 

159 

160@tensorflow_func 

161def sharpen( 

162 image: "tf.Tensor", radius: float = 1.0, amount: float = 1.0 

163) -> "tf.Tensor": 

164 """ 

165 Sharpen a 3D image using unsharp masking. 

166 

167 This applies sharpening to each Z-slice independently. 

168 

169 Args: 

170 image: 3D TensorFlow tensor of shape (Z, Y, X) 

171 radius: Radius of Gaussian blur 

172 amount: Sharpening strength 

173 

174 Returns: 

175 Sharpened 3D TensorFlow tensor of shape (Z, Y, X) 

176 """ 

177 _validate_3d_array(image) 

178 

179 # Store original dtype 

180 dtype = image.dtype 

181 

182 # Process each Z-slice independently 

183 result_list = [] 

184 

185 for z in range(image.shape[0]): 

186 # Convert to float for processing 

187 slice_float = tf.cast(image[z], tf.float32) / tf.reduce_max(image[z]) 

188 

189 # Create blurred version for unsharp mask 

190 blurred = _gaussian_blur(slice_float, sigma=radius) 

191 

192 # Apply unsharp mask: original + amount * (original - blurred) 

193 sharpened = slice_float + amount * (slice_float - blurred) 

194 

195 # Clip to valid range 

196 sharpened = tf.clip_by_value(sharpened, 0.0, 1.0) 

197 

198 # Scale back to original range 

199 min_val = tf.reduce_min(sharpened) 

200 max_val = tf.reduce_max(sharpened) 

201 if max_val > min_val: 

202 sharpened = (sharpened - min_val) * 65535.0 / (max_val - min_val) 

203 

204 result_list.append(sharpened) 

205 

206 # Stack results back into a 3D tensor 

207 result = tf.stack(result_list, axis=0) 

208 

209 # Convert back to original dtype 

210 if dtype == tf.uint16: 

211 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16) 

212 else: 

213 result = tf.cast(result, dtype) 

214 

215 return result 

216 

217@tensorflow_func 

218def percentile_normalize( 

219 image: "tf.Tensor", 

220 low_percentile: float = 1.0, 

221 high_percentile: float = 99.0, 

222 target_min: float = 0.0, 

223 target_max: float = 65535.0 

224) -> "tf.Tensor": 

225 """ 

226 Normalize a 3D image using percentile-based contrast stretching. 

227 

228 This applies normalization to each Z-slice independently. 

229 

230 Args: 

231 image: 3D TensorFlow tensor of shape (Z, Y, X) 

232 low_percentile: Lower percentile (0-100) 

233 high_percentile: Upper percentile (0-100) 

234 target_min: Target minimum value 

235 target_max: Target maximum value 

236 

237 Returns: 

238 Normalized 3D TensorFlow tensor of shape (Z, Y, X) 

239 """ 

240 _validate_3d_array(image) 

241 

242 # Process each Z-slice independently 

243 result_list = [] 

244 

245 for z in range(image.shape[0]): 

246 # Get percentile values for this slice 

247 # TensorFlow doesn't have a direct percentile function, so we use a workaround 

248 flat_slice = tf.reshape(image[z], [-1]) 

249 sorted_slice = tf.sort(flat_slice) 

250 

251 # Calculate indices for percentiles 

252 slice_size = tf.cast(tf.size(flat_slice), tf.float32) 

253 low_idx = tf.cast(tf.math.floor(slice_size * low_percentile / 100.0), tf.int32) 

254 high_idx = tf.cast(tf.math.floor(slice_size * high_percentile / 100.0), tf.int32) 

255 

256 # Get percentile values 

257 p_low = sorted_slice[low_idx] 

258 p_high = sorted_slice[high_idx] 

259 

260 # Avoid division by zero 

261 if p_high == p_low: 

262 result_list.append(tf.ones_like(image[z], dtype=tf.float32) * target_min) 

263 continue 

264 

265 # Clip and normalize to target range 

266 clipped = tf.clip_by_value(tf.cast(image[z], tf.float32), p_low, p_high) 

267 scale = (target_max - target_min) / (p_high - p_low) 

268 normalized = (clipped - p_low) * scale + target_min 

269 result_list.append(normalized) 

270 

271 # Stack results back into a 3D tensor 

272 result = tf.stack(result_list, axis=0) 

273 

274 # Convert to uint16 

275 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16) 

276 

277 return result 

278 

279@tensorflow_func 

280def stack_percentile_normalize( 

281 stack: "tf.Tensor", 

282 low_percentile: float = 1.0, 

283 high_percentile: float = 99.0, 

284 target_min: float = 0.0, 

285 target_max: float = 65535.0 

286) -> "tf.Tensor": 

287 """ 

288 Normalize a stack using global percentile-based contrast stretching. 

289 

290 This ensures consistent normalization across all Z-slices by computing 

291 global percentiles across the entire stack. 

292 

293 Args: 

294 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

295 low_percentile: Lower percentile (0-100) 

296 high_percentile: Upper percentile (0-100) 

297 target_min: Target minimum value 

298 target_max: Target maximum value 

299 

300 Returns: 

301 Normalized 3D TensorFlow tensor of shape (Z, Y, X) 

302 """ 

303 _validate_3d_array(stack) 

304 

305 # Calculate global percentiles across the entire stack using TensorFlow Probability 

306 # This is memory-efficient and doesn't require sorting the entire array 

307 try: 

308 import tensorflow_probability as tfp 

309 p_low = tf.cast(tfp.stats.percentile(stack, low_percentile), tf.float32) 

310 p_high = tf.cast(tfp.stats.percentile(stack, high_percentile), tf.float32) 

311 except ImportError: 

312 # Fallback to manual calculation if TensorFlow Probability is not available 

313 # This is less memory-efficient but works 

314 flat_stack = tf.reshape(stack, [-1]) 

315 sorted_stack = tf.sort(flat_stack) 

316 

317 # Calculate indices for percentiles 

318 stack_size = tf.cast(tf.size(flat_stack), tf.float32) 

319 low_idx = tf.cast(tf.math.floor(stack_size * low_percentile / 100.0), tf.int32) 

320 high_idx = tf.cast(tf.math.floor(stack_size * high_percentile / 100.0), tf.int32) 

321 

322 # Get percentile values and cast to float32 for consistency 

323 p_low = tf.cast(sorted_stack[low_idx], tf.float32) 

324 p_high = tf.cast(sorted_stack[high_idx], tf.float32) 

325 

326 # Avoid division by zero 

327 if p_high == p_low: 

328 return tf.ones_like(stack) * target_min 

329 

330 # Clip and normalize to target range (match NumPy implementation exactly) 

331 clipped = tf.clip_by_value(stack, p_low, p_high) 

332 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min 

333 normalized = tf.cast(normalized, tf.uint16) 

334 

335 return normalized 

336 

337@tensorflow_func 

338def create_composite( 

339 images: List["tf.Tensor"], weights: Optional[List[float]] = None 

340) -> "tf.Tensor": 

341 """ 

342 Create a composite image from multiple 3D arrays. 

343 

344 Args: 

345 images: List of 3D TensorFlow tensors, each of shape (Z, Y, X) 

346 weights: List of weights for each image. If None, equal weights are used. 

347 

348 Returns: 

349 Composite 3D TensorFlow tensor of shape (Z, Y, X) 

350 """ 

351 # Ensure images is a list 

352 if not isinstance(images, list): 

353 raise TypeError("images must be a list of TensorFlow tensors") 

354 

355 # Check for empty list early 

356 if not images: 

357 raise ValueError("images list cannot be empty") 

358 

359 # Validate all images are 3D TensorFlow tensors with the same shape 

360 for i, img in enumerate(images): 

361 _validate_3d_array(img, f"images[{i}]") 

362 if img.shape != images[0].shape: 

363 raise ValueError( 

364 f"All images must have the same shape. " 

365 f"images[0] has shape {images[0].shape}, " 

366 f"images[{i}] has shape {img.shape}" 

367 ) 

368 

369 # Default weights if none provided 

370 if weights is None: 

371 # Equal weights for all images 

372 weights = [1.0 / len(images)] * len(images) 

373 elif not isinstance(weights, list): 

374 raise TypeError("weights must be a list of values") 

375 

376 # Make sure weights list is at least as long as images list 

377 if len(weights) < len(images): 

378 weights = weights + [0.0] * (len(images) - len(weights)) 

379 # Truncate weights if longer than images 

380 weights = weights[:len(images)] 

381 

382 first_image = images[0] 

383 shape = first_image.shape 

384 dtype = first_image.dtype 

385 

386 # Create empty composite 

387 composite = tf.zeros(shape, dtype=tf.float32) 

388 total_weight = 0.0 

389 

390 # Add each image with its weight 

391 for i, image in enumerate(images): 

392 weight = weights[i] 

393 if weight <= 0.0: 

394 continue 

395 

396 # Add to composite 

397 composite += tf.cast(image, tf.float32) * weight 

398 total_weight += weight 

399 

400 # Normalize by total weight 

401 if total_weight > 0: 

402 composite /= total_weight 

403 

404 # Convert back to original dtype (usually uint16) 

405 if dtype in [tf.uint8, tf.uint16, tf.uint32, tf.int8, tf.int16, tf.int32, tf.int64]: 

406 # Get the maximum value for the specific integer dtype 

407 if dtype == tf.uint8: 

408 max_val = 255 

409 elif dtype == tf.uint16: 

410 max_val = 65535 

411 elif dtype == tf.uint32: 

412 max_val = 4294967295 

413 elif dtype == tf.int8: 

414 max_val = 127 

415 elif dtype == tf.int16: 

416 max_val = 32767 

417 elif dtype == tf.int32: 

418 max_val = 2147483647 

419 elif dtype == tf.int64: 

420 max_val = 9223372036854775807 

421 

422 composite = tf.cast(tf.clip_by_value(composite, 0, max_val), dtype) 

423 else: 

424 composite = tf.cast(composite, dtype) 

425 

426 return composite 

427 

428@tensorflow_func 

429def apply_mask(image: "tf.Tensor", mask: "tf.Tensor") -> "tf.Tensor": 

430 """ 

431 Apply a mask to a 3D image. 

432 

433 This applies the mask to each Z-slice independently if mask is 2D, 

434 or applies the 3D mask directly if mask is 3D. 

435 

436 Args: 

437 image: 3D TensorFlow tensor of shape (Z, Y, X) 

438 mask: 3D TensorFlow tensor of shape (Z, Y, X) or 2D TensorFlow tensor of shape (Y, X) 

439 

440 Returns: 

441 Masked 3D TensorFlow tensor of shape (Z, Y, X) 

442 """ 

443 _validate_3d_array(image) 

444 

445 # Handle 2D mask (apply to each Z-slice) 

446 if isinstance(mask, tf.Tensor) and len(mask.shape) == 2: 

447 if mask.shape != image.shape[1:]: 

448 raise ValueError( 

449 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}" 

450 ) 

451 

452 # Apply 2D mask to each Z-slice 

453 result_list = [] 

454 for z in range(image.shape[0]): 

455 result_list.append(tf.cast(image[z], tf.float32) * tf.cast(mask, tf.float32)) 

456 

457 result = tf.stack(result_list, axis=0) 

458 return tf.cast(result, image.dtype) 

459 

460 # Handle 3D mask 

461 if isinstance(mask, tf.Tensor) and len(mask.shape) == 3: 

462 if mask.shape != image.shape: 

463 raise ValueError( 

464 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}" 

465 ) 

466 

467 # Apply 3D mask directly 

468 masked = tf.cast(image, tf.float32) * tf.cast(mask, tf.float32) 

469 return tf.cast(masked, image.dtype) 

470 

471 # If we get here, the mask is neither 2D nor 3D TensorFlow tensor 

472 raise TypeError(f"mask must be a 2D or 3D TensorFlow tensor, got {type(mask)}") 

473 

474@tensorflow_func 

475def create_weight_mask( 

476 shape: Tuple[int, int], margin_ratio: float = 0.1 

477) -> "tf.Tensor": 

478 """ 

479 Create a weight mask for blending images. 

480 

481 Args: 

482 shape: Shape of the mask (height, width) 

483 margin_ratio: Ratio of image size to use as margin 

484 

485 Returns: 

486 2D TensorFlow weight mask of shape (Y, X) 

487 """ 

488 if not isinstance(shape, tuple) or len(shape) != 2: 

489 raise TypeError("shape must be a tuple of (height, width)") 

490 

491 height, width = shape 

492 return create_linear_weight_mask(height, width, margin_ratio) 

493 

494@tensorflow_func 

495def max_projection(stack: "tf.Tensor") -> "tf.Tensor": 

496 """ 

497 Create a maximum intensity projection from a Z-stack. 

498 

499 Args: 

500 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

501 

502 Returns: 

503 3D TensorFlow tensor of shape (1, Y, X) 

504 """ 

505 _validate_3d_array(stack) 

506 

507 # Create max projection 

508 projection_2d = tf.reduce_max(stack, axis=0) 

509 return tf.expand_dims(projection_2d, axis=0) 

510 

511@tensorflow_func 

512def mean_projection(stack: "tf.Tensor") -> "tf.Tensor": 

513 """ 

514 Create a mean intensity projection from a Z-stack. 

515 

516 Args: 

517 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

518 

519 Returns: 

520 3D TensorFlow tensor of shape (1, Y, X) 

521 """ 

522 _validate_3d_array(stack) 

523 

524 # Create mean projection 

525 projection_2d = tf.cast(tf.reduce_mean(tf.cast(stack, tf.float32), axis=0), stack.dtype) 

526 return tf.expand_dims(projection_2d, axis=0) 

527 

528@tensorflow_func 

529def stack_equalize_histogram( 

530 stack: "tf.Tensor", 

531 bins: int = 65536, 

532 range_min: float = 0.0, 

533 range_max: float = 65535.0 

534) -> "tf.Tensor": 

535 """ 

536 Apply histogram equalization to an entire stack. 

537 

538 This ensures consistent contrast enhancement across all Z-slices by 

539 computing a global histogram across the entire stack. 

540 

541 Args: 

542 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

543 bins: Number of bins for histogram computation 

544 range_min: Minimum value for histogram range 

545 range_max: Maximum value for histogram range 

546 

547 Returns: 

548 Equalized 3D TensorFlow tensor of shape (Z, Y, X) 

549 """ 

550 _validate_3d_array(stack) 

551 

552 # TensorFlow doesn't have a direct histogram equalization function 

553 # We'll implement it manually 

554 

555 # Flatten the entire stack to compute the global histogram 

556 flat_stack = tf.reshape(tf.cast(stack, tf.float32), [-1]) 

557 

558 # Calculate the histogram 

559 # TensorFlow doesn't have a direct equivalent to numpy's histogram 

560 # We'll use tf.histogram_fixed_width 

561 hist = tf.histogram_fixed_width( 

562 flat_stack, 

563 [range_min, range_max], 

564 nbins=bins 

565 ) 

566 

567 # Calculate cumulative distribution function (CDF) 

568 cdf = tf.cumsum(hist) 

569 

570 # Normalize the CDF to the range [0, 65535] 

571 # Avoid division by zero 

572 if tf.reduce_max(cdf) > 0: 

573 cdf = 65535.0 * cdf / tf.cast(cdf[-1], tf.float32) 

574 

575 # We don't need bin width for the lookup table approach 

576 

577 # Scale input values to bin indices 

578 indices = tf.cast(tf.clip_by_value( 

579 tf.math.floor((flat_stack - range_min) / (range_max - range_min) * bins), 

580 0, bins - 1 

581 ), tf.int32) 

582 

583 # Look up CDF values 

584 equalized_flat = tf.gather(cdf, indices) 

585 

586 # Reshape back to original shape 

587 equalized_stack = tf.reshape(equalized_flat, stack.shape) 

588 

589 # Convert to uint16 

590 return tf.cast(equalized_stack, tf.uint16) 

591 

592@tensorflow_func 

593def create_projection( 

594 stack: "tf.Tensor", method: str = "max_projection" 

595) -> "tf.Tensor": 

596 """ 

597 Create a projection from a stack using the specified method. 

598 

599 Args: 

600 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

601 method: Projection method (max_projection, mean_projection) 

602 

603 Returns: 

604 3D TensorFlow tensor of shape (1, Y, X) 

605 """ 

606 _validate_3d_array(stack) 

607 

608 if method == "max_projection": 

609 return max_projection(stack) 

610 

611 if method == "mean_projection": 

612 return mean_projection(stack) 

613 

614 # FAIL FAST: No fallback projection methods 

615 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection") 

616 

617@tensorflow_func 

618def tophat( 

619 image: "tf.Tensor", 

620 selem_radius: int = 50, 

621 downsample_factor: int = 4 

622) -> "tf.Tensor": 

623 """ 

624 Apply white top-hat filter to a 3D image for background removal. 

625 

626 This applies the filter to each Z-slice independently using TensorFlow's 

627 native operations. 

628 

629 Args: 

630 image: 3D TensorFlow tensor of shape (Z, Y, X) 

631 selem_radius: Radius of the structuring element disk 

632 downsample_factor: Factor by which to downsample the image for processing 

633 

634 Returns: 

635 Filtered 3D TensorFlow tensor of shape (Z, Y, X) 

636 """ 

637 _validate_3d_array(image) 

638 

639 # Process each Z-slice independently 

640 result_list = [] 

641 

642 for z in range(image.shape[0]): 

643 # Store original data type 

644 input_dtype = image[z].dtype 

645 

646 # 1) Downsample using TensorFlow's resize function 

647 # First, add batch and channel dimensions for resize 

648 img_4d = tf.expand_dims(tf.expand_dims(tf.cast(image[z], tf.float32), 0), -1) 

649 

650 # Calculate new dimensions 

651 new_h = tf.cast(tf.shape(image[z])[0] // downsample_factor, tf.int32) 

652 new_w = tf.cast(tf.shape(image[z])[1] // downsample_factor, tf.int32) 

653 

654 # Resize using TensorFlow's resize function 

655 image_small = tf.squeeze(tf.image.resize( 

656 img_4d, 

657 [new_h, new_w], 

658 method=tf.image.ResizeMethod.BILINEAR 

659 ), axis=[0, -1]) 

660 

661 # 2) Create a circular structuring element 

662 small_selem_radius = tf.maximum(1, selem_radius // downsample_factor) 

663 small_grid_size = 2 * small_selem_radius + 1 

664 

665 # Create grid for structuring element 

666 y_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32) 

667 x_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32) 

668 small_y_grid, small_x_grid = tf.meshgrid(y_range, x_range) 

669 

670 # Create circular mask 

671 small_mask = tf.cast( 

672 tf.sqrt(tf.square(small_y_grid) + tf.square(small_x_grid)) <= small_selem_radius, 

673 tf.float32 

674 ) 

675 

676 # 3) Apply white top-hat using TensorFlow's morphological operations 

677 # White top-hat is opening subtracted from the original image 

678 # Opening is erosion followed by dilation 

679 

680 # Implement erosion using TensorFlow's conv2d with custom kernel 

681 

682 # Pad the image to handle boundary conditions 

683 pad_size = small_selem_radius 

684 padded = tf.pad( 

685 tf.expand_dims(tf.expand_dims(image_small, 0), -1), 

686 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], 

687 mode='SYMMETRIC' 

688 ) 

689 

690 # For erosion, we need to find the minimum value in the neighborhood 

691 # We can use a trick: negate the image, apply max pooling, then negate again 

692 neg_padded = -padded 

693 

694 # Apply convolution with the kernel 

695 # We use a large negative value for pixels outside the mask 

696 mask_expanded = tf.reshape(small_mask, [small_grid_size, small_grid_size, 1, 1]) 

697 mask_complement = 1.0 - mask_expanded 

698 large_neg = tf.constant(-1e9, dtype=tf.float32) 

699 

700 # Custom erosion using depthwise_conv2d 

701 eroded_neg = tf.nn.depthwise_conv2d( 

702 neg_padded + mask_complement * large_neg, 

703 tf.tile(mask_expanded, [1, 1, 1, 1]), 

704 strides=[1, 1, 1, 1], 

705 padding='VALID' 

706 ) 

707 

708 # Convert back to positive 

709 eroded = -eroded_neg 

710 

711 # Implement dilation using similar approach 

712 # Pad the eroded image 

713 padded_eroded = tf.pad( 

714 eroded, 

715 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], 

716 mode='SYMMETRIC' 

717 ) 

718 

719 # For dilation, we need to find the maximum value in the neighborhood 

720 # Apply convolution with the kernel 

721 opened = tf.nn.depthwise_conv2d( 

722 padded_eroded, 

723 tf.tile(mask_expanded, [1, 1, 1, 1]), 

724 strides=[1, 1, 1, 1], 

725 padding='VALID' 

726 ) 

727 

728 # Remove batch and channel dimensions 

729 opened = tf.squeeze(opened, axis=[0, -1]) 

730 

731 # White top-hat is original minus opening 

732 tophat_small = image_small - opened 

733 

734 # 4) Calculate background 

735 background_small = image_small - tophat_small 

736 

737 # 5) Upscale background to original size 

738 background_4d = tf.expand_dims(tf.expand_dims(background_small, 0), -1) 

739 background_large = tf.squeeze(tf.image.resize( 

740 background_4d, 

741 tf.shape(image[z])[:2], 

742 method=tf.image.ResizeMethod.BILINEAR 

743 ), axis=[0, -1]) 

744 

745 # 6) Subtract background and clip negative values 

746 slice_result = tf.maximum(tf.cast(image[z], tf.float32) - background_large, 0.0) 

747 

748 # 7) Convert back to original data type 

749 result_list.append(tf.cast(slice_result, input_dtype)) 

750 

751 # Stack results back into a 3D tensor 

752 result = tf.stack(result_list, axis=0) 

753 

754 return result