Coverage for openhcs/processing/backends/analysis/dxf_mask_pipeline.py: 18.5%
124 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
3import logging
4from typing import TYPE_CHECKING, List, Tuple, Union
6from openhcs.utils.import_utils import optional_import, create_placeholder_class
7from openhcs.core.memory.decorators import torch as torch_func # Changed from numpy_func
9# --- Backend Imports as optional dependencies ---
10# PyTorch
11if TYPE_CHECKING: 11 ↛ 12line 11 didn't jump to line 12 because the condition on line 11 was never true
12 import torch
13 import torch.nn as nn
14 import torch.nn.functional as F
16torch = optional_import("torch")
17nn = optional_import("torch.nn") if torch is not None else None
18F = optional_import("torch.nn.functional") if torch is not None else None
19HAS_TORCH = torch is not None
21# CuPy
22if TYPE_CHECKING: 22 ↛ 23line 22 didn't jump to line 23 because the condition on line 22 was never true
23 import cupy as cp
25cp = optional_import("cupy")
26HAS_CUPY = cp is not None
28# JAX
29if TYPE_CHECKING: 29 ↛ 30line 29 didn't jump to line 30 because the condition on line 29 was never true
30 import jax
31 import jax.numpy as jnp
33jax = optional_import("jax")
34jnp = optional_import("jax.numpy") if jax is not None else None
35HAS_JAX = jax is not None
37# TensorFlow
38if TYPE_CHECKING: 38 ↛ 39line 38 didn't jump to line 39 because the condition on line 38 was never true
39 import tensorflow as tf
41tf = optional_import("tensorflow")
42HAS_TENSORFLOW = tf is not None
44logger = logging.getLogger(__name__)
46# --- PyTorch Specific Helpers ---
47# Create placeholder for nn.Module
48# If nn (and thus nn.Module) is available, ModulePlaceholder will be nn.Module.
49# Otherwise, ModulePlaceholder will be a placeholder class.
50ModulePlaceholder = create_placeholder_class(
51 "Module", # Name for the placeholder if generated
52 base_class=nn.Module if nn else None,
53 required_library="PyTorch"
54)
56if HAS_TORCH: # Keep nn.Module definition conditional on PyTorch availability 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true
57 class _RegistrationCNN_torch(ModulePlaceholder): # Inherit from placeholder or actual nn.Module
58 def __init__(self):
59 super().__init__() # This will call nn.Module.__init__ or Placeholder.__init__
60 self.conv1 = nn.Conv2d(2, 32, kernel_size=3, padding=1)
61 self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
62 self.conv3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
63 self.conv4 = nn.Conv2d(32, 2, kernel_size=3, padding=1)
65 def forward(self, x: "torch.Tensor") -> "torch.Tensor": # x shape: [B, 2, H, W]
66 x = F.relu(self.conv1(x))
67 x = F.relu(self.conv2(x))
68 x = F.relu(self.conv3(x))
69 x = torch.tanh(self.conv4(x)) # Output in [-1, 1] range for displacement
70 return x
72 def _rasterize_polygons_slice_torch(
73 polygons_gpu: List["torch.Tensor"], H: int, W: int, device: "torch.device"
74 ) -> "torch.Tensor":
75 """
76 Simplified GPU rasterization for a single slice using point-in-polygon.
77 WARNING: This implementation is basic and slow for many/complex polygons.
78 A production system would use optimized rasterization libraries or custom kernels.
79 """
80 mask_slice = torch.zeros((H, W), dtype=torch.bool, device=device)
82 for poly_tensor in polygons_gpu: # poly_tensor shape [N_points, 2] (x, y)
83 if poly_tensor.shape[0] < 3: continue
85 min_xy = torch.min(poly_tensor, dim=0)[0]
86 max_xy = torch.max(poly_tensor, dim=0)[0]
87 min_x, min_y = torch.floor(min_xy).long()
88 max_x, max_y = torch.ceil(max_xy).long()
90 min_x = torch.clamp(min_x, 0, W - 1)
91 max_x = torch.clamp(max_x, 0, W - 1)
92 min_y = torch.clamp(min_y, 0, H - 1)
93 max_y = torch.clamp(max_y, 0, H - 1)
95 if max_x < min_x or max_y < min_y: continue
97 # Create grid for the bounding box
98 bb_H, bb_W = max_y - min_y + 1, max_x - min_x + 1
99 yy_bb, xx_bb = torch.meshgrid(
100 torch.arange(min_y, max_y + 1, device=device),
101 torch.arange(min_x, max_x + 1, device=device),
102 indexing='ij'
103 ) # yy_bb, xx_bb shapes [bb_H, bb_W]
105 # Points to test within bounding box, shape [bb_H, bb_W, 2]
106 test_points = torch.stack((xx_bb.float(), yy_bb.float()), dim=-1)
108 # Ray casting algorithm (vectorized attempt for one polygon)
109 num_poly_pts = poly_tensor.shape[0]
110 poly_x = poly_tensor[:, 0]
111 poly_y = poly_tensor[:, 1]
113 # Replicate points for comparison with all edges
114 # test_points: [bb_H, bb_W, 2] -> [bb_H, bb_W, 1, 2]
115 # poly_x/y: [num_poly_pts]
116 # j_indices: [num_poly_pts] (0, 1, ..., N-1)
117 # k_indices: [num_poly_pts] (N-1, 0, ..., N-2) (previous vertex)
118 j_indices = torch.arange(num_poly_pts, device=device)
119 k_indices = (j_indices - 1 + num_poly_pts) % num_poly_pts
121 # Edges: (poly_x[j], poly_y[j]) to (poly_x[k], poly_y[k])
122 # Test point: (test_x, test_y) from test_points
123 test_x = test_points[..., 0].unsqueeze(-1) # [bb_H, bb_W, 1]
124 test_y = test_points[..., 1].unsqueeze(-1) # [bb_H, bb_W, 1]
126 # Compare test_y with y-coordinates of polygon vertices
127 # Shape of poly_y[j_indices] is [num_poly_pts]
128 # Need to broadcast: test_y [bb_H, bb_W, 1], poly_y[j_indices] [1, 1, num_poly_pts]
129 cond1 = (poly_y[j_indices] <= test_y) & (test_y < poly_y[k_indices]) # Upward edge
130 cond2 = (poly_y[k_indices] <= test_y) & (test_y < poly_y[j_indices]) # Downward edge
132 # Intersection x-coordinate: test_x < (poly_x[k] - poly_x[j]) * (test_y - poly_y[j]) / (poly_y[k] - poly_y[j]) + poly_x[j]
133 # Avoid division by zero if poly_y[k] == poly_y[j] (horizontal edge)
134 # delta_y = poly_y[k_indices] - poly_y[j_indices]
135 # delta_x = poly_x[k_indices] - poly_x[j_indices]
137 # Simplified: this vectorized PIP is complex. Using iterative for clarity here.
138 # The iterative version from sandbox was more direct to write under constraints.
139 # Reverting to iterative for bounding box for now.
140 current_poly_mask_bb = torch.zeros((bb_H, bb_W), dtype=torch.bool, device=device)
141 for r_idx in range(bb_H):
142 for c_idx in range(bb_W):
143 abs_y, abs_x = min_y + r_idx, min_x + c_idx
144 intersections = 0
145 p1x, p1y = poly_tensor[-1, 0], poly_tensor[-1, 1]
146 for i in range(num_poly_pts):
147 p2x, p2y = poly_tensor[i, 0], poly_tensor[i, 1]
148 if abs_y > min(p1y, p2y) and abs_y <= max(p1y, p2y) and abs_x <= max(p1x, p2x):
149 if p1y != p2y: # Non-horizontal edge
150 xinters = (abs_y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
151 if p1x == p2x or abs_x <= xinters: # Point is to the left of edge
152 intersections += 1
153 p1x, p1y = p2x, p2y
154 if intersections % 2 == 1:
155 current_poly_mask_bb[r_idx, c_idx] = True
157 mask_slice[min_y:max_y+1, min_x:max_x+1] = mask_slice[min_y:max_y+1, min_x:max_x+1] | current_poly_mask_bb
158 return mask_slice
160 def _apply_displacement_field_torch(data_slice: "torch.Tensor", displacement_field: "torch.Tensor") -> "torch.Tensor":
161 H_data, W_data = data_slice.shape[-2:]
163 grid_y, grid_x = torch.meshgrid(
164 torch.linspace(-1, 1, H_data, device=data_slice.device),
165 torch.linspace(-1, 1, W_data, device=data_slice.device),
166 indexing='ij'
167 )
168 identity_grid = torch.stack((grid_x, grid_y), dim=-1) # Shape [H, W, 2] (x,y) for grid_sample
170 # displacement_field is [2, H, W] (dx, dy), needs to be [H, W, 2] (dx, dy) then (x+dx, y+dy)
171 # The displacement field from CNN is already scaled by tanh to [-1, 1].
172 # This range corresponds to displacement from -image_size/2 to +image_size/2.
173 # grid_sample expects final coordinates in [-1, 1].
174 # So, new_x_norm = old_x_norm + dx_norm
175 displaced_grid = identity_grid + displacement_field.permute(1, 2, 0) # [H, W, 2]
176 displaced_grid = torch.clamp(displaced_grid, -1, 1) # Ensure grid is within bounds
178 data_slice_unsqueezed = data_slice.unsqueeze(0) # [1, C, H, W] or [1, H, W]
179 if data_slice_unsqueezed.ndim == 3: # Was [H,W] -> [1,H,W], need [1,1,H,W]
180 data_slice_unsqueezed = data_slice_unsqueezed.unsqueeze(1)
182 warped_slice = F.grid_sample(
183 data_slice_unsqueezed,
184 displaced_grid.unsqueeze(0), # [1, H, W, 2]
185 mode='bilinear', padding_mode='zeros', align_corners=False # Zeros for outside mask
186 )
187 return warped_slice.squeeze(0) # [C,H,W] or [H,W]
189 def _smooth_field_z_torch(displacement_field_stack: "torch.Tensor", sigma_z: float) -> "torch.Tensor":
190 if sigma_z <= 0: return displacement_field_stack
191 Z, C, H, W = displacement_field_stack.shape # [Z, 2, H, W]
193 kernel_size_z = max(3, int(2 * 2 * sigma_z + 1))
194 if kernel_size_z % 2 == 0: kernel_size_z +=1
196 coords_z = torch.arange(kernel_size_z, dtype=torch.float32, device=displacement_field_stack.device)
197 coords_z -= (kernel_size_z - 1) / 2
198 kernel_1d_z = torch.exp(-(coords_z**2) / (2 * sigma_z**2))
199 kernel_1d_z /= kernel_1d_z.sum()
201 kernel_1d_z_reshaped = kernel_1d_z.view(1, 1, kernel_size_z).repeat(C * H * W, 1, 1) # For grouped conv
203 # Reshape for 1D convolution: (N, C_in, L_in) -> (C*H*W, 1, Z)
204 field_permuted = displacement_field_stack.permute(1,2,3,0).contiguous() # [C,H,W,Z]
205 field_reshaped = field_permuted.view(-1, 1, Z) # [C*H*W, 1, Z]
207 padding_z = kernel_size_z // 2
208 smoothed_reshaped = F.conv1d(field_reshaped, kernel_1d_z_reshaped, padding=padding_z, groups=C*H*W)
210 smoothed_permuted = smoothed_reshaped.view(C,H,W,Z)
211 smoothed_stack = smoothed_permuted.permute(3,0,1,2).contiguous() # [Z,C,H,W]
212 return smoothed_stack
214# --- Main Pipeline Function ---
215@torch_func # Decorate with torch_func
216def dxf_mask_pipeline(
217 image_stack, # Expected to be a torch.Tensor if torch_func is used
218 dxf_polygons: List[List[Tuple[float, float]]],
219 apply_mask: bool = False,
220 masking_mode: str = "zero_out",
221 smoothing_sigma_z: float = 0.0,
222 **kwargs
223) -> Union["torch.Tensor", "cp.ndarray", "jnp.ndarray", "tf.Tensor"]: # type: ignore
225 # Assuming image_stack is (Z, H, W) or (Z, C, H, W)
226 # If (Z,C,H,W), C is usually 1 for grayscale, or we take the first channel.
227 if image_stack.ndim == 4: # Z, C, H, W
228 Z, C_img, H, W = image_stack.shape
229 if C_img > 1: logger.warning("Multi-channel image stack provided, using first channel for registration.")
230 image_stack_reg = image_stack[:, 0, :, :] # Use first channel for registration: (Z, H, W)
231 elif image_stack.ndim == 3: # Z, H, W
232 Z, H, W = image_stack.shape
233 image_stack_reg = image_stack
234 else:
235 raise ValueError(f"image_stack has unsupported ndim: {image_stack.ndim}. Expected 3 or 4.")
237 device = image_stack.device # image_stack is now expected to be a torch.Tensor
238 polygons_gpu = [torch.tensor(p, dtype=torch.float32, device=device) for p in dxf_polygons]
240 initial_rasterized_masks_float = torch.zeros((Z, H, W), device=device, dtype=torch.float32)
241 displacement_field_slices = []
243 registration_cnn = _RegistrationCNN_torch().to(device)
244 registration_cnn.eval()
246 for z_idx in range(Z):
247 image_slice_gray = image_stack_reg[z_idx] # Shape [H, W]
249 img_min, img_max = torch.min(image_slice_gray), torch.max(image_slice_gray)
250 image_slice_norm = (image_slice_gray - img_min) / (img_max - img_min + 1e-6) if img_max > img_min else torch.zeros_like(image_slice_gray)
252 raster_slice = _rasterize_polygons_slice_torch(polygons_gpu, H, W, device).float()
253 initial_rasterized_masks_float[z_idx] = raster_slice
255 cnn_input = torch.stack([image_slice_norm, raster_slice], dim=0).unsqueeze(0) # [1, 2, H, W]
256 with torch.no_grad():
257 displacement_field_slice = registration_cnn(cnn_input).squeeze(0) # [2, H, W]
258 displacement_field_slices.append(displacement_field_slice)
260 displacement_field_stack = torch.stack(displacement_field_slices, dim=0) # [Z, 2, H, W]
262 if smoothing_sigma_z > 0:
263 displacement_field_stack = _smooth_field_z_torch(displacement_field_stack, smoothing_sigma_z)
265 aligned_mask_slices_list = []
266 for z_idx in range(Z):
267 aligned_slice = _apply_displacement_field_torch(
268 initial_rasterized_masks_float[z_idx],
269 displacement_field_stack[z_idx]
270 ) # Output can be [1,H,W] or [H,W]
271 if aligned_slice.ndim == 3 and aligned_slice.shape[0] == 1:
272 aligned_slice = aligned_slice.squeeze(0) # to [H,W]
273 aligned_mask_slices_list.append(aligned_slice > 0.5) # Binarize
275 aligned_mask_stack_bool = torch.stack(aligned_mask_slices_list, dim=0) # [Z, H, W] bool
277 if apply_mask:
278 original_dtype = image_stack.dtype
279 # Prepare mask for broadcasting if image_stack is (Z,C,H,W)
280 mask_to_apply = aligned_mask_stack_bool.float()
281 if image_stack.ndim == 4: # Z,C,H,W
282 mask_to_apply = mask_to_apply.unsqueeze(1) # -> (Z,1,H,W)
284 if masking_mode == "zero_out" or masking_mode == "multiply":
285 masked_img = image_stack.float() * mask_to_apply
286 return masked_img.to(original_dtype)
287 elif masking_mode == "nan_out":
288 masked_img_float = image_stack.float()
289 nans = torch.full_like(masked_img_float, float('nan'))
290 return torch.where(mask_to_apply.bool(), masked_img_float, nans) # Nan where mask is False
291 else:
292 raise ValueError(f"Unknown masking_mode: {masking_mode}")
293 else:
294 return aligned_mask_stack_bool