Coverage for openhcs/processing/backends/analysis/dxf_mask_pipeline.py: 18.5%

124 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2 

3import logging 

4from typing import TYPE_CHECKING, List, Tuple, Union 

5 

6from openhcs.utils.import_utils import optional_import, create_placeholder_class 

7from openhcs.core.memory.decorators import torch as torch_func # Changed from numpy_func 

8 

9# --- Backend Imports as optional dependencies --- 

10# PyTorch 

11if TYPE_CHECKING: 11 ↛ 12line 11 didn't jump to line 12 because the condition on line 11 was never true

12 import torch 

13 import torch.nn as nn 

14 import torch.nn.functional as F 

15 

16torch = optional_import("torch") 

17nn = optional_import("torch.nn") if torch is not None else None 

18F = optional_import("torch.nn.functional") if torch is not None else None 

19HAS_TORCH = torch is not None 

20 

21# CuPy 

22if TYPE_CHECKING: 22 ↛ 23line 22 didn't jump to line 23 because the condition on line 22 was never true

23 import cupy as cp 

24 

25cp = optional_import("cupy") 

26HAS_CUPY = cp is not None 

27 

28# JAX 

29if TYPE_CHECKING: 29 ↛ 30line 29 didn't jump to line 30 because the condition on line 29 was never true

30 import jax 

31 import jax.numpy as jnp 

32 

33jax = optional_import("jax") 

34jnp = optional_import("jax.numpy") if jax is not None else None 

35HAS_JAX = jax is not None 

36 

37# TensorFlow 

38if TYPE_CHECKING: 38 ↛ 39line 38 didn't jump to line 39 because the condition on line 38 was never true

39 import tensorflow as tf 

40 

41tf = optional_import("tensorflow") 

42HAS_TENSORFLOW = tf is not None 

43 

44logger = logging.getLogger(__name__) 

45 

46# --- PyTorch Specific Helpers --- 

47# Create placeholder for nn.Module 

48# If nn (and thus nn.Module) is available, ModulePlaceholder will be nn.Module. 

49# Otherwise, ModulePlaceholder will be a placeholder class. 

50ModulePlaceholder = create_placeholder_class( 

51 "Module", # Name for the placeholder if generated 

52 base_class=nn.Module if nn else None, 

53 required_library="PyTorch" 

54) 

55 

56if HAS_TORCH: # Keep nn.Module definition conditional on PyTorch availability 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 class _RegistrationCNN_torch(ModulePlaceholder): # Inherit from placeholder or actual nn.Module 

58 def __init__(self): 

59 super().__init__() # This will call nn.Module.__init__ or Placeholder.__init__ 

60 self.conv1 = nn.Conv2d(2, 32, kernel_size=3, padding=1) 

61 self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 

62 self.conv3 = nn.Conv2d(64, 32, kernel_size=3, padding=1) 

63 self.conv4 = nn.Conv2d(32, 2, kernel_size=3, padding=1) 

64 

65 def forward(self, x: "torch.Tensor") -> "torch.Tensor": # x shape: [B, 2, H, W] 

66 x = F.relu(self.conv1(x)) 

67 x = F.relu(self.conv2(x)) 

68 x = F.relu(self.conv3(x)) 

69 x = torch.tanh(self.conv4(x)) # Output in [-1, 1] range for displacement 

70 return x 

71 

72 def _rasterize_polygons_slice_torch( 

73 polygons_gpu: List["torch.Tensor"], H: int, W: int, device: "torch.device" 

74 ) -> "torch.Tensor": 

75 """ 

76 Simplified GPU rasterization for a single slice using point-in-polygon. 

77 WARNING: This implementation is basic and slow for many/complex polygons. 

78 A production system would use optimized rasterization libraries or custom kernels. 

79 """ 

80 mask_slice = torch.zeros((H, W), dtype=torch.bool, device=device) 

81 

82 for poly_tensor in polygons_gpu: # poly_tensor shape [N_points, 2] (x, y) 

83 if poly_tensor.shape[0] < 3: continue 

84 

85 min_xy = torch.min(poly_tensor, dim=0)[0] 

86 max_xy = torch.max(poly_tensor, dim=0)[0] 

87 min_x, min_y = torch.floor(min_xy).long() 

88 max_x, max_y = torch.ceil(max_xy).long() 

89 

90 min_x = torch.clamp(min_x, 0, W - 1) 

