Coverage for openhcs/processing/backends/processors/tensorflow_processor.py: 12.5%

247 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2TensorFlow Image Processor Implementation 

3 

4This module implements the ImageProcessorInterface using TensorFlow as the backend. 

5It leverages GPU acceleration for image processing operations. 

6 

7Doctrinal Clauses: 

8- Clause 3 — Declarative Primacy: All functions are pure and stateless 

9- Clause 88 — No Inferred Capabilities: Explicit TensorFlow dependency 

10- Clause 106-A — Declared Memory Types: All methods specify TensorFlow tensors 

11""" 

12from __future__ import annotations 

13 

14import logging 

15from typing import Any, List, Optional, Tuple 

16 

17import pkg_resources 

18 

19from openhcs.core.memory.decorators import tensorflow as tensorflow_func 

20from openhcs.core.lazy_gpu_imports import tf 

21 

22# Define error variable 

23TENSORFLOW_ERROR = "" 

24 

25# Check TensorFlow version for DLPack compatibility if available 

26if tf is not None: 26 ↛ 42line 26 didn't jump to line 42 because the condition on line 26 was always true

27 try: 

28 tf_version = pkg_resources.parse_version(tf.__version__) 

29 min_version = pkg_resources.parse_version("2.12.0") 

30 

31 if tf_version < min_version: 

32 TENSORFLOW_ERROR = ( 

33 f"TensorFlow version {tf.__version__} is not supported for DLPack operations. " 

34 f"Version 2.12.0 or higher is required for stable DLPack support. " 

35 f"Clause 88 violation: Cannot infer DLPack capability." 

36 ) 

37 tf = None 

38 except Exception as e: 

39 TENSORFLOW_ERROR = str(e) 

40 tf = None 

41 

42logger = logging.getLogger(__name__) 

43 

44 

45@tensorflow_func 

46def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "tf.Tensor": 

47 """ 

48 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center. 

49 

50 Args: 

51 height: Height of the mask 

52 width: Width of the mask 

53 margin_ratio: Ratio of the margin to the image size 

54 

55 Returns: 

56 2D TensorFlow weight mask of shape (height, width) 

57 """ 

58 # The compiler will ensure this function is only called when TensorFlow is available 

59 # No need to check for TensorFlow availability here 

60 

61 margin_y = int(tf.math.floor(height * margin_ratio)) 

62 margin_x = int(tf.math.floor(width * margin_ratio)) 

63 

64 weight_y = tf.ones(height, dtype=tf.float32) 

65 if margin_y > 0: 

66 ramp_top = tf.linspace(0.0, 1.0, margin_y) 

67 ramp_bottom = tf.linspace(1.0, 0.0, margin_y) 

68 

69 # Update slices of the weight_y tensor 

70 weight_y = tf.tensor_scatter_nd_update( 

71 weight_y, 

72 tf.stack([tf.range(margin_y)], axis=1), 

73 ramp_top 

74 ) 

75 weight_y = tf.tensor_scatter_nd_update( 

76 weight_y, 

77 tf.stack([tf.range(height - margin_y, height)], axis=1), 

78 ramp_bottom 

79 ) 

80 

81 weight_x = tf.ones(width, dtype=tf.float32) 

82 if margin_x > 0: 

83 ramp_left = tf.linspace(0.0, 1.0, margin_x) 

84 ramp_right = tf.linspace(1.0, 0.0, margin_x) 

85 

86 # Update slices of the weight_x tensor 

87 weight_x = tf.tensor_scatter_nd_update( 

88 weight_x, 

89 tf.stack([tf.range(margin_x)], axis=1), 

90 ramp_left 

91 ) 

92 weight_x = tf.tensor_scatter_nd_update( 

93 weight_x, 

94 tf.stack([tf.range(width - margin_x, width)], axis=1), 

95 ramp_right 

96 ) 

97 

98 # Create 2D weight mask using outer product 

99 weight_mask = tf.tensordot(weight_y, weight_x, axes=0) 

100 

101 return weight_mask 

102 

103 

104def _validate_3d_array(array: Any, name: str = "input") -> None: 

105 """ 

