Coverage for openhcs/processing/backends/enhance/dl_edof_unsupervised.py: 11.7%

139 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1from __future__ import annotations 

2from typing import TYPE_CHECKING, List, Optional 

3 

4from openhcs.core.memory.decorators import torch as torch_func 

5from openhcs.utils.import_utils import optional_import, create_placeholder_class 

6 

7 

8 

9# For type checking only 

10if TYPE_CHECKING: 10 ↛ 11line 10 didn't jump to line 11 because the condition on line 10 was never true

11 import torch 

12 import torch.nn as nn 

13 import torch.nn.functional as F 

14 

15# Import torch modules using optional_import 

16torch = optional_import("torch") 

17nn = optional_import("torch.nn") if torch is not None else None 

18F = optional_import("torch.nn.functional") if torch is not None else None 

19 

20nnModule = create_placeholder_class( 

21 "Module", # Name for the placeholder if generated 

22 base_class=nn.Module if nn else None, 

23 required_library="PyTorch" 

24) 

25 

26# Helper for sharpness loss 

27def laplacian_filter_torch(image_batch: "torch.Tensor") -> "torch.Tensor": 

28 """ 

29 Applies a Laplacian filter to a batch of 2D images. 

30 Input: (N, 1, H, W) 

31 Output: (N, 1, H, W) 

32 """ 

33 kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], 

34 dtype=image_batch.dtype, device=image_batch.device).reshape(1, 1, 3, 3) 

35 return F.conv2d(image_batch, kernel, padding=1) 

36 

37def extract_patches_2d_from_3d_stack( 

38 stack_3d: "torch.Tensor", patch_size: int, stride: int 

39) -> torch.Tensor: 

40 """ 

41 Extracts 2D patches from a 3D stack. 

42 Input stack_3d: [Z, H, W] 

43 Output patches: [N, Z, patch_size, patch_size], where N is num_patches. 

44 """ 

45 Z, H, W = stack_3d.shape 

46 patches = stack_3d.unfold(1, patch_size, stride) 

47 patches = patches.unfold(2, patch_size, stride) 

48 patches = patches.permute(1, 2, 0, 3, 4) 

49 patches = patches.reshape(-1, Z, patch_size, patch_size) 

50 return patches 

51 

52def blend_patches_to_2d_image( 

53 patch_outputs: List["torch.Tensor"], # List of [1, patch_size, patch_size] 

54 target_h: int, 

55 target_w: int, 

56 patch_size: int, 

57 stride: int, 

58 device: torch.device 

59) -> torch.Tensor: 

60 """ 

61 Blends 2D fused patches back into a single 2D image. 

62 Input patch_outputs: List of [1, patch_size, patch_size] tensors. 

63 Output: [1, target_h, target_w] 

64 """ 

65 fused_image = torch.zeros((target_h, target_w), dtype=torch.float32, device=device) 

66 count_map = torch.zeros((target_h, target_w), dtype=torch.float32, device=device) 

67 

68 num_blocks_h = (target_h - patch_size) // stride + 1 

69 num_blocks_w = (target_w - patch_size) // stride + 1 

70 

71 patch_idx = 0 

72 for i in range(num_blocks_h): 

73 for j in range(num_blocks_w): 

74 if patch_idx >= len(patch_outputs): 

75 # This case should ideally not be reached if inputs are consistent 

76 break 

77 

78 patch_content = patch_outputs[patch_idx].squeeze(0) # [patch_size, patch_size] 

79 

80 h_start = i * stride 

81 w_start = j * stride 

82 

83 h_end = h_start + patch_size 

84 w_end = w_start + patch_size 

85 

86 fused_image[h_start:h_end, w_start:w_end] += patch_content 

87 count_map[h_start:h_end, w_start:w_end] += 1.0 # Use float for count_map 

88 patch_idx += 1 

89 

90 fused_image /= count_map.clamp(min=1.0) 

91 return fused_image.unsqueeze(0) 

92 

93class UNetLite(nnModule): 

94 def __init__(self, in_channels_z: int, model_config_depth: int): 

95 super().__init__() 

96 multiplier = 1 if model_config_depth == 3 else 2 

97 ch1 = 32 * multiplier 

98 ch2 = 64 * multiplier 

99 

100 self.conv1 = nn.Conv2d(in_channels_z, ch1, kernel_size=3, padding=1) 

101 self.relu1 = nn.ReLU(inplace=True) 

102 self.conv2 = nn.Conv2d(ch1, ch2, kernel_size=3, padding=1, stride=2) 

103 self.relu2 = nn.ReLU(inplace=True) 

104 

105 self.upconv = nn.ConvTranspose2d(ch2, ch1, kernel_size=2, stride=2) 

106 self.relu3 = nn.ReLU(inplace=True) 

107 self.conv_out = nn.Conv2d(ch1, 1, kernel_size=3, padding=1) 

108 self.sigmoid = nn.Sigmoid() 

109 

110 def forward(self, x: "torch.Tensor") -> "torch.Tensor": 

111 x1 = self.relu1(self.conv1(x)) 

112 x2 = self.relu2(self.conv2(x1)) 

113 x3 = self.relu3(self.upconv(x2)) 

114 out = self.sigmoid(self.conv_out(x3)) 

115 return out 

116 

117def sharpness_loss_fn(fused_patch: "torch.Tensor") -> "torch.Tensor": 