91 max_x = torch.clamp(max_x, 0, W - 1) 

92 min_y = torch.clamp(min_y, 0, H - 1) 

93 max_y = torch.clamp(max_y, 0, H - 1) 

94 

95 if max_x < min_x or max_y < min_y: continue 

96 

97 # Create grid for the bounding box 

98 bb_H, bb_W = max_y - min_y + 1, max_x - min_x + 1 

99 yy_bb, xx_bb = torch.meshgrid( 

100 torch.arange(min_y, max_y + 1, device=device), 

101 torch.arange(min_x, max_x + 1, device=device), 

102 indexing='ij' 

103 ) # yy_bb, xx_bb shapes [bb_H, bb_W] 

104 

105 # Points to test within bounding box, shape [bb_H, bb_W, 2] 

106 test_points = torch.stack((xx_bb.float(), yy_bb.float()), dim=-1) 

107 

108 # Ray casting algorithm (vectorized attempt for one polygon) 

109 num_poly_pts = poly_tensor.shape[0] 

110 poly_x = poly_tensor[:, 0] 

111 poly_y = poly_tensor[:, 1] 

112 

113 # Replicate points for comparison with all edges 

114 # test_points: [bb_H, bb_W, 2] -> [bb_H, bb_W, 1, 2] 

115 # poly_x/y: [num_poly_pts] 

116 # j_indices: [num_poly_pts] (0, 1, ..., N-1) 

117 # k_indices: [num_poly_pts] (N-1, 0, ..., N-2) (previous vertex) 

118 j_indices = torch.arange(num_poly_pts, device=device) 

119 k_indices = (j_indices - 1 + num_poly_pts) % num_poly_pts 

120 

121 # Edges: (poly_x[j], poly_y[j]) to (poly_x[k], poly_y[k]) 

122 # Test point: (test_x, test_y) from test_points 

123 test_x = test_points[..., 0].unsqueeze(-1) # [bb_H, bb_W, 1] 

124 test_y = test_points[..., 1].unsqueeze(-1) # [bb_H, bb_W, 1] 

125 

126 # Compare test_y with y-coordinates of polygon vertices 

127 # Shape of poly_y[j_indices] is [num_poly_pts] 

128 # Need to broadcast: test_y [bb_H, bb_W, 1], poly_y[j_indices] [1, 1, num_poly_pts] 

129 cond1 = (poly_y[j_indices] <= test_y) & (test_y < poly_y[k_indices]) # Upward edge 

130 cond2 = (poly_y[k_indices] <= test_y) & (test_y < poly_y[j_indices]) # Downward edge 

131 

132 # Intersection x-coordinate: test_x < (poly_x[k] - poly_x[j]) * (test_y - poly_y[j]) / (poly_y[k] - poly_y[j]) + poly_x[j] 

133 # Avoid division by zero if poly_y[k] == poly_y[j] (horizontal edge) 

134 # delta_y = poly_y[k_indices] - poly_y[j_indices] 

135 # delta_x = poly_x[k_indices] - poly_x[j_indices] 

136 

137 # Simplified: this vectorized PIP is complex. Using iterative for clarity here. 

138 # The iterative version from sandbox was more direct to write under constraints. 

139 # Reverting to iterative for bounding box for now. 

140 current_poly_mask_bb = torch.zeros((bb_H, bb_W), dtype=torch.bool, device=device) 

141 for r_idx in range(bb_H): 

142 for c_idx in range(bb_W): 

143 abs_y, abs_x = min_y + r_idx, min_x + c_idx 

144 intersections = 0 

145 p1x, p1y = poly_tensor[-1, 0], poly_tensor[-1, 1] 

146 for i in range(num_poly_pts): 

147 p2x, p2y = poly_tensor[i, 0], poly_tensor[i, 1] 

148 if abs_y > min(p1y, p2y) and abs_y <= max(p1y, p2y) and abs_x <= max(p1x, p2x): 

149 if p1y != p2y: # Non-horizontal edge 