106 Validate that the input is a 3D TensorFlow tensor. 

107 

108 Args: 

109 array: Array to validate 

110 name: Name of the array for error messages 

111 

112 Raises: 

113 TypeError: If the array is not a TensorFlow tensor 

114 ValueError: If the array is not 3D 

115 ImportError: If TensorFlow is not available 

116 """ 

117 # The compiler will ensure this function is only called when TensorFlow is available 

118 # No need to check for TensorFlow availability here 

119 

120 if not isinstance(array, tf.Tensor): 

121 raise TypeError(f"{name} must be a TensorFlow tensor, got {type(array)}. " 

122 f"No automatic conversion is performed to maintain explicit contracts.") 

123 

124 if len(array.shape) != 3: 

125 raise ValueError(f"{name} must be a 3D tensor, got {len(array.shape)}D") 

126 

127def _gaussian_blur(image: "tf.Tensor", sigma: float) -> "tf.Tensor": 

128 """ 

129 Apply Gaussian blur to a 2D image. 

130 

131 Args: 

132 image: 2D TensorFlow tensor of shape (H, W) 

133 sigma: Standard deviation of the Gaussian kernel 

134 

135 Returns: 

136 Blurred 2D TensorFlow tensor of shape (H, W) 

137 """ 

138 # Calculate kernel size based on sigma 

139 kernel_size = max(3, int(2 * 4 * sigma + 1)) 

140 if kernel_size % 2 == 0: 

141 kernel_size += 1 # Ensure odd kernel size 

142 

143 # Add batch and channel dimensions for tf.image.gaussian_blur 

144 img = tf.expand_dims(tf.expand_dims(image, 0), -1) 

145 

146 # Apply Gaussian blur 

147 blurred = tf.image.gaussian_blur( 

148 img, 

149 [kernel_size, kernel_size], 

150 sigma, 

151 "REFLECT" 

152 ) 

153 

154 # Remove batch and channel dimensions 

155 return tf.squeeze(blurred) 

156 

157@tensorflow_func 

158def sharpen( 

159 image: "tf.Tensor", radius: float = 1.0, amount: float = 1.0 

160) -> "tf.Tensor": 

161 """ 

162 Sharpen a 3D image using unsharp masking. 

163 

164 This applies sharpening to each Z-slice independently. 

165 

166 Args: 

167 image: 3D TensorFlow tensor of shape (Z, Y, X) 

168 radius: Radius of Gaussian blur 

169 amount: Sharpening strength 

170 

171 Returns: 

172 Sharpened 3D TensorFlow tensor of shape (Z, Y, X) 

173 """ 

174 _validate_3d_array(image) 

175 

176 # Store original dtype 

177 dtype = image.dtype 

178 

179 # Process each Z-slice independently 

180 result_list = [] 

181 

182 for z in range(image.shape[0]): 

183 # Convert to float for processing 

184 slice_float = tf.cast(image[z], tf.float32) / tf.reduce_max(image[z]) 

185 

186 # Create blurred version for unsharp mask 

187 blurred = _gaussian_blur(slice_float, sigma=radius) 

188 

189 # Apply unsharp mask: original + amount * (original - blurred) 

190 sharpened = slice_float + amount * (slice_float - blurred) 

191 

192 # Clip to valid range 

193 sharpened = tf.clip_by_value(sharpened, 0.0, 1.0) 

194 

195 # Scale back to original range 

196 min_val = tf.reduce_min(sharpened) 

197 max_val = tf.reduce_max(sharpened) 

198 if max_val > min_val: 

199 sharpened = (sharpened - min_val) * 65535.0 / (max_val - min_val) 

200 

201 result_list.append(sharpened) 

202 

203 # Stack results back into a 3D tensor 

204 result = tf.stack(result_list, axis=0) 

205 

206 # Convert back to original dtype 

207 if dtype == tf.uint16: 

208 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16) 

209 else: 

210 result = tf.cast(result, dtype) 

211 

212 return result 

213 

214@tensorflow_func 

215def percentile_normalize( 

216 image: "tf.Tensor", 

217 low_percentile: float = 1.0, 

218 high_percentile: float = 99.0, 

219 target_min: float = 0.0, 

220 target_max: float = 65535.0 

221) -> "tf.Tensor": 

222 """ 

