Coverage for openhcs/processing/backends/enhance/dl_edof_unsupervised.py: 11.7%
139 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
2from typing import TYPE_CHECKING, List, Optional
4from openhcs.core.memory.decorators import torch as torch_func
5from openhcs.utils.import_utils import optional_import, create_placeholder_class
9# For type checking only
10if TYPE_CHECKING: 10 ↛ 11line 10 didn't jump to line 11 because the condition on line 10 was never true
11 import torch
12 import torch.nn as nn
13 import torch.nn.functional as F
15# Import torch modules using optional_import
16torch = optional_import("torch")
17nn = optional_import("torch.nn") if torch is not None else None
18F = optional_import("torch.nn.functional") if torch is not None else None
20nnModule = create_placeholder_class(
21 "Module", # Name for the placeholder if generated
22 base_class=nn.Module if nn else None,
23 required_library="PyTorch"
24)
26# Helper for sharpness loss
27def laplacian_filter_torch(image_batch: "torch.Tensor") -> "torch.Tensor":
28 """
29 Applies a Laplacian filter to a batch of 2D images.
30 Input: (N, 1, H, W)
31 Output: (N, 1, H, W)
32 """
33 kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]],
34 dtype=image_batch.dtype, device=image_batch.device).reshape(1, 1, 3, 3)
35 return F.conv2d(image_batch, kernel, padding=1)
37def extract_patches_2d_from_3d_stack(
38 stack_3d: "torch.Tensor", patch_size: int, stride: int
39) -> torch.Tensor:
40 """
41 Extracts 2D patches from a 3D stack.
42 Input stack_3d: [Z, H, W]
43 Output patches: [N, Z, patch_size, patch_size], where N is num_patches.
44 """
45 Z, H, W = stack_3d.shape
46 patches = stack_3d.unfold(1, patch_size, stride)
47 patches = patches.unfold(2, patch_size, stride)
48 patches = patches.permute(1, 2, 0, 3, 4)
49 patches = patches.reshape(-1, Z, patch_size, patch_size)
50 return patches
52def blend_patches_to_2d_image(
53 patch_outputs: List["torch.Tensor"], # List of [1, patch_size, patch_size]
54 target_h: int,
55 target_w: int,
56 patch_size: int,
57 stride: int,
58 device: torch.device
59) -> torch.Tensor:
60 """
61 Blends 2D fused patches back into a single 2D image.
62 Input patch_outputs: List of [1, patch_size, patch_size] tensors.
63 Output: [1, target_h, target_w]
64 """
65 fused_image = torch.zeros((target_h, target_w), dtype=torch.float32, device=device)
66 count_map = torch.zeros((target_h, target_w), dtype=torch.float32, device=device)
68 num_blocks_h = (target_h - patch_size) // stride + 1
69 num_blocks_w = (target_w - patch_size) // stride + 1
71 patch_idx = 0
72 for i in range(num_blocks_h):
73 for j in range(num_blocks_w):
74 if patch_idx >= len(patch_outputs):
75 # This case should ideally not be reached if inputs are consistent
76 break
78 patch_content = patch_outputs[patch_idx].squeeze(0) # [patch_size, patch_size]
80 h_start = i * stride
81 w_start = j * stride
83 h_end = h_start + patch_size
84 w_end = w_start + patch_size
86 fused_image[h_start:h_end, w_start:w_end] += patch_content
87 count_map[h_start:h_end, w_start:w_end] += 1.0 # Use float for count_map
88 patch_idx += 1
90 fused_image /= count_map.clamp(min=1.0)
91 return fused_image.unsqueeze(0)
93class UNetLite(nnModule):
94 def __init__(self, in_channels_z: int, model_config_depth: int):
95 super().__init__()
96 multiplier = 1 if model_config_depth == 3 else 2
97 ch1 = 32 * multiplier
98 ch2 = 64 * multiplier
100 self.conv1 = nn.Conv2d(in_channels_z, ch1, kernel_size=3, padding=1)
101 self.relu1 = nn.ReLU(inplace=True)
102 self.conv2 = nn.Conv2d(ch1, ch2, kernel_size=3, padding=1, stride=2)
103 self.relu2 = nn.ReLU(inplace=True)
105 self.upconv = nn.ConvTranspose2d(ch2, ch1, kernel_size=2, stride=2)
106 self.relu3 = nn.ReLU(inplace=True)
107 self.conv_out = nn.Conv2d(ch1, 1, kernel_size=3, padding=1)
108 self.sigmoid = nn.Sigmoid()
110 def forward(self, x: "torch.Tensor") -> "torch.Tensor":
111 x1 = self.relu1(self.conv1(x))
112 x2 = self.relu2(self.conv2(x1))
113 x3 = self.relu3(self.upconv(x2))
114 out = self.sigmoid(self.conv_out(x3))
115 return out
117def sharpness_loss_fn(fused_patch: "torch.Tensor") -> "torch.Tensor":
118 laplacian_response = laplacian_filter_torch(fused_patch)
119 return -torch.var(laplacian_response, dim=(-1, -2), unbiased=False).mean()
121def consistency_loss_fn(fused_patch: "torch.Tensor", input_patch_stack: "torch.Tensor") -> "torch.Tensor":
122 diff_sq = (fused_patch - input_patch_stack)**2
123 min_diff_sq_over_z = torch.min(diff_sq, dim=1)[0]
124 return torch.mean(min_diff_sq_over_z)
126@torch_func
127def dl_edof_unsupervised(
128 image_stack: "torch.Tensor",
129 model_depth: Optional[int] = None,
130 patch_size: Optional[int] = None,
131 stride: Optional[int] = None,
132 denoise: bool = False,
133 normalize: bool = False,
134) -> torch.Tensor:
135 if torch is None:
136 raise ImportError("PyTorch is required for this function")
137 if not (image_stack.ndim == 3 and str(image_stack.device.type) == 'cuda'):
138 raise ValueError("Input image_stack must be a 3D CUDA tensor [Z, H, W]. "
139 f"Got {image_stack.ndim}D tensor on {image_stack.device.type}.")
141 Z_orig, H_orig, W_orig = image_stack.shape
142 device = image_stack.device
143 original_dtype = image_stack.dtype
145 # Memory usage warning for large images
146 total_elements = Z_orig * H_orig * W_orig
147 if total_elements > 100_000_000: # 100M elements
148 import logging
149 logger = logging.getLogger(__name__)
150 logger.warning(f"⚠️ Large image stack ({total_elements:,} elements) may cause high memory usage in deep learning EDoF. "
151 f"Consider using smaller patch sizes or processing smaller regions.")
152 logger.warning(f"Current image size: {Z_orig}×{H_orig}×{W_orig}")
154 # Estimate patch memory usage
155 current_patch_size = patch_size or max(H_orig, W_orig) // 8
156 current_stride = stride or current_patch_size // 2
157 num_patches_h = (H_orig - current_patch_size) // current_stride + 1
158 num_patches_w = (W_orig - current_patch_size) // current_stride + 1
159 total_patches = num_patches_h * num_patches_w
161 if total_patches > 1000:
162 import logging
163 logger = logging.getLogger(__name__)
164 logger.warning(f"⚠️ Large number of patches ({total_patches:,}) may cause high memory usage. "
165 f"Consider increasing stride or reducing patch size.")
167 current_patch_size = patch_size
168 if current_patch_size is None:
169 current_patch_size = max(H_orig, W_orig) // 8
171 current_patch_size = max(current_patch_size, 16) # Min patch size
172 if current_patch_size % 2 != 0: # Ensure even for CNN
173 current_patch_size +=1
174 current_patch_size = min(current_patch_size, H_orig, W_orig)
177 current_stride = stride
178 if current_stride is None:
179 current_stride = current_patch_size // 2
180 if current_stride <=0: current_stride = 1
183 current_model_depth_config = model_depth
184 if current_model_depth_config is None:
185 current_model_depth_config = 3 if H_orig < 1024 else 5
187 if normalize:
188 stack_f32 = image_stack.float() / 65535.0
189 else:
190 stack_f32 = image_stack.float()
192 if denoise:
193 stack_to_blur = stack_f32.unsqueeze(1)
194 blurred_stack = F.gaussian_blur(stack_to_blur, kernel_size=(3,3), sigma=(0.5,0.5))
195 stack_f32 = blurred_stack.squeeze(1)
197 patches = extract_patches_2d_from_3d_stack(stack_f32, current_patch_size, current_stride)
199 fused_patch_outputs = []
200 num_epochs_per_patch = 10
202 for i in range(patches.shape[0]):
203 patch_stack_z = patches[i]
204 model_input = patch_stack_z.unsqueeze(0).to(device)
205 Z_patch = model_input.shape[1]
207 model = UNetLite(in_channels_z=Z_patch, model_config_depth=current_model_depth_config).to(device)
208 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
210 for epoch in range(num_epochs_per_patch):
211 model.train()
212 optimizer.zero_grad()
213 fused_output_patch = model(model_input)
214 loss_s = sharpness_loss_fn(fused_output_patch)
215 loss_c = consistency_loss_fn(fused_output_patch, model_input)
216 total_loss = loss_s + loss_c
217 total_loss.backward()
218 optimizer.step()
220 model.eval()
221 with torch.no_grad():
222 final_fused_patch = model(model_input)
223 fused_patch_outputs.append(final_fused_patch.detach().squeeze(0))
225 fused_2d_normalized = blend_patches_to_2d_image(
226 fused_patch_outputs, H_orig, W_orig, current_patch_size, current_stride, device
227 )
229 fused_uint16 = fused_2d_normalized.clamp(0, 1).mul(65535.0).to(original_dtype)
230 return fused_uint16