Coverage for openhcs/processing/backends/enhance/self_supervised_2d_deconvolution.py: 14.0%
134 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
2import logging
3from typing import Optional, Tuple
5# Import torch decorator and optional_import utility
6from openhcs.utils.import_utils import optional_import, create_placeholder_class
7from openhcs.core.memory.decorators import torch as torch_func
9# --- PyTorch Imports as optional dependencies ---
10torch = optional_import("torch")
11nn = optional_import("torch.nn") if torch is not None else None
12F = optional_import("torch.nn.functional") if torch is not None else None
13if torch is not None: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true
14 from torch.fft import irfft2, rfft2
15else:
16 irfft2 = None
17 rfft2 = None
19logger = logging.getLogger(__name__)
21nnModule = create_placeholder_class(
22 "Module", # Name for the placeholder if generated
23 base_class=nn.Module if nn else None,
24 required_library="PyTorch"
25)
26# --- PyTorch Specific Models and Helpers for 2D ---
27class _Simple2DCNN_torch(nnModule):
28 """Simple 2D CNN for deconvolution - optimized for 2D data per paper."""
29 def __init__(self, in_channels=1, out_channels=1, features=(96, 192)): # Paper: 96 initial features for 2D
30 super().__init__()
31 self.conv_block = nn.Sequential(
32 nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1),
33 nn.ReLU(inplace=True),
34 nn.Conv2d(features[0], features[1], kernel_size=3, padding=1),
35 nn.ReLU(inplace=True),
36 nn.Conv2d(features[1], out_channels, kernel_size=3, padding=1)
37 )
39 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, H, W)
40 return self.conv_block(x)
42class _LearnedBlur2D_torch(nnModule):
43 """Learned blur for 2D deconvolution."""
44 def __init__(self, kernel_size=3):
45 super().__init__()
46 self.blur_conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
47 # Initialize weights to be somewhat like a Gaussian blur
48 if kernel_size > 0:
49 weights = torch.ones(kernel_size, kernel_size)
50 weights = weights / weights.sum()
51 self.blur_conv.weight.data = weights.reshape(1, 1, kernel_size, kernel_size)
53 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, 1, H, W)
54 return self.blur_conv(x)
56def _gaussian_kernel_2d_torch(shape: Tuple[int, int], sigma: Tuple[float, float], device) -> torch.Tensor:
57 """Generate 2D Gaussian kernel."""
58 coords_h = torch.arange(shape[0], dtype=torch.float32, device=device) - (shape[0] - 1) / 2.0
59 coords_w = torch.arange(shape[1], dtype=torch.float32, device=device) - (shape[1] - 1) / 2.0
61 kernel_h = torch.exp(-coords_h**2 / (2 * sigma[0]**2))
62 kernel_w = torch.exp(-coords_w**2 / (2 * sigma[1]**2))
64 kernel = torch.outer(kernel_h, kernel_w)
65 return kernel / torch.sum(kernel)
67def _blur_fft_2d_torch(image: torch.Tensor, kernel: torch.Tensor, device) -> torch.Tensor:
68 """FFT-based 2D blur convolution."""
69 # image: (B, 1, H, W), kernel: (kH, kW)
70 B, C, H, W = image.shape
71 kH, kW = kernel.shape
73 # Pad kernel to image size for FFT
74 kernel_padded = F.pad(kernel, (
75 (W - kW) // 2, (W - kW + 1) // 2,
76 (H - kH) // 2, (H - kH + 1) // 2,
77 ))
78 kernel_padded = kernel_padded.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
80 img_fft = rfft2(image, dim=(-2, -1))
81 ker_fft = rfft2(kernel_padded.to(device), dim=(-2, -1))
83 blurred_fft = img_fft * ker_fft
84 blurred_img = irfft2(blurred_fft, s=(H, W), dim=(-2, -1))
85 return blurred_img
87def _extract_random_patches_2d_torch(
88 image_single_batch_channel: torch.Tensor, # (H, W)
89 patch_size_hw: Tuple[int, int],
90 num_patches: int,
91 device
92) -> torch.Tensor: # (num_patches, 1, pH, pW)
93 """Extract random 2D patches - GPU-native."""
94 H, W = image_single_batch_channel.shape
95 pH, pW = patch_size_hw
96 patches = torch.empty((num_patches, 1, pH, pW), device=device)
97 for i in range(num_patches):
98 # Force GPU device for random operations - NO CPU FALLBACK
99 h_start = torch.randint(0, H - pH + 1, (1,), device=device).item()
100 w_start = torch.randint(0, W - pW + 1, (1,), device=device).item()
101 patch = image_single_batch_channel[
102 h_start:h_start+pH, w_start:w_start+pW
103 ]
104 patches[i, 0, ...] = patch
105 return patches
107# --- Main 2D Deconvolution Function ---
108@torch_func
109def self_supervised_2d_deconvolution(
110 image: torch.Tensor, # Expected (H, W) or (1, H, W)
111 apply_deconvolution: bool = True,
112 n_epochs: int = 10, # Reduced for testing
113 patch_size_hw: Tuple[int, int] = (128, 128), # Paper: 128x128 for 2D
114 mask_fraction: float = 0.005, # Paper: 0.5%
115 sigma_noise: float = 0.2,
116 lambda_rec: float = 1.0,
117 lambda_inv_d: float = 2.0, # Paper: deconvolved invariance for 2D
118 lambda_bound_d: float = 0.1, # Paper: boundary loss for 2D
119 min_val: float = 0.0,
120 max_val: float = 1.0,
121 learning_rate: float = 4e-4, # Paper: Adam 4e-4
122 blur_mode: str = "gaussian", # 'fft', 'gaussian', 'learned'
123 blur_sigma_spatial: float = 1.5,
124 blur_kernel_size: int = 5,
125 **kwargs
126) -> torch.Tensor:
127 """
128 Self-supervised 2D deconvolution optimized for 2D imaging data.
130 Based on the paper's optimal 2D configuration:
131 - 96 initial features (vs 48 for 3D)
132 - 128x128 patches (vs 64x64x64 for 3D)
133 - Batch size 16 (vs 4 for 3D)
134 - Loss (4) with deconvolved invariance (vs reconvolved for 3D)
135 """
136 if not isinstance(image, torch.Tensor):
137 raise TypeError(f"Input image must be a PyTorch Tensor. Got {type(image)}")
140 if not apply_deconvolution:
141 return image
143 # --- PyTorch Backend Implementation ---
144 device = image.device
146 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed
147 if device.type != "cuda":
148 raise RuntimeError(f"@torch_func requires CUDA tensor, got device: {device}")
150 # Ensure input is (1, 1, H, W)
151 if image.ndim == 2: # (H, W)
152 img_norm = image.unsqueeze(0).unsqueeze(0).float()
153 elif image.ndim == 3: # (1, H, W)
154 img_norm = image.unsqueeze(1).float()
155 elif image.ndim == 4: # (1, 1, H, W)
156 img_norm = image.float()
157 else:
158 raise ValueError(f"Unsupported image ndim: {image.ndim}")
160 # Normalize to [min_val, max_val]
161 img_min_orig, img_max_orig = torch.min(img_norm), torch.max(img_norm)
162 if img_max_orig > img_min_orig:
163 img_norm = (img_norm - img_min_orig) / (img_max_orig - img_min_orig)
164 img_norm = img_norm * (max_val - min_val) + min_val
165 else:
166 img_norm = torch.full_like(img_norm, min_val)
168 # Create 2D model with paper's optimal architecture
169 f_model = _Simple2DCNN_torch().to(device)
171 g_model_blur: Optional[nn.Module] = None
172 fixed_blur_kernel: Optional[torch.Tensor] = None
174 if blur_mode == "learned":
175 g_model_blur = _LearnedBlur2D_torch(kernel_size=blur_kernel_size).to(device)
176 optimizer = torch.optim.Adam(list(f_model.parameters()) + list(g_model_blur.parameters()), lr=learning_rate)
177 else:
178 optimizer = torch.optim.Adam(f_model.parameters(), lr=learning_rate)
179 if blur_mode in ["gaussian", "fft"]:
180 fixed_blur_kernel = _gaussian_kernel_2d_torch(
181 (blur_kernel_size, blur_kernel_size),
182 (blur_sigma_spatial, blur_sigma_spatial), device
183 )
185 # Training Loop
186 for epoch in range(n_epochs):
187 f_model.train()
188 if g_model_blur:
189 g_model_blur.train()
191 # Extract 2D patch
192 current_patch_orig = _extract_random_patches_2d_torch(
193 img_norm.squeeze(0).squeeze(0), patch_size_hw, 1, device
194 )
196 # Create masked variant
197 mask = (torch.rand_like(current_patch_orig) < mask_fraction).bool()
198 noise = (torch.randn_like(current_patch_orig) * sigma_noise).clamp(min_val, max_val)
199 current_patch_masked = torch.where(mask, noise, current_patch_orig)
201 # Forward pass f(x)
202 f_x_orig = f_model(current_patch_orig).clamp(min_val, max_val)
203 f_x_masked = f_model(current_patch_masked).clamp(min_val, max_val)
205 # Apply blur g(f(x))
206 if blur_mode == "learned":
207 g_f_x_orig = g_model_blur(f_x_orig)
208 g_f_x_masked = g_model_blur(f_x_masked)
209 elif blur_mode == "fft":
210 g_f_x_orig = _blur_fft_2d_torch(f_x_orig, fixed_blur_kernel, device)
211 g_f_x_masked = _blur_fft_2d_torch(f_x_masked, fixed_blur_kernel, device)
212 elif blur_mode == "gaussian":
213 conv_kernel = fixed_blur_kernel.unsqueeze(0).unsqueeze(0).to(device)
214 pad_size = blur_kernel_size // 2
215 g_f_x_orig = F.conv2d(f_x_orig, conv_kernel, padding=pad_size)
216 g_f_x_masked = F.conv2d(f_x_masked, conv_kernel, padding=pad_size)
217 else:
218 raise ValueError(f"Unknown blur_mode: {blur_mode}")
220 # Losses - Paper's optimal Loss (4) for 2D: deconvolved invariance
221 loss_rec = F.mse_loss(g_f_x_masked, current_patch_orig)
223 # Deconvolved invariance loss (before PSF) - optimal for 2D per paper
224 loss_inv_d = torch.tensor(0.0, device=device)
225 if mask.sum() > 0:
226 loss_inv_d = F.mse_loss(f_x_orig[mask], f_x_masked[mask])
228 # Boundary loss on deconvolved output
229 loss_bound_d = (torch.relu(f_x_masked - max_val) + torch.relu(min_val - f_x_masked)).mean()
230 loss_bound_d += (torch.relu(f_x_orig - max_val) + torch.relu(min_val - f_x_orig)).mean()
231 loss_bound_d /= 2.0
233 total_loss = lambda_rec * loss_rec + lambda_inv_d * loss_inv_d + lambda_bound_d * loss_bound_d
235 optimizer.zero_grad()
236 total_loss.backward()
237 optimizer.step()
239 if epoch % (n_epochs // 10 if n_epochs >= 10 else 1) == 0:
240 logger.info(f"2D Deconv Epoch {epoch}/{n_epochs}, Loss: {total_loss.item():.4f} "
241 f"(Rec: {loss_rec.item():.4f}, Inv_d: {loss_inv_d.item():.4f}, Bound_d: {loss_bound_d.item():.4f})")
243 # Inference
244 f_model.eval()
245 with torch.no_grad():
246 deconvolved_norm = f_model(img_norm).clamp(min_val, max_val)
248 # Denormalize
249 if img_max_orig > img_min_orig:
250 deconvolved_final = (deconvolved_norm - min_val) / (max_val - min_val)
251 deconvolved_final = deconvolved_final * (img_max_orig - img_min_orig) + img_min_orig
252 else:
253 deconvolved_final = torch.full_like(deconvolved_norm, img_min_orig)
255 # Return in original input shape
256 if image.ndim == 2:
257 return deconvolved_final.squeeze(0).squeeze(0)
258 elif image.ndim == 3:
259 return deconvolved_final.squeeze(1)
260 return deconvolved_final