223 Normalize a 3D image using percentile-based contrast stretching. 

224 

225 This applies normalization to each Z-slice independently. 

226 

227 Args: 

228 image: 3D TensorFlow tensor of shape (Z, Y, X) 

229 low_percentile: Lower percentile (0-100) 

230 high_percentile: Upper percentile (0-100) 

231 target_min: Target minimum value 

232 target_max: Target maximum value 

233 

234 Returns: 

235 Normalized 3D TensorFlow tensor of shape (Z, Y, X) 

236 """ 

237 _validate_3d_array(image) 

238 

239 # Process each Z-slice independently 

240 result_list = [] 

241 

242 for z in range(image.shape[0]): 

243 # Get percentile values for this slice 

244 # TensorFlow doesn't have a direct percentile function, so we use a workaround 

245 flat_slice = tf.reshape(image[z], [-1]) 

246 sorted_slice = tf.sort(flat_slice) 

247 

248 # Calculate indices for percentiles 

249 slice_size = tf.cast(tf.size(flat_slice), tf.float32) 

250 low_idx = tf.cast(tf.math.floor(slice_size * low_percentile / 100.0), tf.int32) 

251 high_idx = tf.cast(tf.math.floor(slice_size * high_percentile / 100.0), tf.int32) 

252 

253 # Get percentile values 

254 p_low = sorted_slice[low_idx] 

255 p_high = sorted_slice[high_idx] 

256 

257 # Avoid division by zero 

258 if p_high == p_low: 

259 result_list.append(tf.ones_like(image[z], dtype=tf.float32) * target_min) 

260 continue 

261 

262 # Clip and normalize to target range 

263 clipped = tf.clip_by_value(tf.cast(image[z], tf.float32), p_low, p_high) 

264 scale = (target_max - target_min) / (p_high - p_low) 

265 normalized = (clipped - p_low) * scale + target_min 

266 result_list.append(normalized) 

267 

268 # Stack results back into a 3D tensor 

269 result = tf.stack(result_list, axis=0) 

270 

271 # Convert to uint16 

272 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16) 

273 

274 return result 

275 

276@tensorflow_func 

277def stack_percentile_normalize( 

278 stack: "tf.Tensor", 

279 low_percentile: float = 1.0, 

280 high_percentile: float = 99.0, 

281 target_min: float = 0.0, 

282 target_max: float = 65535.0 

283) -> "tf.Tensor": 

284 """ 

285 Normalize a stack using global percentile-based contrast stretching. 

286 

287 This ensures consistent normalization across all Z-slices by computing 

288 global percentiles across the entire stack. 

289 

290 Args: 

291 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

292 low_percentile: Lower percentile (0-100) 

293 high_percentile: Upper percentile (0-100) 

294 target_min: Target minimum value 

295 target_max: Target maximum value 

296 

297 Returns: 

298 Normalized 3D TensorFlow tensor of shape (Z, Y, X) 

299 """ 

300 _validate_3d_array(stack) 

301 

302 # Calculate global percentiles across the entire stack using TensorFlow Probability 

303 # This is memory-efficient and doesn't require sorting the entire array 

304 try: 

305 import tensorflow_probability as tfp 

306 p_low = tf.cast(tfp.stats.percentile(stack, low_percentile), tf.float32) 

307 p_high = tf.cast(tfp.stats.percentile(stack, high_percentile), tf.float32) 

308 except ImportError: 

309 # Fallback to manual calculation if TensorFlow Probability is not available 

310 # This is less memory-efficient but works 

311 flat_stack = tf.reshape(stack, [-1]) 

312 sorted_stack = tf.sort(flat_stack) 

313 

314 # Calculate indices for percentiles 

315 stack_size = tf.cast(tf.size(flat_stack), tf.float32) 

316 low_idx = tf.cast(tf.math.floor(stack_size * low_percentile / 100.0), tf.int32) 

317 high_idx = tf.cast(tf.math.floor(stack_size * high_percentile / 100.0), tf.int32) 

318 

319 # Get percentile values and cast to float32 for consistency 

320 p_low = tf.cast(sorted_stack[low_idx], tf.float32) 

321 p_high = tf.cast(sorted_stack[high_idx], tf.float32) 

322 

323 # Avoid division by zero 

324 if p_high == p_low: 

325 return tf.ones_like(stack) * target_min 

326 

327 # Clip and normalize to target range (match NumPy implementation exactly) 

328 clipped = tf.clip_by_value(stack, p_low, p_high) 

329 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min 

330 normalized = tf.cast(normalized, tf.uint16) 

331 

332 return normalized 

333 

334@tensorflow_func 

335def create_composite( 

336 images: List["tf.Tensor"], weights: Optional[List[float]] = None 

337) -> "tf.Tensor": 

338 """ 

