Coverage for openhcs/processing/backends/processors/tensorflow_processor.py: 12.7%
248 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2TensorFlow Image Processor Implementation
4This module implements the ImageProcessorInterface using TensorFlow as the backend.
5It leverages GPU acceleration for image processing operations.
7Doctrinal Clauses:
8- Clause 3 — Declarative Primacy: All functions are pure and stateless
9- Clause 88 — No Inferred Capabilities: Explicit TensorFlow dependency
10- Clause 106-A — Declared Memory Types: All methods specify TensorFlow tensors
11"""
12from __future__ import annotations
14import logging
15from typing import Any, List, Optional, Tuple
17import pkg_resources
19from openhcs.core.memory.decorators import tensorflow as tensorflow_func
20from openhcs.core.utils import optional_import
22# Define error variable
23TENSORFLOW_ERROR = ""
25# Import TensorFlow as an optional dependency
26tf = optional_import("tensorflow")
28# Check TensorFlow version for DLPack compatibility if available
29if tf is not None: 29 ↛ 45line 29 didn't jump to line 45 because the condition on line 29 was always true
30 try:
31 tf_version = pkg_resources.parse_version(tf.__version__)
32 min_version = pkg_resources.parse_version("2.12.0")
34 if tf_version < min_version:
35 TENSORFLOW_ERROR = (
36 f"TensorFlow version {tf.__version__} is not supported for DLPack operations. "
37 f"Version 2.12.0 or higher is required for stable DLPack support. "
38 f"Clause 88 violation: Cannot infer DLPack capability."
39 )
40 tf = None
41 except Exception as e:
42 TENSORFLOW_ERROR = str(e)
43 tf = None
45logger = logging.getLogger(__name__)
48@tensorflow_func
49def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "tf.Tensor":
50 """
51 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center.
53 Args:
54 height: Height of the mask
55 width: Width of the mask
56 margin_ratio: Ratio of the margin to the image size
58 Returns:
59 2D TensorFlow weight mask of shape (height, width)
60 """
61 # The compiler will ensure this function is only called when TensorFlow is available
62 # No need to check for TensorFlow availability here
64 margin_y = int(tf.math.floor(height * margin_ratio))
65 margin_x = int(tf.math.floor(width * margin_ratio))
67 weight_y = tf.ones(height, dtype=tf.float32)
68 if margin_y > 0:
69 ramp_top = tf.linspace(0.0, 1.0, margin_y)
70 ramp_bottom = tf.linspace(1.0, 0.0, margin_y)
72 # Update slices of the weight_y tensor
73 weight_y = tf.tensor_scatter_nd_update(
74 weight_y,
75 tf.stack([tf.range(margin_y)], axis=1),
76 ramp_top
77 )
78 weight_y = tf.tensor_scatter_nd_update(
79 weight_y,
80 tf.stack([tf.range(height - margin_y, height)], axis=1),
81 ramp_bottom
82 )
84 weight_x = tf.ones(width, dtype=tf.float32)
85 if margin_x > 0:
86 ramp_left = tf.linspace(0.0, 1.0, margin_x)
87 ramp_right = tf.linspace(1.0, 0.0, margin_x)
89 # Update slices of the weight_x tensor
90 weight_x = tf.tensor_scatter_nd_update(
91 weight_x,
92 tf.stack([tf.range(margin_x)], axis=1),
93 ramp_left
94 )
95 weight_x = tf.tensor_scatter_nd_update(
96 weight_x,
97 tf.stack([tf.range(width - margin_x, width)], axis=1),
98 ramp_right
99 )
101 # Create 2D weight mask using outer product
102 weight_mask = tf.tensordot(weight_y, weight_x, axes=0)
104 return weight_mask
107def _validate_3d_array(array: Any, name: str = "input") -> None:
108 """
109 Validate that the input is a 3D TensorFlow tensor.
111 Args:
112 array: Array to validate
113 name: Name of the array for error messages
115 Raises:
116 TypeError: If the array is not a TensorFlow tensor
117 ValueError: If the array is not 3D
118 ImportError: If TensorFlow is not available
119 """
120 # The compiler will ensure this function is only called when TensorFlow is available
121 # No need to check for TensorFlow availability here
123 if not isinstance(array, tf.Tensor):
124 raise TypeError(f"{name} must be a TensorFlow tensor, got {type(array)}. "
125 f"No automatic conversion is performed to maintain explicit contracts.")
127 if len(array.shape) != 3:
128 raise ValueError(f"{name} must be a 3D tensor, got {len(array.shape)}D")
130def _gaussian_blur(image: "tf.Tensor", sigma: float) -> "tf.Tensor":
131 """
132 Apply Gaussian blur to a 2D image.
134 Args:
135 image: 2D TensorFlow tensor of shape (H, W)
136 sigma: Standard deviation of the Gaussian kernel
138 Returns:
139 Blurred 2D TensorFlow tensor of shape (H, W)
140 """
141 # Calculate kernel size based on sigma
142 kernel_size = max(3, int(2 * 4 * sigma + 1))
143 if kernel_size % 2 == 0:
144 kernel_size += 1 # Ensure odd kernel size
146 # Add batch and channel dimensions for tf.image.gaussian_blur
147 img = tf.expand_dims(tf.expand_dims(image, 0), -1)
149 # Apply Gaussian blur
150 blurred = tf.image.gaussian_blur(
151 img,
152 [kernel_size, kernel_size],
153 sigma,
154 "REFLECT"
155 )
157 # Remove batch and channel dimensions
158 return tf.squeeze(blurred)
160@tensorflow_func
161def sharpen(
162 image: "tf.Tensor", radius: float = 1.0, amount: float = 1.0
163) -> "tf.Tensor":
164 """
165 Sharpen a 3D image using unsharp masking.
167 This applies sharpening to each Z-slice independently.
169 Args:
170 image: 3D TensorFlow tensor of shape (Z, Y, X)
171 radius: Radius of Gaussian blur
172 amount: Sharpening strength
174 Returns:
175 Sharpened 3D TensorFlow tensor of shape (Z, Y, X)
176 """
177 _validate_3d_array(image)
179 # Store original dtype
180 dtype = image.dtype
182 # Process each Z-slice independently
183 result_list = []
185 for z in range(image.shape[0]):
186 # Convert to float for processing
187 slice_float = tf.cast(image[z], tf.float32) / tf.reduce_max(image[z])
189 # Create blurred version for unsharp mask
190 blurred = _gaussian_blur(slice_float, sigma=radius)
192 # Apply unsharp mask: original + amount * (original - blurred)
193 sharpened = slice_float + amount * (slice_float - blurred)
195 # Clip to valid range
196 sharpened = tf.clip_by_value(sharpened, 0.0, 1.0)
198 # Scale back to original range
199 min_val = tf.reduce_min(sharpened)
200 max_val = tf.reduce_max(sharpened)
201 if max_val > min_val:
202 sharpened = (sharpened - min_val) * 65535.0 / (max_val - min_val)
204 result_list.append(sharpened)
206 # Stack results back into a 3D tensor
207 result = tf.stack(result_list, axis=0)
209 # Convert back to original dtype
210 if dtype == tf.uint16:
211 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16)
212 else:
213 result = tf.cast(result, dtype)
215 return result
217@tensorflow_func
218def percentile_normalize(
219 image: "tf.Tensor",
220 low_percentile: float = 1.0,
221 high_percentile: float = 99.0,
222 target_min: float = 0.0,
223 target_max: float = 65535.0
224) -> "tf.Tensor":
225 """
226 Normalize a 3D image using percentile-based contrast stretching.
228 This applies normalization to each Z-slice independently.
230 Args:
231 image: 3D TensorFlow tensor of shape (Z, Y, X)
232 low_percentile: Lower percentile (0-100)
233 high_percentile: Upper percentile (0-100)
234 target_min: Target minimum value
235 target_max: Target maximum value
237 Returns:
238 Normalized 3D TensorFlow tensor of shape (Z, Y, X)
239 """
240 _validate_3d_array(image)
242 # Process each Z-slice independently
243 result_list = []
245 for z in range(image.shape[0]):
246 # Get percentile values for this slice
247 # TensorFlow doesn't have a direct percentile function, so we use a workaround
248 flat_slice = tf.reshape(image[z], [-1])
249 sorted_slice = tf.sort(flat_slice)
251 # Calculate indices for percentiles
252 slice_size = tf.cast(tf.size(flat_slice), tf.float32)
253 low_idx = tf.cast(tf.math.floor(slice_size * low_percentile / 100.0), tf.int32)
254 high_idx = tf.cast(tf.math.floor(slice_size * high_percentile / 100.0), tf.int32)
256 # Get percentile values
257 p_low = sorted_slice[low_idx]
258 p_high = sorted_slice[high_idx]
260 # Avoid division by zero
261 if p_high == p_low:
262 result_list.append(tf.ones_like(image[z], dtype=tf.float32) * target_min)
263 continue
265 # Clip and normalize to target range
266 clipped = tf.clip_by_value(tf.cast(image[z], tf.float32), p_low, p_high)
267 scale = (target_max - target_min) / (p_high - p_low)
268 normalized = (clipped - p_low) * scale + target_min
269 result_list.append(normalized)
271 # Stack results back into a 3D tensor
272 result = tf.stack(result_list, axis=0)
274 # Convert to uint16
275 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16)
277 return result
279@tensorflow_func
280def stack_percentile_normalize(
281 stack: "tf.Tensor",
282 low_percentile: float = 1.0,
283 high_percentile: float = 99.0,
284 target_min: float = 0.0,
285 target_max: float = 65535.0
286) -> "tf.Tensor":
287 """
288 Normalize a stack using global percentile-based contrast stretching.
290 This ensures consistent normalization across all Z-slices by computing
291 global percentiles across the entire stack.
293 Args:
294 stack: 3D TensorFlow tensor of shape (Z, Y, X)
295 low_percentile: Lower percentile (0-100)
296 high_percentile: Upper percentile (0-100)
297 target_min: Target minimum value
298 target_max: Target maximum value
300 Returns:
301 Normalized 3D TensorFlow tensor of shape (Z, Y, X)
302 """
303 _validate_3d_array(stack)
305 # Calculate global percentiles across the entire stack using TensorFlow Probability
306 # This is memory-efficient and doesn't require sorting the entire array
307 try:
308 import tensorflow_probability as tfp
309 p_low = tf.cast(tfp.stats.percentile(stack, low_percentile), tf.float32)
310 p_high = tf.cast(tfp.stats.percentile(stack, high_percentile), tf.float32)
311 except ImportError:
312 # Fallback to manual calculation if TensorFlow Probability is not available
313 # This is less memory-efficient but works
314 flat_stack = tf.reshape(stack, [-1])
315 sorted_stack = tf.sort(flat_stack)
317 # Calculate indices for percentiles
318 stack_size = tf.cast(tf.size(flat_stack), tf.float32)
319 low_idx = tf.cast(tf.math.floor(stack_size * low_percentile / 100.0), tf.int32)
320 high_idx = tf.cast(tf.math.floor(stack_size * high_percentile / 100.0), tf.int32)
322 # Get percentile values and cast to float32 for consistency
323 p_low = tf.cast(sorted_stack[low_idx], tf.float32)
324 p_high = tf.cast(sorted_stack[high_idx], tf.float32)
326 # Avoid division by zero
327 if p_high == p_low:
328 return tf.ones_like(stack) * target_min
330 # Clip and normalize to target range (match NumPy implementation exactly)
331 clipped = tf.clip_by_value(stack, p_low, p_high)
332 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min
333 normalized = tf.cast(normalized, tf.uint16)
335 return normalized
337@tensorflow_func
338def create_composite(
339 images: List["tf.Tensor"], weights: Optional[List[float]] = None
340) -> "tf.Tensor":
341 """
342 Create a composite image from multiple 3D arrays.
344 Args:
345 images: List of 3D TensorFlow tensors, each of shape (Z, Y, X)
346 weights: List of weights for each image. If None, equal weights are used.
348 Returns:
349 Composite 3D TensorFlow tensor of shape (Z, Y, X)
350 """
351 # Ensure images is a list
352 if not isinstance(images, list):
353 raise TypeError("images must be a list of TensorFlow tensors")
355 # Check for empty list early
356 if not images:
357 raise ValueError("images list cannot be empty")
359 # Validate all images are 3D TensorFlow tensors with the same shape
360 for i, img in enumerate(images):
361 _validate_3d_array(img, f"images[{i}]")
362 if img.shape != images[0].shape:
363 raise ValueError(
364 f"All images must have the same shape. "
365 f"images[0] has shape {images[0].shape}, "
366 f"images[{i}] has shape {img.shape}"
367 )
369 # Default weights if none provided
370 if weights is None:
371 # Equal weights for all images
372 weights = [1.0 / len(images)] * len(images)
373 elif not isinstance(weights, list):
374 raise TypeError("weights must be a list of values")
376 # Make sure weights list is at least as long as images list
377 if len(weights) < len(images):
378 weights = weights + [0.0] * (len(images) - len(weights))
379 # Truncate weights if longer than images
380 weights = weights[:len(images)]
382 first_image = images[0]
383 shape = first_image.shape
384 dtype = first_image.dtype
386 # Create empty composite
387 composite = tf.zeros(shape, dtype=tf.float32)
388 total_weight = 0.0
390 # Add each image with its weight
391 for i, image in enumerate(images):
392 weight = weights[i]
393 if weight <= 0.0:
394 continue
396 # Add to composite
397 composite += tf.cast(image, tf.float32) * weight
398 total_weight += weight
400 # Normalize by total weight
401 if total_weight > 0:
402 composite /= total_weight
404 # Convert back to original dtype (usually uint16)
405 if dtype in [tf.uint8, tf.uint16, tf.uint32, tf.int8, tf.int16, tf.int32, tf.int64]:
406 # Get the maximum value for the specific integer dtype
407 if dtype == tf.uint8:
408 max_val = 255
409 elif dtype == tf.uint16:
410 max_val = 65535
411 elif dtype == tf.uint32:
412 max_val = 4294967295
413 elif dtype == tf.int8:
414 max_val = 127
415 elif dtype == tf.int16:
416 max_val = 32767
417 elif dtype == tf.int32:
418 max_val = 2147483647
419 elif dtype == tf.int64:
420 max_val = 9223372036854775807
422 composite = tf.cast(tf.clip_by_value(composite, 0, max_val), dtype)
423 else:
424 composite = tf.cast(composite, dtype)
426 return composite
428@tensorflow_func
429def apply_mask(image: "tf.Tensor", mask: "tf.Tensor") -> "tf.Tensor":
430 """
431 Apply a mask to a 3D image.
433 This applies the mask to each Z-slice independently if mask is 2D,
434 or applies the 3D mask directly if mask is 3D.
436 Args:
437 image: 3D TensorFlow tensor of shape (Z, Y, X)
438 mask: 3D TensorFlow tensor of shape (Z, Y, X) or 2D TensorFlow tensor of shape (Y, X)
440 Returns:
441 Masked 3D TensorFlow tensor of shape (Z, Y, X)
442 """
443 _validate_3d_array(image)
445 # Handle 2D mask (apply to each Z-slice)
446 if isinstance(mask, tf.Tensor) and len(mask.shape) == 2:
447 if mask.shape != image.shape[1:]:
448 raise ValueError(
449 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}"
450 )
452 # Apply 2D mask to each Z-slice
453 result_list = []
454 for z in range(image.shape[0]):
455 result_list.append(tf.cast(image[z], tf.float32) * tf.cast(mask, tf.float32))
457 result = tf.stack(result_list, axis=0)
458 return tf.cast(result, image.dtype)
460 # Handle 3D mask
461 if isinstance(mask, tf.Tensor) and len(mask.shape) == 3:
462 if mask.shape != image.shape:
463 raise ValueError(
464 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}"
465 )
467 # Apply 3D mask directly
468 masked = tf.cast(image, tf.float32) * tf.cast(mask, tf.float32)
469 return tf.cast(masked, image.dtype)
471 # If we get here, the mask is neither 2D nor 3D TensorFlow tensor
472 raise TypeError(f"mask must be a 2D or 3D TensorFlow tensor, got {type(mask)}")
474@tensorflow_func
475def create_weight_mask(
476 shape: Tuple[int, int], margin_ratio: float = 0.1
477) -> "tf.Tensor":
478 """
479 Create a weight mask for blending images.
481 Args:
482 shape: Shape of the mask (height, width)
483 margin_ratio: Ratio of image size to use as margin
485 Returns:
486 2D TensorFlow weight mask of shape (Y, X)
487 """
488 if not isinstance(shape, tuple) or len(shape) != 2:
489 raise TypeError("shape must be a tuple of (height, width)")
491 height, width = shape
492 return create_linear_weight_mask(height, width, margin_ratio)
494@tensorflow_func
495def max_projection(stack: "tf.Tensor") -> "tf.Tensor":
496 """
497 Create a maximum intensity projection from a Z-stack.
499 Args:
500 stack: 3D TensorFlow tensor of shape (Z, Y, X)
502 Returns:
503 3D TensorFlow tensor of shape (1, Y, X)
504 """
505 _validate_3d_array(stack)
507 # Create max projection
508 projection_2d = tf.reduce_max(stack, axis=0)
509 return tf.expand_dims(projection_2d, axis=0)
511@tensorflow_func
512def mean_projection(stack: "tf.Tensor") -> "tf.Tensor":
513 """
514 Create a mean intensity projection from a Z-stack.
516 Args:
517 stack: 3D TensorFlow tensor of shape (Z, Y, X)
519 Returns:
520 3D TensorFlow tensor of shape (1, Y, X)
521 """
522 _validate_3d_array(stack)
524 # Create mean projection
525 projection_2d = tf.cast(tf.reduce_mean(tf.cast(stack, tf.float32), axis=0), stack.dtype)
526 return tf.expand_dims(projection_2d, axis=0)
528@tensorflow_func
529def stack_equalize_histogram(
530 stack: "tf.Tensor",
531 bins: int = 65536,
532 range_min: float = 0.0,
533 range_max: float = 65535.0
534) -> "tf.Tensor":
535 """
536 Apply histogram equalization to an entire stack.
538 This ensures consistent contrast enhancement across all Z-slices by
539 computing a global histogram across the entire stack.
541 Args:
542 stack: 3D TensorFlow tensor of shape (Z, Y, X)
543 bins: Number of bins for histogram computation
544 range_min: Minimum value for histogram range
545 range_max: Maximum value for histogram range
547 Returns:
548 Equalized 3D TensorFlow tensor of shape (Z, Y, X)
549 """
550 _validate_3d_array(stack)
552 # TensorFlow doesn't have a direct histogram equalization function
553 # We'll implement it manually
555 # Flatten the entire stack to compute the global histogram
556 flat_stack = tf.reshape(tf.cast(stack, tf.float32), [-1])
558 # Calculate the histogram
559 # TensorFlow doesn't have a direct equivalent to numpy's histogram
560 # We'll use tf.histogram_fixed_width
561 hist = tf.histogram_fixed_width(
562 flat_stack,
563 [range_min, range_max],
564 nbins=bins
565 )
567 # Calculate cumulative distribution function (CDF)
568 cdf = tf.cumsum(hist)
570 # Normalize the CDF to the range [0, 65535]
571 # Avoid division by zero
572 if tf.reduce_max(cdf) > 0:
573 cdf = 65535.0 * cdf / tf.cast(cdf[-1], tf.float32)
575 # We don't need bin width for the lookup table approach
577 # Scale input values to bin indices
578 indices = tf.cast(tf.clip_by_value(
579 tf.math.floor((flat_stack - range_min) / (range_max - range_min) * bins),
580 0, bins - 1
581 ), tf.int32)
583 # Look up CDF values
584 equalized_flat = tf.gather(cdf, indices)
586 # Reshape back to original shape
587 equalized_stack = tf.reshape(equalized_flat, stack.shape)
589 # Convert to uint16
590 return tf.cast(equalized_stack, tf.uint16)
592@tensorflow_func
593def create_projection(
594 stack: "tf.Tensor", method: str = "max_projection"
595) -> "tf.Tensor":
596 """
597 Create a projection from a stack using the specified method.
599 Args:
600 stack: 3D TensorFlow tensor of shape (Z, Y, X)
601 method: Projection method (max_projection, mean_projection)
603 Returns:
604 3D TensorFlow tensor of shape (1, Y, X)
605 """
606 _validate_3d_array(stack)
608 if method == "max_projection":
609 return max_projection(stack)
611 if method == "mean_projection":
612 return mean_projection(stack)
614 # FAIL FAST: No fallback projection methods
615 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection")
617@tensorflow_func
618def tophat(
619 image: "tf.Tensor",
620 selem_radius: int = 50,
621 downsample_factor: int = 4
622) -> "tf.Tensor":
623 """
624 Apply white top-hat filter to a 3D image for background removal.
626 This applies the filter to each Z-slice independently using TensorFlow's
627 native operations.
629 Args:
630 image: 3D TensorFlow tensor of shape (Z, Y, X)
631 selem_radius: Radius of the structuring element disk
632 downsample_factor: Factor by which to downsample the image for processing
634 Returns:
635 Filtered 3D TensorFlow tensor of shape (Z, Y, X)
636 """
637 _validate_3d_array(image)
639 # Process each Z-slice independently
640 result_list = []
642 for z in range(image.shape[0]):
643 # Store original data type
644 input_dtype = image[z].dtype
646 # 1) Downsample using TensorFlow's resize function
647 # First, add batch and channel dimensions for resize
648 img_4d = tf.expand_dims(tf.expand_dims(tf.cast(image[z], tf.float32), 0), -1)
650 # Calculate new dimensions
651 new_h = tf.cast(tf.shape(image[z])[0] // downsample_factor, tf.int32)
652 new_w = tf.cast(tf.shape(image[z])[1] // downsample_factor, tf.int32)
654 # Resize using TensorFlow's resize function
655 image_small = tf.squeeze(tf.image.resize(
656 img_4d,
657 [new_h, new_w],
658 method=tf.image.ResizeMethod.BILINEAR
659 ), axis=[0, -1])
661 # 2) Create a circular structuring element
662 small_selem_radius = tf.maximum(1, selem_radius // downsample_factor)
663 small_grid_size = 2 * small_selem_radius + 1
665 # Create grid for structuring element
666 y_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32)
667 x_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32)
668 small_y_grid, small_x_grid = tf.meshgrid(y_range, x_range)
670 # Create circular mask
671 small_mask = tf.cast(
672 tf.sqrt(tf.square(small_y_grid) + tf.square(small_x_grid)) <= small_selem_radius,
673 tf.float32
674 )
676 # 3) Apply white top-hat using TensorFlow's morphological operations
677 # White top-hat is opening subtracted from the original image
678 # Opening is erosion followed by dilation
680 # Implement erosion using TensorFlow's conv2d with custom kernel
682 # Pad the image to handle boundary conditions
683 pad_size = small_selem_radius
684 padded = tf.pad(
685 tf.expand_dims(tf.expand_dims(image_small, 0), -1),
686 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]],
687 mode='SYMMETRIC'
688 )
690 # For erosion, we need to find the minimum value in the neighborhood
691 # We can use a trick: negate the image, apply max pooling, then negate again
692 neg_padded = -padded
694 # Apply convolution with the kernel
695 # We use a large negative value for pixels outside the mask
696 mask_expanded = tf.reshape(small_mask, [small_grid_size, small_grid_size, 1, 1])
697 mask_complement = 1.0 - mask_expanded
698 large_neg = tf.constant(-1e9, dtype=tf.float32)
700 # Custom erosion using depthwise_conv2d
701 eroded_neg = tf.nn.depthwise_conv2d(
702 neg_padded + mask_complement * large_neg,
703 tf.tile(mask_expanded, [1, 1, 1, 1]),
704 strides=[1, 1, 1, 1],
705 padding='VALID'
706 )
708 # Convert back to positive
709 eroded = -eroded_neg
711 # Implement dilation using similar approach
712 # Pad the eroded image
713 padded_eroded = tf.pad(
714 eroded,
715 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]],
716 mode='SYMMETRIC'
717 )
719 # For dilation, we need to find the maximum value in the neighborhood
720 # Apply convolution with the kernel
721 opened = tf.nn.depthwise_conv2d(
722 padded_eroded,
723 tf.tile(mask_expanded, [1, 1, 1, 1]),
724 strides=[1, 1, 1, 1],
725 padding='VALID'
726 )
728 # Remove batch and channel dimensions
729 opened = tf.squeeze(opened, axis=[0, -1])
731 # White top-hat is original minus opening
732 tophat_small = image_small - opened
734 # 4) Calculate background
735 background_small = image_small - tophat_small
737 # 5) Upscale background to original size
738 background_4d = tf.expand_dims(tf.expand_dims(background_small, 0), -1)
739 background_large = tf.squeeze(tf.image.resize(
740 background_4d,
741 tf.shape(image[z])[:2],
742 method=tf.image.ResizeMethod.BILINEAR
743 ), axis=[0, -1])
745 # 6) Subtract background and clip negative values
746 slice_result = tf.maximum(tf.cast(image[z], tf.float32) - background_large, 0.0)
748 # 7) Convert back to original data type
749 result_list.append(tf.cast(slice_result, input_dtype))
751 # Stack results back into a 3D tensor
752 result = tf.stack(result_list, axis=0)
754 return result