Coverage for openhcs/processing/backends/processors/tensorflow_processor.py: 12.5%
247 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2TensorFlow Image Processor Implementation
4This module implements the ImageProcessorInterface using TensorFlow as the backend.
5It leverages GPU acceleration for image processing operations.
7Doctrinal Clauses:
8- Clause 3 — Declarative Primacy: All functions are pure and stateless
9- Clause 88 — No Inferred Capabilities: Explicit TensorFlow dependency
10- Clause 106-A — Declared Memory Types: All methods specify TensorFlow tensors
11"""
12from __future__ import annotations
14import logging
15from typing import Any, List, Optional, Tuple
17import pkg_resources
19from openhcs.core.memory.decorators import tensorflow as tensorflow_func
20from openhcs.core.lazy_gpu_imports import tf
22# Define error variable
23TENSORFLOW_ERROR = ""
25# Check TensorFlow version for DLPack compatibility if available
26if tf is not None: 26 ↛ 42line 26 didn't jump to line 42 because the condition on line 26 was always true
27 try:
28 tf_version = pkg_resources.parse_version(tf.__version__)
29 min_version = pkg_resources.parse_version("2.12.0")
31 if tf_version < min_version:
32 TENSORFLOW_ERROR = (
33 f"TensorFlow version {tf.__version__} is not supported for DLPack operations. "
34 f"Version 2.12.0 or higher is required for stable DLPack support. "
35 f"Clause 88 violation: Cannot infer DLPack capability."
36 )
37 tf = None
38 except Exception as e:
39 TENSORFLOW_ERROR = str(e)
40 tf = None
42logger = logging.getLogger(__name__)
45@tensorflow_func
46def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "tf.Tensor":
47 """
48 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center.
50 Args:
51 height: Height of the mask
52 width: Width of the mask
53 margin_ratio: Ratio of the margin to the image size
55 Returns:
56 2D TensorFlow weight mask of shape (height, width)
57 """
58 # The compiler will ensure this function is only called when TensorFlow is available
59 # No need to check for TensorFlow availability here
61 margin_y = int(tf.math.floor(height * margin_ratio))
62 margin_x = int(tf.math.floor(width * margin_ratio))
64 weight_y = tf.ones(height, dtype=tf.float32)
65 if margin_y > 0:
66 ramp_top = tf.linspace(0.0, 1.0, margin_y)
67 ramp_bottom = tf.linspace(1.0, 0.0, margin_y)
69 # Update slices of the weight_y tensor
70 weight_y = tf.tensor_scatter_nd_update(
71 weight_y,
72 tf.stack([tf.range(margin_y)], axis=1),
73 ramp_top
74 )
75 weight_y = tf.tensor_scatter_nd_update(
76 weight_y,
77 tf.stack([tf.range(height - margin_y, height)], axis=1),
78 ramp_bottom
79 )
81 weight_x = tf.ones(width, dtype=tf.float32)
82 if margin_x > 0:
83 ramp_left = tf.linspace(0.0, 1.0, margin_x)
84 ramp_right = tf.linspace(1.0, 0.0, margin_x)
86 # Update slices of the weight_x tensor
87 weight_x = tf.tensor_scatter_nd_update(
88 weight_x,
89 tf.stack([tf.range(margin_x)], axis=1),
90 ramp_left
91 )
92 weight_x = tf.tensor_scatter_nd_update(
93 weight_x,
94 tf.stack([tf.range(width - margin_x, width)], axis=1),
95 ramp_right
96 )
98 # Create 2D weight mask using outer product
99 weight_mask = tf.tensordot(weight_y, weight_x, axes=0)
101 return weight_mask
104def _validate_3d_array(array: Any, name: str = "input") -> None:
105 """
106 Validate that the input is a 3D TensorFlow tensor.
108 Args:
109 array: Array to validate
110 name: Name of the array for error messages
112 Raises:
113 TypeError: If the array is not a TensorFlow tensor
114 ValueError: If the array is not 3D
115 ImportError: If TensorFlow is not available
116 """
117 # The compiler will ensure this function is only called when TensorFlow is available
118 # No need to check for TensorFlow availability here
120 if not isinstance(array, tf.Tensor):
121 raise TypeError(f"{name} must be a TensorFlow tensor, got {type(array)}. "
122 f"No automatic conversion is performed to maintain explicit contracts.")
124 if len(array.shape) != 3:
125 raise ValueError(f"{name} must be a 3D tensor, got {len(array.shape)}D")
127def _gaussian_blur(image: "tf.Tensor", sigma: float) -> "tf.Tensor":
128 """
129 Apply Gaussian blur to a 2D image.
131 Args:
132 image: 2D TensorFlow tensor of shape (H, W)
133 sigma: Standard deviation of the Gaussian kernel
135 Returns:
136 Blurred 2D TensorFlow tensor of shape (H, W)
137 """
138 # Calculate kernel size based on sigma
139 kernel_size = max(3, int(2 * 4 * sigma + 1))
140 if kernel_size % 2 == 0:
141 kernel_size += 1 # Ensure odd kernel size
143 # Add batch and channel dimensions for tf.image.gaussian_blur
144 img = tf.expand_dims(tf.expand_dims(image, 0), -1)
146 # Apply Gaussian blur
147 blurred = tf.image.gaussian_blur(
148 img,
149 [kernel_size, kernel_size],
150 sigma,
151 "REFLECT"
152 )
154 # Remove batch and channel dimensions
155 return tf.squeeze(blurred)
157@tensorflow_func
158def sharpen(
159 image: "tf.Tensor", radius: float = 1.0, amount: float = 1.0
160) -> "tf.Tensor":
161 """
162 Sharpen a 3D image using unsharp masking.
164 This applies sharpening to each Z-slice independently.
166 Args:
167 image: 3D TensorFlow tensor of shape (Z, Y, X)
168 radius: Radius of Gaussian blur
169 amount: Sharpening strength
171 Returns:
172 Sharpened 3D TensorFlow tensor of shape (Z, Y, X)
173 """
174 _validate_3d_array(image)
176 # Store original dtype
177 dtype = image.dtype
179 # Process each Z-slice independently
180 result_list = []
182 for z in range(image.shape[0]):
183 # Convert to float for processing
184 slice_float = tf.cast(image[z], tf.float32) / tf.reduce_max(image[z])
186 # Create blurred version for unsharp mask
187 blurred = _gaussian_blur(slice_float, sigma=radius)
189 # Apply unsharp mask: original + amount * (original - blurred)
190 sharpened = slice_float + amount * (slice_float - blurred)
192 # Clip to valid range
193 sharpened = tf.clip_by_value(sharpened, 0.0, 1.0)
195 # Scale back to original range
196 min_val = tf.reduce_min(sharpened)
197 max_val = tf.reduce_max(sharpened)
198 if max_val > min_val:
199 sharpened = (sharpened - min_val) * 65535.0 / (max_val - min_val)
201 result_list.append(sharpened)
203 # Stack results back into a 3D tensor
204 result = tf.stack(result_list, axis=0)
206 # Convert back to original dtype
207 if dtype == tf.uint16:
208 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16)
209 else:
210 result = tf.cast(result, dtype)
212 return result
214@tensorflow_func
215def percentile_normalize(
216 image: "tf.Tensor",
217 low_percentile: float = 1.0,
218 high_percentile: float = 99.0,
219 target_min: float = 0.0,
220 target_max: float = 65535.0
221) -> "tf.Tensor":
222 """
223 Normalize a 3D image using percentile-based contrast stretching.
225 This applies normalization to each Z-slice independently.
227 Args:
228 image: 3D TensorFlow tensor of shape (Z, Y, X)
229 low_percentile: Lower percentile (0-100)
230 high_percentile: Upper percentile (0-100)
231 target_min: Target minimum value
232 target_max: Target maximum value
234 Returns:
235 Normalized 3D TensorFlow tensor of shape (Z, Y, X)
236 """
237 _validate_3d_array(image)
239 # Process each Z-slice independently
240 result_list = []
242 for z in range(image.shape[0]):
243 # Get percentile values for this slice
244 # TensorFlow doesn't have a direct percentile function, so we use a workaround
245 flat_slice = tf.reshape(image[z], [-1])
246 sorted_slice = tf.sort(flat_slice)
248 # Calculate indices for percentiles
249 slice_size = tf.cast(tf.size(flat_slice), tf.float32)
250 low_idx = tf.cast(tf.math.floor(slice_size * low_percentile / 100.0), tf.int32)
251 high_idx = tf.cast(tf.math.floor(slice_size * high_percentile / 100.0), tf.int32)
253 # Get percentile values
254 p_low = sorted_slice[low_idx]
255 p_high = sorted_slice[high_idx]
257 # Avoid division by zero
258 if p_high == p_low:
259 result_list.append(tf.ones_like(image[z], dtype=tf.float32) * target_min)
260 continue
262 # Clip and normalize to target range
263 clipped = tf.clip_by_value(tf.cast(image[z], tf.float32), p_low, p_high)
264 scale = (target_max - target_min) / (p_high - p_low)
265 normalized = (clipped - p_low) * scale + target_min
266 result_list.append(normalized)
268 # Stack results back into a 3D tensor
269 result = tf.stack(result_list, axis=0)
271 # Convert to uint16
272 result = tf.cast(tf.clip_by_value(result, 0, 65535), tf.uint16)
274 return result
276@tensorflow_func
277def stack_percentile_normalize(
278 stack: "tf.Tensor",
279 low_percentile: float = 1.0,
280 high_percentile: float = 99.0,
281 target_min: float = 0.0,
282 target_max: float = 65535.0
283) -> "tf.Tensor":
284 """
285 Normalize a stack using global percentile-based contrast stretching.
287 This ensures consistent normalization across all Z-slices by computing
288 global percentiles across the entire stack.
290 Args:
291 stack: 3D TensorFlow tensor of shape (Z, Y, X)
292 low_percentile: Lower percentile (0-100)
293 high_percentile: Upper percentile (0-100)
294 target_min: Target minimum value
295 target_max: Target maximum value
297 Returns:
298 Normalized 3D TensorFlow tensor of shape (Z, Y, X)
299 """
300 _validate_3d_array(stack)
302 # Calculate global percentiles across the entire stack using TensorFlow Probability
303 # This is memory-efficient and doesn't require sorting the entire array
304 try:
305 import tensorflow_probability as tfp
306 p_low = tf.cast(tfp.stats.percentile(stack, low_percentile), tf.float32)
307 p_high = tf.cast(tfp.stats.percentile(stack, high_percentile), tf.float32)
308 except ImportError:
309 # Fallback to manual calculation if TensorFlow Probability is not available
310 # This is less memory-efficient but works
311 flat_stack = tf.reshape(stack, [-1])
312 sorted_stack = tf.sort(flat_stack)
314 # Calculate indices for percentiles
315 stack_size = tf.cast(tf.size(flat_stack), tf.float32)
316 low_idx = tf.cast(tf.math.floor(stack_size * low_percentile / 100.0), tf.int32)
317 high_idx = tf.cast(tf.math.floor(stack_size * high_percentile / 100.0), tf.int32)
319 # Get percentile values and cast to float32 for consistency
320 p_low = tf.cast(sorted_stack[low_idx], tf.float32)
321 p_high = tf.cast(sorted_stack[high_idx], tf.float32)
323 # Avoid division by zero
324 if p_high == p_low:
325 return tf.ones_like(stack) * target_min
327 # Clip and normalize to target range (match NumPy implementation exactly)
328 clipped = tf.clip_by_value(stack, p_low, p_high)
329 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min
330 normalized = tf.cast(normalized, tf.uint16)
332 return normalized
334@tensorflow_func
335def create_composite(
336 images: List["tf.Tensor"], weights: Optional[List[float]] = None
337) -> "tf.Tensor":
338 """
339 Create a composite image from multiple 3D arrays.
341 Args:
342 images: List of 3D TensorFlow tensors, each of shape (Z, Y, X)
343 weights: List of weights for each image. If None, equal weights are used.
345 Returns:
346 Composite 3D TensorFlow tensor of shape (Z, Y, X)
347 """
348 # Ensure images is a list
349 if not isinstance(images, list):
350 raise TypeError("images must be a list of TensorFlow tensors")
352 # Check for empty list early
353 if not images:
354 raise ValueError("images list cannot be empty")
356 # Validate all images are 3D TensorFlow tensors with the same shape
357 for i, img in enumerate(images):
358 _validate_3d_array(img, f"images[{i}]")
359 if img.shape != images[0].shape:
360 raise ValueError(
361 f"All images must have the same shape. "
362 f"images[0] has shape {images[0].shape}, "
363 f"images[{i}] has shape {img.shape}"
364 )
366 # Default weights if none provided
367 if weights is None:
368 # Equal weights for all images
369 weights = [1.0 / len(images)] * len(images)
370 elif not isinstance(weights, list):
371 raise TypeError("weights must be a list of values")
373 # Make sure weights list is at least as long as images list
374 if len(weights) < len(images):
375 weights = weights + [0.0] * (len(images) - len(weights))
376 # Truncate weights if longer than images
377 weights = weights[:len(images)]
379 first_image = images[0]
380 shape = first_image.shape
381 dtype = first_image.dtype
383 # Create empty composite
384 composite = tf.zeros(shape, dtype=tf.float32)
385 total_weight = 0.0
387 # Add each image with its weight
388 for i, image in enumerate(images):
389 weight = weights[i]
390 if weight <= 0.0:
391 continue
393 # Add to composite
394 composite += tf.cast(image, tf.float32) * weight
395 total_weight += weight
397 # Normalize by total weight
398 if total_weight > 0:
399 composite /= total_weight
401 # Convert back to original dtype (usually uint16)
402 if dtype in [tf.uint8, tf.uint16, tf.uint32, tf.int8, tf.int16, tf.int32, tf.int64]:
403 # Get the maximum value for the specific integer dtype
404 if dtype == tf.uint8:
405 max_val = 255
406 elif dtype == tf.uint16:
407 max_val = 65535
408 elif dtype == tf.uint32:
409 max_val = 4294967295
410 elif dtype == tf.int8:
411 max_val = 127
412 elif dtype == tf.int16:
413 max_val = 32767
414 elif dtype == tf.int32:
415 max_val = 2147483647
416 elif dtype == tf.int64:
417 max_val = 9223372036854775807
419 composite = tf.cast(tf.clip_by_value(composite, 0, max_val), dtype)
420 else:
421 composite = tf.cast(composite, dtype)
423 return composite
425@tensorflow_func
426def apply_mask(image: "tf.Tensor", mask: "tf.Tensor") -> "tf.Tensor":
427 """
428 Apply a mask to a 3D image.
430 This applies the mask to each Z-slice independently if mask is 2D,
431 or applies the 3D mask directly if mask is 3D.
433 Args:
434 image: 3D TensorFlow tensor of shape (Z, Y, X)
435 mask: 3D TensorFlow tensor of shape (Z, Y, X) or 2D TensorFlow tensor of shape (Y, X)
437 Returns:
438 Masked 3D TensorFlow tensor of shape (Z, Y, X)
439 """
440 _validate_3d_array(image)
442 # Handle 2D mask (apply to each Z-slice)
443 if isinstance(mask, tf.Tensor) and len(mask.shape) == 2:
444 if mask.shape != image.shape[1:]:
445 raise ValueError(
446 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}"
447 )
449 # Apply 2D mask to each Z-slice
450 result_list = []
451 for z in range(image.shape[0]):
452 result_list.append(tf.cast(image[z], tf.float32) * tf.cast(mask, tf.float32))
454 result = tf.stack(result_list, axis=0)
455 return tf.cast(result, image.dtype)
457 # Handle 3D mask
458 if isinstance(mask, tf.Tensor) and len(mask.shape) == 3:
459 if mask.shape != image.shape:
460 raise ValueError(
461 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}"
462 )
464 # Apply 3D mask directly
465 masked = tf.cast(image, tf.float32) * tf.cast(mask, tf.float32)
466 return tf.cast(masked, image.dtype)
468 # If we get here, the mask is neither 2D nor 3D TensorFlow tensor
469 raise TypeError(f"mask must be a 2D or 3D TensorFlow tensor, got {type(mask)}")
471@tensorflow_func
472def create_weight_mask(
473 shape: Tuple[int, int], margin_ratio: float = 0.1
474) -> "tf.Tensor":
475 """
476 Create a weight mask for blending images.
478 Args:
479 shape: Shape of the mask (height, width)
480 margin_ratio: Ratio of image size to use as margin
482 Returns:
483 2D TensorFlow weight mask of shape (Y, X)
484 """
485 if not isinstance(shape, tuple) or len(shape) != 2:
486 raise TypeError("shape must be a tuple of (height, width)")
488 height, width = shape
489 return create_linear_weight_mask(height, width, margin_ratio)
491@tensorflow_func
492def max_projection(stack: "tf.Tensor") -> "tf.Tensor":
493 """
494 Create a maximum intensity projection from a Z-stack.
496 Args:
497 stack: 3D TensorFlow tensor of shape (Z, Y, X)
499 Returns:
500 3D TensorFlow tensor of shape (1, Y, X)
501 """
502 _validate_3d_array(stack)
504 # Create max projection
505 projection_2d = tf.reduce_max(stack, axis=0)
506 return tf.expand_dims(projection_2d, axis=0)
508@tensorflow_func
509def mean_projection(stack: "tf.Tensor") -> "tf.Tensor":
510 """
511 Create a mean intensity projection from a Z-stack.
513 Args:
514 stack: 3D TensorFlow tensor of shape (Z, Y, X)
516 Returns:
517 3D TensorFlow tensor of shape (1, Y, X)
518 """
519 _validate_3d_array(stack)
521 # Create mean projection
522 projection_2d = tf.cast(tf.reduce_mean(tf.cast(stack, tf.float32), axis=0), stack.dtype)
523 return tf.expand_dims(projection_2d, axis=0)
525@tensorflow_func
526def stack_equalize_histogram(
527 stack: "tf.Tensor",
528 bins: int = 65536,
529 range_min: float = 0.0,
530 range_max: float = 65535.0
531) -> "tf.Tensor":
532 """
533 Apply histogram equalization to an entire stack.
535 This ensures consistent contrast enhancement across all Z-slices by
536 computing a global histogram across the entire stack.
538 Args:
539 stack: 3D TensorFlow tensor of shape (Z, Y, X)
540 bins: Number of bins for histogram computation
541 range_min: Minimum value for histogram range
542 range_max: Maximum value for histogram range
544 Returns:
545 Equalized 3D TensorFlow tensor of shape (Z, Y, X)
546 """
547 _validate_3d_array(stack)
549 # TensorFlow doesn't have a direct histogram equalization function
550 # We'll implement it manually
552 # Flatten the entire stack to compute the global histogram
553 flat_stack = tf.reshape(tf.cast(stack, tf.float32), [-1])
555 # Calculate the histogram
556 # TensorFlow doesn't have a direct equivalent to numpy's histogram
557 # We'll use tf.histogram_fixed_width
558 hist = tf.histogram_fixed_width(
559 flat_stack,
560 [range_min, range_max],
561 nbins=bins
562 )
564 # Calculate cumulative distribution function (CDF)
565 cdf = tf.cumsum(hist)
567 # Normalize the CDF to the range [0, 65535]
568 # Avoid division by zero
569 if tf.reduce_max(cdf) > 0:
570 cdf = 65535.0 * cdf / tf.cast(cdf[-1], tf.float32)
572 # We don't need bin width for the lookup table approach
574 # Scale input values to bin indices
575 indices = tf.cast(tf.clip_by_value(
576 tf.math.floor((flat_stack - range_min) / (range_max - range_min) * bins),
577 0, bins - 1
578 ), tf.int32)
580 # Look up CDF values
581 equalized_flat = tf.gather(cdf, indices)
583 # Reshape back to original shape
584 equalized_stack = tf.reshape(equalized_flat, stack.shape)
586 # Convert to uint16
587 return tf.cast(equalized_stack, tf.uint16)
589@tensorflow_func
590def create_projection(
591 stack: "tf.Tensor", method: str = "max_projection"
592) -> "tf.Tensor":
593 """
594 Create a projection from a stack using the specified method.
596 Args:
597 stack: 3D TensorFlow tensor of shape (Z, Y, X)
598 method: Projection method (max_projection, mean_projection)
600 Returns:
601 3D TensorFlow tensor of shape (1, Y, X)
602 """
603 _validate_3d_array(stack)
605 if method == "max_projection":
606 return max_projection(stack)
608 if method == "mean_projection":
609 return mean_projection(stack)
611 # FAIL FAST: No fallback projection methods
612 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection")
614@tensorflow_func
615def tophat(
616 image: "tf.Tensor",
617 selem_radius: int = 50,
618 downsample_factor: int = 4
619) -> "tf.Tensor":
620 """
621 Apply white top-hat filter to a 3D image for background removal.
623 This applies the filter to each Z-slice independently using TensorFlow's
624 native operations.
626 Args:
627 image: 3D TensorFlow tensor of shape (Z, Y, X)
628 selem_radius: Radius of the structuring element disk
629 downsample_factor: Factor by which to downsample the image for processing
631 Returns:
632 Filtered 3D TensorFlow tensor of shape (Z, Y, X)
633 """
634 _validate_3d_array(image)
636 # Process each Z-slice independently
637 result_list = []
639 for z in range(image.shape[0]):
640 # Store original data type
641 input_dtype = image[z].dtype
643 # 1) Downsample using TensorFlow's resize function
644 # First, add batch and channel dimensions for resize
645 img_4d = tf.expand_dims(tf.expand_dims(tf.cast(image[z], tf.float32), 0), -1)
647 # Calculate new dimensions
648 new_h = tf.cast(tf.shape(image[z])[0] // downsample_factor, tf.int32)
649 new_w = tf.cast(tf.shape(image[z])[1] // downsample_factor, tf.int32)
651 # Resize using TensorFlow's resize function
652 image_small = tf.squeeze(tf.image.resize(
653 img_4d,
654 [new_h, new_w],
655 method=tf.image.ResizeMethod.BILINEAR
656 ), axis=[0, -1])
658 # 2) Create a circular structuring element
659 small_selem_radius = tf.maximum(1, selem_radius // downsample_factor)
660 small_grid_size = 2 * small_selem_radius + 1
662 # Create grid for structuring element
663 y_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32)
664 x_range = tf.range(-small_selem_radius, small_selem_radius + 1, dtype=tf.float32)
665 small_y_grid, small_x_grid = tf.meshgrid(y_range, x_range)
667 # Create circular mask
668 small_mask = tf.cast(
669 tf.sqrt(tf.square(small_y_grid) + tf.square(small_x_grid)) <= small_selem_radius,
670 tf.float32
671 )
673 # 3) Apply white top-hat using TensorFlow's morphological operations
674 # White top-hat is opening subtracted from the original image
675 # Opening is erosion followed by dilation
677 # Implement erosion using TensorFlow's conv2d with custom kernel
679 # Pad the image to handle boundary conditions
680 pad_size = small_selem_radius
681 padded = tf.pad(
682 tf.expand_dims(tf.expand_dims(image_small, 0), -1),
683 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]],
684 mode='SYMMETRIC'
685 )
687 # For erosion, we need to find the minimum value in the neighborhood
688 # We can use a trick: negate the image, apply max pooling, then negate again
689 neg_padded = -padded
691 # Apply convolution with the kernel
692 # We use a large negative value for pixels outside the mask
693 mask_expanded = tf.reshape(small_mask, [small_grid_size, small_grid_size, 1, 1])
694 mask_complement = 1.0 - mask_expanded
695 large_neg = tf.constant(-1e9, dtype=tf.float32)
697 # Custom erosion using depthwise_conv2d
698 eroded_neg = tf.nn.depthwise_conv2d(
699 neg_padded + mask_complement * large_neg,
700 tf.tile(mask_expanded, [1, 1, 1, 1]),
701 strides=[1, 1, 1, 1],
702 padding='VALID'
703 )
705 # Convert back to positive
706 eroded = -eroded_neg
708 # Implement dilation using similar approach
709 # Pad the eroded image
710 padded_eroded = tf.pad(
711 eroded,
712 [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]],
713 mode='SYMMETRIC'
714 )
716 # For dilation, we need to find the maximum value in the neighborhood
717 # Apply convolution with the kernel
718 opened = tf.nn.depthwise_conv2d(
719 padded_eroded,
720 tf.tile(mask_expanded, [1, 1, 1, 1]),
721 strides=[1, 1, 1, 1],
722 padding='VALID'
723 )
725 # Remove batch and channel dimensions
726 opened = tf.squeeze(opened, axis=[0, -1])
728 # White top-hat is original minus opening
729 tophat_small = image_small - opened
731 # 4) Calculate background
732 background_small = image_small - tophat_small
734 # 5) Upscale background to original size
735 background_4d = tf.expand_dims(tf.expand_dims(background_small, 0), -1)
736 background_large = tf.squeeze(tf.image.resize(
737 background_4d,
738 tf.shape(image[z])[:2],
739 method=tf.image.ResizeMethod.BILINEAR
740 ), axis=[0, -1])
742 # 6) Subtract background and clip negative values
743 slice_result = tf.maximum(tf.cast(image[z], tf.float32) - background_large, 0.0)
745 # 7) Convert back to original data type
746 result_list.append(tf.cast(slice_result, input_dtype))
748 # Stack results back into a 3D tensor
749 result = tf.stack(result_list, axis=0)
751 return result