339 Create a composite image from multiple 3D arrays. 

340 

341 Args: 

342 images: List of 3D TensorFlow tensors, each of shape (Z, Y, X) 

343 weights: List of weights for each image. If None, equal weights are used. 

344 

345 Returns: 

346 Composite 3D TensorFlow tensor of shape (Z, Y, X) 

347 """ 

348 # Ensure images is a list 

349 if not isinstance(images, list): 

350 raise TypeError("images must be a list of TensorFlow tensors") 

351 

352 # Check for empty list early 

353 if not images: 

354 raise ValueError("images list cannot be empty") 

355 

356 # Validate all images are 3D TensorFlow tensors with the same shape 

357 for i, img in enumerate(images): 

358 _validate_3d_array(img, f"images[{i}]") 

359 if img.shape != images[0].shape: 

360 raise ValueError( 

361 f"All images must have the same shape. " 

362 f"images[0] has shape {images[0].shape}, " 

363 f"images[{i}] has shape {img.shape}" 

364 ) 

365 

366 # Default weights if none provided 

367 if weights is None: 

368 # Equal weights for all images 

369 weights = [1.0 / len(images)] * len(images) 

370 elif not isinstance(weights, list): 

371 raise TypeError("weights must be a list of values") 

372 

373 # Make sure weights list is at least as long as images list 

374 if len(weights) < len(images): 

375 weights = weights + [0.0] * (len(images) - len(weights)) 

376 # Truncate weights if longer than images 

377 weights = weights[:len(images)] 

378 

379 first_image = images[0] 

380 shape = first_image.shape 

381 dtype = first_image.dtype 

382 

383 # Create empty composite 

384 composite = tf.zeros(shape, dtype=tf.float32) 

385 total_weight = 0.0 

386 

387 # Add each image with its weight 

388 for i, image in enumerate(images): 

389 weight = weights[i] 

390 if weight <= 0.0: 

391 continue 

392 

393 # Add to composite 

394 composite += tf.cast(image, tf.float32) * weight 

395 total_weight += weight 

396 

397 # Normalize by total weight 

398 if total_weight > 0: 

399 composite /= total_weight 

400 

401 # Convert back to original dtype (usually uint16) 

402 if dtype in [tf.uint8, tf.uint16, tf.uint32, tf.int8, tf.int16, tf.int32, tf.int64]: 

403 # Get the maximum value for the specific integer dtype 

404 if dtype == tf.uint8: 

405 max_val = 255 

406 elif dtype == tf.uint16: 

407 max_val = 65535 

408 elif dtype == tf.uint32: 

409 max_val = 4294967295 

410 elif dtype == tf.int8: 

411 max_val = 127 

412 elif dtype == tf.int16: 

413 max_val = 32767 

414 elif dtype == tf.int32: 

415 max_val = 2147483647 

416 elif dtype == tf.int64: 

417 max_val = 9223372036854775807 

418 

419 composite = tf.cast(tf.clip_by_value(composite, 0, max_val), dtype) 

420 else: 

421 composite = tf.cast(composite, dtype) 

422 

423 return composite 

424 

425@tensorflow_func 

426def apply_mask(image: "tf.Tensor", mask: "tf.Tensor") -> "tf.Tensor": 

427 """ 