118 laplacian_response = laplacian_filter_torch(fused_patch) 

119 return -torch.var(laplacian_response, dim=(-1, -2), unbiased=False).mean() 

120 

121def consistency_loss_fn(fused_patch: "torch.Tensor", input_patch_stack: "torch.Tensor") -> "torch.Tensor": 

122 diff_sq = (fused_patch - input_patch_stack)**2 

123 min_diff_sq_over_z = torch.min(diff_sq, dim=1)[0] 

124 return torch.mean(min_diff_sq_over_z) 

125 

126@torch_func 

127def dl_edof_unsupervised( 

128 image_stack: "torch.Tensor", 

129 model_depth: Optional[int] = None, 

130 patch_size: Optional[int] = None, 

131 stride: Optional[int] = None, 

132 denoise: bool = False, 

133 normalize: bool = False, 

134) -> torch.Tensor: 

135 if torch is None: 

136 raise ImportError("PyTorch is required for this function") 

137 if not (image_stack.ndim == 3 and str(image_stack.device.type) == 'cuda'): 

138 raise ValueError("Input image_stack must be a 3D CUDA tensor [Z, H, W]. " 

139 f"Got {image_stack.ndim}D tensor on {image_stack.device.type}.") 

140 

141 Z_orig, H_orig, W_orig = image_stack.shape 

142 device = image_stack.device 

143 original_dtype = image_stack.dtype 

144 

145 # Memory usage warning for large images 

146 total_elements = Z_orig * H_orig * W_orig 

147 if total_elements > 100_000_000: # 100M elements 

148 import logging 

149 logger = logging.getLogger(__name__) 

150 logger.warning(f"⚠️ Large image stack ({total_elements:,} elements) may cause high memory usage in deep learning EDoF. " 

151 f"Consider using smaller patch sizes or processing smaller regions.") 

152 logger.warning(f"Current image size: {Z_orig}×{H_orig}×{W_orig}") 

153 

154 # Estimate patch memory usage 

155 current_patch_size = patch_size or max(H_orig, W_orig) // 8 

156 current_stride = stride or current_patch_size // 2 

157 num_patches_h = (H_orig - current_patch_size) // current_stride + 1 

158 num_patches_w = (W_orig - current_patch_size) // current_stride + 1 

159 total_patches = num_patches_h * num_patches_w 

160 

161 if total_patches > 1000: 

162 import logging 

163 logger = logging.getLogger(__name__) 

164 logger.warning(f"⚠️ Large number of patches ({total_patches:,}) may cause high memory usage. " 

165 f"Consider increasing stride or reducing patch size.") 

166 

167 current_patch_size = patch_size 

168 if current_patch_size is None: 

169 current_patch_size = max(H_orig, W_orig) // 8 

170 

171 current_patch_size = max(current_patch_size, 16) # Min patch size 

172 if current_patch_size % 2 != 0: # Ensure even for CNN 

173 current_patch_size +=1 

174 current_patch_size = min(current_patch_size, H_orig, W_orig) 

175 

176 

177 current_stride = stride 

178 if current_stride is None: 

179 current_stride = current_patch_size // 2 

180 if current_stride <=0: current_stride = 1 

181 

182 

183 current_model_depth_config = model_depth 

184 if current_model_depth_config is None: 

185 current_model_depth_config = 3 if H_orig < 1024 else 5 

186 

187 if normalize: 

188 stack_f32 = image_stack.float() / 65535.0 

189 else: 

190 stack_f32 = image_stack.float() 

191 

192 if denoise: 

193 stack_to_blur = stack_f32.unsqueeze(1) 

194 blurred_stack = F.gaussian_blur(stack_to_blur, kernel_size=(3,3), sigma=(0.5,0.5)) 

195 stack_f32 = blurred_stack.squeeze(1) 

196 

197 patches = extract_patches_2d_from_3d_stack(stack_f32, current_patch_size, current_stride) 

198 

199 fused_patch_outputs = [] 

200 num_epochs_per_patch = 10 

201 

202 for i in range(patches.shape[0]): 

203 patch_stack_z = patches[i] 

204 model_input = patch_stack_z.unsqueeze(0).to(device) 

205 Z_patch = model_input.shape[1] 

206 

207 model = UNetLite(in_channels_z=Z_patch, model_config_depth=current_model_depth_config).to(device) 

208 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 

209 

210 for epoch in range(num_epochs_per_patch): 

211 model.train() 

212 optimizer.zero_grad() 

213 fused_output_patch = model(model_input) 

214 loss_s = sharpness_loss_fn(fused_output_patch) 

215 loss_c = consistency_loss_fn(fused_output_patch, model_input) 

216 total_loss = loss_s + loss_c 

217 total_loss.backward() 

218 optimizer.step() 

219 

220 model.eval() 

221 with torch.no_grad(): 

222 final_fused_patch = model(model_input) 

223 fused_patch_outputs.append(final_fused_patch.detach().squeeze(0)) 

224 

225 fused_2d_normalized = blend_patches_to_2d_image( 

226 fused_patch_outputs, H_orig, W_orig, current_patch_size, current_stride, device 

227 ) 

228 

229 fused_uint16 = fused_2d_normalized.clamp(0, 1).mul(65535.0).to(original_dtype) 

230 return fused_uint16