Coverage for openhcs/processing/backends/enhance/self_supervised_3d_deconvolution.py: 13.5%
148 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1from __future__ import annotations
2import logging
3from typing import Optional, Tuple, Tuple
5# Import torch decorator and optional_import utility
6from openhcs.utils.import_utils import optional_import, create_placeholder_class
7from openhcs.core.memory.decorators import torch as torch_func
9# --- PyTorch Imports as optional dependencies ---
10torch = optional_import("torch")
11nn = optional_import("torch.nn") if torch is not None else None
12F = optional_import("torch.nn.functional") if torch is not None else None
13if torch is not None: 13 ↛ 14line 13 didn't jump to line 14 because the condition on line 13 was never true
14 from torch.fft import irfftn, rfftn
15else:
16 irfftn = None
17 rfftn = None
19logger = logging.getLogger(__name__)
21nnModule = create_placeholder_class(
22 "Module", # Name for the placeholder if generated
23 base_class=nn.Module if nn else None,
24 required_library="PyTorch"
25)
27# --- PyTorch Specific Models and Helpers ---
28class _Simple3DCNN_torch(nnModule):
29 """Simple 3D CNN for deconvolution - optimized for 3D data per paper."""
30 def __init__(self, in_channels=1, out_channels=1, features=(48, 96)): # Paper: 48 initial features for 3D
31 super().__init__()
32 self.conv_block = nn.Sequential(
33 nn.Conv3d(in_channels, features[0], kernel_size=3, padding=1),
34 nn.ReLU(inplace=True),
35 nn.Conv3d(features[0], features[1], kernel_size=3, padding=1),
36 nn.ReLU(inplace=True),
37 nn.Conv3d(features[1], out_channels, kernel_size=3, padding=1)
38 )
40 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, D, H, W)
41 return self.conv_block(x)
43class _LearnedBlur3D_torch(nnModule):
44 def __init__(self, kernel_size=3):
45 super().__init__()
46 self.blur_conv = nn.Conv3d(1, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False)
47 # Initialize weights to be somewhat like a Gaussian blur
48 # For simplicity, using default initialization or a simple averaging kernel
49 if kernel_size > 0 :
50 weights = torch.ones(kernel_size, kernel_size, kernel_size)
51 weights = weights / weights.sum()
52 self.blur_conv.weight.data = weights.reshape(1,1,kernel_size,kernel_size,kernel_size)
54 def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, 1, D, H, W)
55 return self.blur_conv(x)
57def _gaussian_kernel_3d_torch(shape: Tuple[int, int, int], sigma: Tuple[float, float, float], device) -> torch.Tensor:
58 coords_d = torch.arange(shape[0], dtype=torch.float32, device=device) - (shape[0] - 1) / 2.0
59 coords_h = torch.arange(shape[1], dtype=torch.float32, device=device) - (shape[1] - 1) / 2.0
60 coords_w = torch.arange(shape[2], dtype=torch.float32, device=device) - (shape[2] - 1) / 2.0
62 kernel_d = torch.exp(-coords_d**2 / (2 * sigma[0]**2))
63 kernel_h = torch.exp(-coords_h**2 / (2 * sigma[1]**2))
64 kernel_w = torch.exp(-coords_w**2 / (2 * sigma[2]**2))
66 # Create 3D kernel using broadcasting instead of nested outer products
67 kernel = kernel_d[:, None, None] * kernel_h[None, :, None] * kernel_w[None, None, :]
68 return kernel / torch.sum(kernel)
70def _blur_fft_torch(volume: torch.Tensor, kernel: torch.Tensor, device) -> torch.Tensor:
71 # volume: (B, 1, D, H, W), kernel: (kD, kH, kW)
72 B, C, D, H, W = volume.shape
73 kD, kH, kW = kernel.shape
75 # Pad kernel to volume size for FFT
76 kernel_padded = F.pad(kernel, (
77 (W - kW) // 2, (W - kW + 1) // 2,
78 (H - kH) // 2, (H - kH + 1) // 2,
79 (D - kD) // 2, (D - kD + 1) // 2,
80 ))
81 kernel_padded = kernel_padded.unsqueeze(0).unsqueeze(0) # (1, 1, D, H, W)
83 vol_fft = rfftn(volume, dim=(-3, -2, -1))
84 ker_fft = rfftn(kernel_padded.to(device), dim=(-3, -2, -1))
86 blurred_fft = vol_fft * ker_fft
87 blurred_vol = irfftn(blurred_fft, s=(D, H, W), dim=(-3, -2, -1))
88 return blurred_vol
90def _blur_gaussian_conv_torch(volume: torch.Tensor, sigma_spatial: float, sigma_depth: float, kernel_size: int, device) -> torch.Tensor:
91 kernel_d_1d = _gaussian_kernel_3d_torch((kernel_size,1,1), (sigma_depth,1,1), device)[:,0,0]
92 kernel_h_1d = _gaussian_kernel_3d_torch((1,kernel_size,1), (1,sigma_spatial,1), device)[0,:,0]
93 kernel_w_1d = _gaussian_kernel_3d_torch((1,1,kernel_size), (1,1,sigma_spatial), device)[0,0,:]
95 kernel = kernel_d_1d[:,None,None] * kernel_h_1d[None,:,None] * kernel_w_1d[None,None,:]
96 kernel = kernel.unsqueeze(0).unsqueeze(0).to(device) # (1, 1, kD, kH, kW)
98 padding = kernel_size // 2
99 return F.conv3d(volume, kernel, padding=padding)
101def _extract_random_patches_torch(
102 volume_single_batch_channel: torch.Tensor, # (D, H, W)
103 patch_size_dhw: Tuple[int, int, int],
104 num_patches: int,
105 device
106) -> torch.Tensor: # (num_patches, 1, pD, pH, pW)
107 D, H, W = volume_single_batch_channel.shape
108 pD, pH, pW = patch_size_dhw
109 patches = torch.empty((num_patches, 1, pD, pH, pW), device=device)
110 for i in range(num_patches):
111 # Force GPU device for random operations - NO CPU FALLBACK
112 d_start = torch.randint(0, D - pD + 1, (1,), device=device).item()
113 h_start = torch.randint(0, H - pH + 1, (1,), device=device).item()
114 w_start = torch.randint(0, W - pW + 1, (1,), device=device).item()
115 patch = volume_single_batch_channel[
116 d_start:d_start+pD, h_start:h_start+pH, w_start:w_start+pW
117 ]
118 patches[i, 0, ...] = patch
119 return patches
121# --- Main Deconvolution Function ---
122@torch_func
123def self_supervised_3d_deconvolution(
124 image_volume: torch.Tensor, # Expected (1, Z, H, W) or (Z,H,W)
125 apply_deconvolution: bool = True,
126 n_epochs: int = 10, # Reduced default for quick test (was 100)
127 patch_size: Tuple[int, int, int] = (16, 32, 32), # Reduced for small test images (paper: 64x64x64)
128 mask_fraction: float = 0.005, # Paper: 0.5%
129 sigma_noise: float = 0.2,
130 lambda_rec: float = 1.0,
131 lambda_inv: float = 2.0, # Paper: reconvolved invariance for 3D
132 lambda_bound: float = 0.0, # Paper: λbound = 0 for 3D
133 min_val: float = 0.0,
134 max_val: float = 1.0,
135 learning_rate: float = 4e-4,
136 blur_mode: str = "gaussian", # 'fft', 'gaussian', 'learned'
137 blur_sigma_spatial: float = 1.5,
138 blur_sigma_depth: float = 1.5,
139 blur_kernel_size: int = 5, # For gaussian/learned conv blur
140 **kwargs
141) -> torch.Tensor:
143 if not isinstance(image_volume, torch.Tensor):
144 raise TypeError(f"Input image_volume must be a PyTorch Tensor. Got {type(image_volume)}")
146 # --- Parameters already extracted from function signature ---
147 patch_size_dhw = tuple(patch_size) # Convert to tuple for consistency
149 if not apply_deconvolution:
150 return image_volume
152 # --- PyTorch Backend Implementation ---
153 device = image_volume.device
155 # FAIL LOUDLY if not on CUDA - no CPU fallback allowed
156 if device.type != "cuda":
157 raise RuntimeError(f"@torch_func requires CUDA tensor, got device: {device}")
159 # Ensure input is (1, 1, D, H, W)
160 if image_volume.ndim == 3: # (D, H, W)
161 img_vol_norm = image_volume.unsqueeze(0).unsqueeze(0).float()
162 elif image_volume.ndim == 4: # (1, D, H, W)
163 img_vol_norm = image_volume.unsqueeze(1).float()
164 elif image_volume.ndim == 5: # (1, 1, D, H, W)
165 img_vol_norm = image_volume.float()
166 else:
167 raise ValueError(f"Unsupported image_volume ndim: {image_volume.ndim}")
169 # Normalize to [min_val, max_val] (typically [0,1])
170 img_min_orig, img_max_orig = torch.min(img_vol_norm), torch.max(img_vol_norm)
171 if img_max_orig > img_min_orig:
172 img_vol_norm = (img_vol_norm - img_min_orig) / (img_max_orig - img_min_orig) # to [0,1]
173 img_vol_norm = img_vol_norm * (max_val - min_val) + min_val # to [min_val, max_val]
174 else: # Constant image
175 img_vol_norm = torch.full_like(img_vol_norm, min_val)
177 f_model = _Simple3DCNN_torch().to(device)
179 g_model_blur: Optional[nn.Module] = None
180 fixed_blur_kernel: Optional[torch.Tensor] = None
182 if blur_mode == "learned":
183 g_model_blur = _LearnedBlur3D_torch(kernel_size=blur_kernel_size).to(device)
184 optimizer = torch.optim.Adam(list(f_model.parameters()) + list(g_model_blur.parameters()), lr=learning_rate)
185 else: # fft or gaussian (fixed kernel)
186 optimizer = torch.optim.Adam(f_model.parameters(), lr=learning_rate)
187 if blur_mode == "gaussian" or blur_mode == "fft": # FFT also needs a kernel
188 fixed_blur_kernel = _gaussian_kernel_3d_torch(
189 (blur_kernel_size, blur_kernel_size, blur_kernel_size),
190 (blur_sigma_depth, blur_sigma_spatial, blur_sigma_spatial), device
191 )
193 # Training Loop
194 for epoch in range(n_epochs):
195 f_model.train()
196 if g_model_blur: g_model_blur.train()
198 # Extract one patch for this step (batch_size=1 for patches)
199 # Input to _extract_random_patches_torch is (D,H,W)
200 current_patch_orig = _extract_random_patches_torch(img_vol_norm.squeeze(0).squeeze(0), patch_size_dhw, 1, device)
201 # current_patch_orig shape: (1, 1, pD, pH, pW)
203 # Create masked variant
204 mask = (torch.rand_like(current_patch_orig) < mask_fraction).bool()
205 noise = (torch.randn_like(current_patch_orig) * sigma_noise).clamp(min_val, max_val)
206 current_patch_masked = torch.where(mask, noise, current_patch_orig)
208 # Forward pass f(x)
209 f_x_orig = f_model(current_patch_orig).clamp(min_val, max_val)
210 f_x_masked = f_model(current_patch_masked).clamp(min_val, max_val)
212 # Apply blur g(f(x))
213 if blur_mode == "learned":
214 assert g_model_blur is not None
215 g_f_x_orig = g_model_blur(f_x_orig)
216 g_f_x_masked = g_model_blur(f_x_masked)
217 elif blur_mode == "fft":
218 assert fixed_blur_kernel is not None
219 g_f_x_orig = _blur_fft_torch(f_x_orig, fixed_blur_kernel, device)
220 g_f_x_masked = _blur_fft_torch(f_x_masked, fixed_blur_kernel, device)
221 elif blur_mode == "gaussian": # Conv based Gaussian
222 assert fixed_blur_kernel is not None # Re-use kernel logic for conv
223 # For conv3d, kernel needs to be (out_c, in_c/groups, kD, kH, kW)
224 conv_kernel = fixed_blur_kernel.unsqueeze(0).unsqueeze(0).to(device)
225 pad_size = blur_kernel_size // 2
226 g_f_x_orig = F.conv3d(f_x_orig, conv_kernel, padding=pad_size)
227 g_f_x_masked = F.conv3d(f_x_masked, conv_kernel, padding=pad_size)
228 else:
229 raise ValueError(f"Unknown blur_mode: {blur_mode}")
231 # Losses - Paper's optimal Loss (3) for 3D: reconvolved invariance
232 loss_rec = F.mse_loss(g_f_x_masked, current_patch_orig)
234 # Reconvolved invariance loss (after PSF) - optimal for 3D per paper
235 loss_inv = torch.tensor(0.0, device=device)
236 if mask.sum() > 0: # Only if some pixels were masked
237 loss_inv = F.mse_loss(g_f_x_orig[mask], g_f_x_masked[mask])
239 # Boundary loss on reconvolved output (λbound = 0 for 3D per paper)
240 loss_bound = torch.tensor(0.0, device=device)
241 if lambda_bound > 0:
242 loss_bound_f_masked = (torch.relu(g_f_x_masked - max_val) + torch.relu(min_val - g_f_x_masked)).mean()
243 loss_bound_f_orig = (torch.relu(g_f_x_orig - max_val) + torch.relu(min_val - g_f_x_orig)).mean()
244 loss_bound = (loss_bound_f_masked + loss_bound_f_orig) / 2.0
246 total_loss = lambda_rec * loss_rec + lambda_inv * loss_inv + lambda_bound * loss_bound
248 optimizer.zero_grad()
249 total_loss.backward()
250 optimizer.step()
252 if epoch % (n_epochs // 10 if n_epochs >=10 else 1) == 0:
253 logger.info(f"3D Deconv Epoch {epoch}/{n_epochs}, Loss: {total_loss.item():.4f} "
254 f"(Rec: {loss_rec.item():.4f}, Inv: {loss_inv.item():.4f}, Bound: {loss_bound.item():.4f})")
256 # Inference
257 f_model.eval()
258 with torch.no_grad():
259 # Full volume inference - may need patching for large volumes on limited GPU
260 # Assuming img_vol_norm is (1, 1, D, H, W)
261 deconvolved_norm = f_model(img_vol_norm).clamp(min_val, max_val)
263 # Denormalize from [min_val, max_val] back to original image range [0, orig_max_val_if_uint]
264 if img_max_orig > img_min_orig:
265 deconvolved_final = (deconvolved_norm - min_val) / (max_val - min_val) # to [0,1]
266 deconvolved_final = deconvolved_final * (img_max_orig - img_min_orig) + img_min_orig
267 else: # Constant image
268 deconvolved_final = torch.full_like(deconvolved_norm, img_min_orig)
270 # Return in the original input shape format if it was (1,D,H,W) or (D,H,W)
271 if image_volume.ndim == 3: return deconvolved_final.squeeze(0).squeeze(0)
272 if image_volume.ndim == 4: return deconvolved_final.squeeze(1)
273 return deconvolved_final # (1,1,D,H,W)