428 Apply a mask to a 3D image. 

429 

430 This applies the mask to each Z-slice independently if mask is 2D, 

431 or applies the 3D mask directly if mask is 3D. 

432 

433 Args: 

434 image: 3D TensorFlow tensor of shape (Z, Y, X) 

435 mask: 3D TensorFlow tensor of shape (Z, Y, X) or 2D TensorFlow tensor of shape (Y, X) 

436 

437 Returns: 

438 Masked 3D TensorFlow tensor of shape (Z, Y, X) 

439 """ 

440 _validate_3d_array(image) 

441 

442 # Handle 2D mask (apply to each Z-slice) 

443 if isinstance(mask, tf.Tensor) and len(mask.shape) == 2: 

444 if mask.shape != image.shape[1:]: 

445 raise ValueError( 

446 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}" 

447 ) 

448 

449 # Apply 2D mask to each Z-slice 

450 result_list = [] 

451 for z in range(image.shape[0]): 

452 result_list.append(tf.cast(image[z], tf.float32) * tf.cast(mask, tf.float32)) 

453 

454 result = tf.stack(result_list, axis=0) 

455 return tf.cast(result, image.dtype) 

456 

457 # Handle 3D mask 

458 if isinstance(mask, tf.Tensor) and len(mask.shape) == 3: 

459 if mask.shape != image.shape: 

460 raise ValueError( 

461 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}" 

462 ) 

463 

464 # Apply 3D mask directly 

465 masked = tf.cast(image, tf.float32) * tf.cast(mask, tf.float32) 

466 return tf.cast(masked, image.dtype) 

467 

468 # If we get here, the mask is neither 2D nor 3D TensorFlow tensor 

469 raise TypeError(f"mask must be a 2D or 3D TensorFlow tensor, got {type(mask)}") 

470 

471@tensorflow_func 

472def create_weight_mask( 

473 shape: Tuple[int, int], margin_ratio: float = 0.1 

474) -> "tf.Tensor": 

475 """ 

476 Create a weight mask for blending images. 

477 

478 Args: 

479 shape: Shape of the mask (height, width) 

480 margin_ratio: Ratio of image size to use as margin 

481 

482 Returns: 

483 2D TensorFlow weight mask of shape (Y, X) 

484 """ 

485 if not isinstance(shape, tuple) or len(shape) != 2: 

486 raise TypeError("shape must be a tuple of (height, width)") 

487 

488 height, width = shape 

489 return create_linear_weight_mask(height, width, margin_ratio) 

490 

491@tensorflow_func 

492def max_projection(stack: "tf.Tensor") -> "tf.Tensor": 

493 """ 

494 Create a maximum intensity projection from a Z-stack. 

495 

496 Args: 

497 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

498 

499 Returns: 

500 3D TensorFlow tensor of shape (1, Y, X) 

501 """ 

502 _validate_3d_array(stack) 

503 

504 # Create max projection 

505 projection_2d = tf.reduce_max(stack, axis=0) 

506 return tf.expand_dims(projection_2d, axis=0) 

507 

508@tensorflow_func 

509def mean_projection(stack: "tf.Tensor") -> "tf.Tensor": 

510 """ 

511 Create a mean intensity projection from a Z-stack. 

512 

513 Args: 

514 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

515 

516 Returns: 

517 3D TensorFlow tensor of shape (1, Y, X) 

518 """ 

519 _validate_3d_array(stack) 

520 

521 # Create mean projection 

522 projection_2d = tf.cast(tf.reduce_mean(tf.cast(stack, tf.float32), axis=0), stack.dtype) 

523 return tf.expand_dims(projection_2d, axis=0) 

524 

525@tensorflow_func 

526def stack_equalize_histogram( 

527 stack: "tf.Tensor", 

528 bins: int = 65536, 

529 range_min: float = 0.0, 

530 range_max: float = 65535.0 

531) -> "tf.Tensor": 

532 """ 

533 Apply histogram equalization to an entire stack. 

534 

535 This ensures consistent contrast enhancement across all Z-slices by 