150 xinters = (abs_y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x 

151 if p1x == p2x or abs_x <= xinters: # Point is to the left of edge 

152 intersections += 1 

153 p1x, p1y = p2x, p2y 

154 if intersections % 2 == 1: 

155 current_poly_mask_bb[r_idx, c_idx] = True 

156 

157 mask_slice[min_y:max_y+1, min_x:max_x+1] = mask_slice[min_y:max_y+1, min_x:max_x+1] | current_poly_mask_bb 

158 return mask_slice 

159 

160 def _apply_displacement_field_torch(data_slice: "torch.Tensor", displacement_field: "torch.Tensor") -> "torch.Tensor": 

161 H_data, W_data = data_slice.shape[-2:] 

162 

163 grid_y, grid_x = torch.meshgrid( 

164 torch.linspace(-1, 1, H_data, device=data_slice.device), 

165 torch.linspace(-1, 1, W_data, device=data_slice.device), 

166 indexing='ij' 

167 ) 

168 identity_grid = torch.stack((grid_x, grid_y), dim=-1) # Shape [H, W, 2] (x,y) for grid_sample 

169 

170 # displacement_field is [2, H, W] (dx, dy), needs to be [H, W, 2] (dx, dy) then (x+dx, y+dy) 

171 # The displacement field from CNN is already scaled by tanh to [-1, 1]. 

172 # This range corresponds to displacement from -image_size/2 to +image_size/2. 

173 # grid_sample expects final coordinates in [-1, 1]. 

174 # So, new_x_norm = old_x_norm + dx_norm 

175 displaced_grid = identity_grid + displacement_field.permute(1, 2, 0) # [H, W, 2] 

176 displaced_grid = torch.clamp(displaced_grid, -1, 1) # Ensure grid is within bounds 

177 

178 data_slice_unsqueezed = data_slice.unsqueeze(0) # [1, C, H, W] or [1, H, W] 

179 if data_slice_unsqueezed.ndim == 3: # Was [H,W] -> [1,H,W], need [1,1,H,W] 

180 data_slice_unsqueezed = data_slice_unsqueezed.unsqueeze(1) 

181 

182 warped_slice = F.grid_sample( 

183 data_slice_unsqueezed, 

184 displaced_grid.unsqueeze(0), # [1, H, W, 2] 

185 mode='bilinear', padding_mode='zeros', align_corners=False # Zeros for outside mask 

186 ) 

187 return warped_slice.squeeze(0) # [C,H,W] or [H,W] 

188 

189 def _smooth_field_z_torch(displacement_field_stack: "torch.Tensor", sigma_z: float) -> "torch.Tensor": 

190 if sigma_z <= 0: return displacement_field_stack 

191 Z, C, H, W = displacement_field_stack.shape # [Z, 2, H, W] 

192 

193 kernel_size_z = max(3, int(2 * 2 * sigma_z + 1)) 

194 if kernel_size_z % 2 == 0: kernel_size_z +=1 

195 

196 coords_z = torch.arange(kernel_size_z, dtype=torch.float32, device=displacement_field_stack.device) 

197 coords_z -= (kernel_size_z - 1) / 2 

198 kernel_1d_z = torch.exp(-(coords_z**2) / (2 * sigma_z**2)) 

199 kernel_1d_z /= kernel_1d_z.sum() 

200 

201 kernel_1d_z_reshaped = kernel_1d_z.view(1, 1, kernel_size_z).repeat(C * H * W, 1, 1) # For grouped conv 

202 

203 # Reshape for 1D convolution: (N, C_in, L_in) -> (C*H*W, 1, Z) 

204 field_permuted = displacement_field_stack.permute(1,2,3,0).contiguous() # [C,H,W,Z] 

205 field_reshaped = field_permuted.view(-1, 1, Z) # [C*H*W, 1, Z] 

206 

207 padding_z = kernel_size_z // 2 

208 smoothed_reshaped = F.conv1d(field_reshaped, kernel_1d_z_reshaped, padding=padding_z, groups=C*H*W) 

209 

210 smoothed_permuted = smoothed_reshaped.view(C,H,W,Z) 

211 smoothed_stack = smoothed_permuted.permute(3,0,1,2).contiguous() # [Z,C,H,W] 

212 return smoothed_stack 

213 

214# --- Main Pipeline Function --- 

215@torch_func # Decorate with torch_func 

216def dxf_mask_pipeline( 

217 image_stack, # Expected to be a torch.Tensor if torch_func is used 

218 dxf_polygons: List[List[Tuple[float, float]]], 

219 apply_mask: bool = False, 

220 masking_mode: str = "zero_out", 

221 smoothing_sigma_z: float = 0.0, 

222 **kwargs 

223) -> Union["torch.Tensor", "cp.ndarray", "jnp.ndarray", "tf.Tensor"]: # type: ignore 

224 

225 # Assuming image_stack is (Z, H, W) or (Z, C, H, W) 

226 # If (Z,C,H,W), C is usually 1 for grayscale, or we take the first channel. 

227 if image_stack.ndim == 4: # Z, C, H, W 

228 Z, C_img, H, W = image_stack.shape 

229 if C_img > 1: logger.warning("Multi-channel image stack provided, using first channel for registration.") 

230 image_stack_reg = image_stack[:, 0, :, :] # Use first channel for registration: (Z, H, W) 

231 elif image_stack.ndim == 3: # Z, H, W 

232 Z, H, W = image_stack.shape 

233 image_stack_reg = image_stack 

234 else: 

235 raise ValueError(f"image_stack has unsupported ndim: {image_stack.ndim}. Expected 3 or 4.") 

236 

237 device = image_stack.device # image_stack is now expected to be a torch.Tensor 

238 polygons_gpu = [torch.tensor(p, dtype=torch.float32, device=device) for p in dxf_polygons] 

239 

240 initial_rasterized_masks_float = torch.zeros((Z, H, W), device=device, dtype=torch.float32) 

241 displacement_field_slices = [] 

242 

243 registration_cnn = _RegistrationCNN_torch().to(device) 

244 registration_cnn.eval() 

245 

246 for z_idx in range(Z): 

247 image_slice_gray = image_stack_reg[z_idx] # Shape [H, W] 

248 

249 img_min, img_max = torch.min(image_slice_gray), torch.max(image_slice_gray) 

250 image_slice_norm = (image_slice_gray - img_min) / (img_max - img_min + 1e-6) if img_max > img_min else torch.zeros_like(image_slice_gray) 

251 

252 raster_slice = _rasterize_polygons_slice_torch(polygons_gpu, H, W, device).float() 

253 initial_rasterized_masks_float[z_idx] = raster_slice 

254 

255 cnn_input = torch.stack([image_slice_norm, raster_slice], dim=0).unsqueeze(0) # [1, 2, H, W] 

256 with torch.no_grad(): 

257 displacement_field_slice = registration_cnn(cnn_input).squeeze(0) # [2, H, W] 

258 displacement_field_slices.append(displacement_field_slice) 

259 

260 displacement_field_stack = torch.stack(displacement_field_slices, dim=0) # [Z, 2, H, W] 

261 

262 if smoothing_sigma_z > 0: 

263 displacement_field_stack = _smooth_field_z_torch(displacement_field_stack, smoothing_sigma_z) 

264 

265 aligned_mask_slices_list = [] 

266 for z_idx in range(Z): 

267 aligned_slice = _apply_displacement_field_torch( 

268 initial_rasterized_masks_float[z_idx], 

269 displacement_field_stack[z_idx] 

270 ) # Output can be [1,H,W] or [H,W] 

271 if aligned_slice.ndim == 3 and aligned_slice.shape[0] == 1: 

272 aligned_slice = aligned_slice.squeeze(0) # to [H,W] 

273 aligned_mask_slices_list.append(aligned_slice > 0.5) # Binarize 

274 

275 aligned_mask_stack_bool = torch.stack(aligned_mask_slices_list, dim=0) # [Z, H, W] bool 

276 

277 if apply_mask: 

278 original_dtype = image_stack.dtype 

279 # Prepare mask for broadcasting if image_stack is (Z,C,H,W) 

280 mask_to_apply = aligned_mask_stack_bool.float() 

281 if image_stack.ndim == 4: # Z,C,H,W 

282 mask_to_apply = mask_to_apply.unsqueeze(1) # -> (Z,1,H,W) 

283 

284 if masking_mode == "zero_out" or masking_mode == "multiply": 

285 masked_img = image_stack.float() * mask_to_apply 

286 return masked_img.to(original_dtype) 

287 elif masking_mode == "nan_out": 

288 masked_img_float = image_stack.float() 

289 nans = torch.full_like(masked_img_float, float('nan')) 

290 return torch.where(mask_to_apply.bool(), masked_img_float, nans) # Nan where mask is False 

291 else: 

292 raise ValueError(f"Unknown masking_mode: {masking_mode}") 

293 else: 

294 return aligned_mask_stack_bool