Coverage for openhcs/processing/backends/enhance/focus_torch.py: 11.8%
54 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
2from __future__ import annotations
3from typing import Optional
5from openhcs.core.utils import optional_import
6from openhcs.core.memory.decorators import torch as torch_decorator
8# Import torch modules as optional dependencies
9torch = optional_import("torch")
10F = optional_import("torch.nn.functional") if torch is not None else None
13def laplacian(image: "torch.Tensor") -> "torch.Tensor":
14 """Applies a 2D Laplacian filter."""
15 # Input image is expected to be [N, C, H, W] or [C, H, W] or [H, W]
16 # Kernel is [out_channels, in_channels/groups, kH, kW]
17 kernel = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=image.dtype, device=image.device)
18 kernel = kernel.reshape(1, 1, 3, 3) # For a single channel input/output
20 # Handle different input dimensions by adding/removing batch/channel dims
21 original_ndim = image.ndim
22 if original_ndim == 2: # [H, W]
23 image = image.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
24 elif original_ndim == 3: # [C, H, W] or [Z, H, W] - assuming [C, H, W] for conv2d
25 # If it's [Z, H, W] as in focus_stack_max_sharpness, need to process each slice
26 # This laplacian is for a single 2D image or batch of 2D images.
27 # The calling function focus_stack_max_sharpness passes image_stack.unsqueeze(1) -> [Z, 1, H, W]
28 # So input to this laplacian function will be [Z, 1, H, W].
29 pass # Already in [N, C, H, W] format where N=Z, C=1
30 elif original_ndim == 4: # [N, C, H, W]
31 pass
32 else:
33 raise ValueError(f"Unsupported image dimension for laplacian: {original_ndim}")
35 # Apply convolution. Assuming input channel is 1.
36 # If input has multiple channels, need to apply laplacian to each or convert to grayscale.
37 # The calling context passes [Z, 1, H, W], so in_channels is 1.
38 laplacian_img = F.conv2d(image, kernel, padding=1)
40 # Restore original dimensions
41 if original_ndim == 2:
42 laplacian_img = laplacian_img.squeeze(0).squeeze(0)
43 # If original_ndim was 3 ([Z, H, W]), the input was [Z, 1, H, W], output is [Z, 1, H, W]. Squeeze channel.
44 elif original_ndim == 3 and image.shape[1] == 1:
45 laplacian_img = laplacian_img.squeeze(1) # [Z, H, W]
47 return laplacian_img
49@torch_decorator
50def focus_stack_max_sharpness(
51 image_stack: "torch.Tensor",
52 method: str = "laplacian",
53 patch_size: Optional[int] = None,
54 stride: Optional[int] = None,
55 normalize_sharpness: bool = False
56) -> "torch.Tensor":
57 """
58 GPU-accelerated focus stacking using PyTorch. Selects sharpest regions from a Z-stack.
60 Args:
61 image_stack: Input tensor of shape [Z, H, W]
62 method: Sharpness metric ('laplacian' or 'gradient')
63 patch_size: Size of analysis patches. Default: max(H,W)//8
64 stride: Stride between patches. Default: patch_size//2
65 normalize_sharpness: Normalize sharpness scores per patch
67 Returns:
68 Composite image of shape [1, H, W] with maximal sharpness regions
69 """
70 if not (str(image_stack.ndim) == '3' and str(image_stack.device.type) == 'cuda'):
71 raise ValueError(f"Input must be 3D tensor [Z,H,W]. Got {image_stack.ndim}D")
73 Z, H, W = image_stack.shape
74 device = image_stack.device
75 dtype = image_stack.dtype
77 # Set adaptive defaults based on image dimensions
78 patch_size = patch_size or max(H, W) // 8
79 stride = stride or patch_size // 2
81 # Calculate sharpness maps
82 if method == "laplacian":
83 sharpness = torch.abs(laplacian(image_stack.unsqueeze(1))).squeeze(1)
84 elif method == "gradient":
85 gx, gy = torch.gradient(image_stack, dim=(1,2))
86 sharpness = torch.sqrt(gx**2 + gy**2)
87 else:
88 raise ValueError(f"Invalid method: {method}. Use 'laplacian' or 'gradient'")
90 if normalize_sharpness:
91 sharpness = (sharpness - sharpness.mean(dim=0)) / (sharpness.std(dim=0) + 1e-6)
93 # Generate sliding window patches
94 patches = F.unfold(
95 sharpness.unsqueeze(1),
96 kernel_size=patch_size,
97 stride=stride
98 ).view(Z, -1, H//stride, W//stride)
100 # Find sharpest z-index per patch
101 _, max_indices = torch.max(patches, dim=0)
103 # Create composite image using max sharpness indices
104 composite = torch.zeros_like(image_stack[0])
105 weights = torch.zeros_like(composite)
107 for i in range(max_indices.shape[1]):
108 for j in range(max_indices.shape[2]):
109 z_idx = max_indices[0,i,j]
110 h_start = i * stride
111 w_start = j * stride
113 composite_slice = composite[h_start:h_start+patch_size, w_start:w_start+patch_size]
114 weight_slice = weights[h_start:h_start+patch_size, w_start:w_start+patch_size]
116 composite_slice += image_stack[z_idx, h_start:h_start+patch_size, w_start:w_start+patch_size]
117 weight_slice += torch.ones_like(weight_slice)
119 # Avoid division by zero in overlapping regions
120 return (composite / torch.clamp_min(weights, 1)).unsqueeze(0)