536 computing a global histogram across the entire stack. 

537 

538 Args: 

539 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

540 bins: Number of bins for histogram computation 

541 range_min: Minimum value for histogram range 

542 range_max: Maximum value for histogram range 

543 

544 Returns: 

545 Equalized 3D TensorFlow tensor of shape (Z, Y, X) 

546 """ 

547 _validate_3d_array(stack) 

548 

549 # TensorFlow doesn't have a direct histogram equalization function 

550 # We'll implement it manually 

551 

552 # Flatten the entire stack to compute the global histogram 

553 flat_stack = tf.reshape(tf.cast(stack, tf.float32), [-1]) 

554 

555 # Calculate the histogram 

556 # TensorFlow doesn't have a direct equivalent to numpy's histogram 

557 # We'll use tf.histogram_fixed_width 

558 hist = tf.histogram_fixed_width( 

559 flat_stack, 

560 [range_min, range_max], 

561 nbins=bins 

562 ) 

563 

564 # Calculate cumulative distribution function (CDF) 

565 cdf = tf.cumsum(hist) 

566 

567 # Normalize the CDF to the range [0, 65535] 

568 # Avoid division by zero 

569 if tf.reduce_max(cdf) > 0: 

570 cdf = 65535.0 * cdf / tf.cast(cdf[-1], tf.float32) 

571 

572 # We don't need bin width for the lookup table approach 

573 

574 # Scale input values to bin indices 

575 indices = tf.cast(tf.clip_by_value( 

576 tf.math.floor((flat_stack - range_min) / (range_max - range_min) * bins), 

577 0, bins - 1 

578 ), tf.int32) 

579 

580 # Look up CDF values 

581 equalized_flat = tf.gather(cdf, indices) 

582 

583 # Reshape back to original shape 

584 equalized_stack = tf.reshape(equalized_flat, stack.shape) 

585 

586 # Convert to uint16 

587 return tf.cast(equalized_stack, tf.uint16) 

588 

589@tensorflow_func 

590def create_projection( 

591 stack: "tf.Tensor", method: str = "max_projection" 

592) -> "tf.Tensor": 

593 """ 

594 Create a projection from a stack using the specified method. 

595 

596 Args: 

597 stack: 3D TensorFlow tensor of shape (Z, Y, X) 

598 method: Projection method (max_projection, mean_projection) 

599 

600 Returns: 

601 3D TensorFlow tensor of shape (1, Y, X) 

602 """ 

603 _validate_3d_array(stack) 

604 

605 if method == "max_projection": 

606 return max_projection(stack) 

607 

608 if method == "mean_projection": 

609 return mean_projection(stack) 

610 

611 # FAIL FAST: No fallback projection methods 

612 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection") 

613 

614@tensorflow_func 

615def tophat( 

616 image: "tf.Tensor", 

617 selem_radius: int = 50, 

618 downsample_factor: int = 4 

619) -> "tf.Tensor": 

620 """ 

621 Apply white top-hat filter to a 3D image for background removal. 

622 

623 This applies the filter to each Z-slice independently using TensorFlow's 

624 native operations. 

625 

626 Args: 

627 image: 3D TensorFlow tensor of shape (Z, Y, X) 

628 selem_radius: Radius of the structuring element disk 

629 downsample_factor: Factor by which to downsample the image for processing 

630 

631 Returns: 

632 Filtered 3D TensorFlow tensor of shape (Z, Y, X) 

633 """ 

634 _validate_3d_array(image) 

635 

636 # Process each Z-slice independently 

637 result_list = [] 

638 

639 for z in range(image.shape[0]): 

640 # Store original data type 

641 input_dtype = image[z].dtype 

642 

643 # 1) Downsample using TensorFlow's resize function 

644 # First, add batch and channel dimensions for resize 

645 img_4d = tf.expand_dims(tf.expand_dims(tf.cast(image[z], tf.float32), 0), -1) 

646 

647 # Calculate new dimensions 

648 new_h = tf.cast(tf.shape(image[z])[0] // downsample_factor, tf.int32) 

649 new_w = tf.cast(tf.shape(image[z])[1] // downsample_factor, tf.int32) 

650 

651 # Resize using TensorFlow's resize function 

652 image_small = tf.squeeze(tf.image.resize( 

653 img_4d, 

654 [new_h, new_w], 

655 method=tf.image.ResizeMethod.BILINEAR 

656 ), axis=[0, -1]) 

657 

658 # 2) Create a circular structuring element 

659 small_selem_radius = tf.maximum(1, selem_radius // downsample_factor) 

660 small_grid_size = 2 * small_selem_radius + 1 

661 

662 # Create grid for structuring element 

663 y_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32) 

664 x_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32) 

665 small_y_grid, small_x_grid = tf.meshgrid(y_range, x_range) 

666 

667 # Create circular mask 

668 small_mask = tf.cast( 

669 tf.sqrt(tf.square(small_y_grid) + tf.square(small_x_grid)) <= small_selem_radius, 

670 tf.float32 

671 ) 

672 

673 # 3) Apply white top-hat using TensorFlow's morphological operations 

674 # White top-hat is opening subtracted from the original image 

675 # Opening is erosion followed by dilation 

676 

677 # Implement erosion using TensorFlow's conv2d with custom kernel 

678 

679 # Pad the image to handle boundary conditions 

680 pad_size = small_selem_radius 

681 padded = tf.pad( 

682 tf.expand_dims(tf.expand_dims(image_small, 0), -1), 

683 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], 

684 mode='SYMMETRIC' 

685 ) 

686 

687 # For erosion, we need to find the minimum value in the neighborhood 

688 # We can use a trick: negate the image, apply max pooling, then negate again 

689 neg_padded = -padded 

690 

691 # Apply convolution with the kernel 

692 # We use a large negative value for pixels outside the mask 

693 mask_expanded = tf.reshape(small_mask, [small_grid_size, small_grid_size, 1, 1]) 

694 mask_complement = 1.0 - mask_expanded 

695 large_neg = tf.constant(-1e9, dtype=tf.float32) 

696 

697 # Custom erosion using depthwise_conv2d 

698 eroded_neg = tf.nn.depthwise_conv2d( 

699 neg_padded + mask_complement * large_neg, 

700 tf.tile(mask_expanded, [1, 1, 1, 1]), 

701 strides=[1, 1, 1, 1], 

702 padding='VALID' 

703 ) 

704 

705 # Convert back to positive 

706 eroded = -eroded_neg 

707 

708 # Implement dilation using similar approach 

709 # Pad the eroded image 

710 padded_eroded = tf.pad( 

711 eroded, 

712 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], 

713 mode='SYMMETRIC' 

714 ) 

715 

716 # For dilation, we need to find the maximum value in the neighborhood 

717 # Apply convolution with the kernel 

718 opened = tf.nn.depthwise_conv2d( 

719 padded_eroded, 

720 tf.tile(mask_expanded, [1, 1, 1, 1]), 

721 strides=[1, 1, 1, 1], 

722 padding='VALID' 

723 ) 

724 

725 # Remove batch and channel dimensions 

726 opened = tf.squeeze(opened, axis=[0, -1]) 

727 

728 # White top-hat is original minus opening 

729 tophat_small = image_small - opened 

730 

731 # 4) Calculate background 

732 background_small = image_small - tophat_small 

733 

734 # 5) Upscale background to original size 

735 background_4d = tf.expand_dims(tf.expand_dims(background_small, 0), -1) 

736 background_large = tf.squeeze(tf.image.resize( 

737 background_4d, 

738 tf.shape(image[z])[:2], 

739 method=tf.image.ResizeMethod.BILINEAR 

740 ), axis=[0, -1]) 

741 

742 # 6) Subtract background and clip negative values 

743 slice_result = tf.maximum(tf.cast(image[z], tf.float32) - background_large, 0.0) 

744 

745 # 7) Convert back to original data type 

746 result_list.append(tf.cast(slice_result, input_dtype)) 

747 

748 # Stack results back into a 3D tensor 

749 result = tf.stack(result_list, axis=0) 

750 

751 return result