Coverage for openhcs/processing/backends/enhance/focus_torch.py: 11.8%

54 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1 

2from __future__ import annotations 

3from typing import Optional 

4 

5from openhcs.core.utils import optional_import 

6from openhcs.core.memory.decorators import torch as torch_decorator 

7 

8# Import torch modules as optional dependencies 

9torch = optional_import("torch") 

10F = optional_import("torch.nn.functional") if torch is not None else None 

11 

12 

13def laplacian(image: "torch.Tensor") -> "torch.Tensor": 

14 """Applies a 2D Laplacian filter.""" 

15 # Input image is expected to be [N, C, H, W] or [C, H, W] or [H, W] 

16 # Kernel is [out_channels, in_channels/groups, kH, kW] 

17 kernel = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=image.dtype, device=image.device) 

18 kernel = kernel.reshape(1, 1, 3, 3) # For a single channel input/output 

19 

20 # Handle different input dimensions by adding/removing batch/channel dims 

21 original_ndim = image.ndim 

22 if original_ndim == 2: # [H, W] 

23 image = image.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] 

24 elif original_ndim == 3: # [C, H, W] or [Z, H, W] - assuming [C, H, W] for conv2d 

25 # If it's [Z, H, W] as in focus_stack_max_sharpness, need to process each slice 

26 # This laplacian is for a single 2D image or batch of 2D images. 

27 # The calling function focus_stack_max_sharpness passes image_stack.unsqueeze(1) -> [Z, 1, H, W] 

28 # So input to this laplacian function will be [Z, 1, H, W]. 

29 pass # Already in [N, C, H, W] format where N=Z, C=1 

30 elif original_ndim == 4: # [N, C, H, W] 

31 pass 

32 else: 

33 raise ValueError(f"Unsupported image dimension for laplacian: {original_ndim}") 

34 

35 # Apply convolution. Assuming input channel is 1. 

36 # If input has multiple channels, need to apply laplacian to each or convert to grayscale. 

37 # The calling context passes [Z, 1, H, W], so in_channels is 1. 

38 laplacian_img = F.conv2d(image, kernel, padding=1) 

39 

40 # Restore original dimensions 

41 if original_ndim == 2: 

42 laplacian_img = laplacian_img.squeeze(0).squeeze(0) 

43 # If original_ndim was 3 ([Z, H, W]), the input was [Z, 1, H, W], output is [Z, 1, H, W]. Squeeze channel. 

44 elif original_ndim == 3 and image.shape[1] == 1: 

45 laplacian_img = laplacian_img.squeeze(1) # [Z, H, W] 

46 

47 return laplacian_img 

48 

49@torch_decorator 

50def focus_stack_max_sharpness( 

51 image_stack: "torch.Tensor", 

52 method: str = "laplacian", 

53 patch_size: Optional[int] = None, 

54 stride: Optional[int] = None, 

55 normalize_sharpness: bool = False 

56) -> "torch.Tensor": 

57 """ 

58 GPU-accelerated focus stacking using PyTorch. Selects sharpest regions from a Z-stack. 

59 

60 Args: 

61 image_stack: Input tensor of shape [Z, H, W] 

62 method: Sharpness metric ('laplacian' or 'gradient') 

63 patch_size: Size of analysis patches. Default: max(H,W)//8 

64 stride: Stride between patches. Default: patch_size//2 

65 normalize_sharpness: Normalize sharpness scores per patch 

66 

67 Returns: 

68 Composite image of shape [1, H, W] with maximal sharpness regions 

69 """ 

70 if not (str(image_stack.ndim) == '3' and str(image_stack.device.type) == 'cuda'): 

71 raise ValueError(f"Input must be 3D tensor [Z,H,W]. Got {image_stack.ndim}D") 

72 

73 Z, H, W = image_stack.shape 

74 device = image_stack.device 

75 dtype = image_stack.dtype 

76 

77 # Set adaptive defaults based on image dimensions 

78 patch_size = patch_size or max(H, W) // 8 

79 stride = stride or patch_size // 2 

80 

81 # Calculate sharpness maps 

82 if method == "laplacian": 

83 sharpness = torch.abs(laplacian(image_stack.unsqueeze(1))).squeeze(1) 

84 elif method == "gradient": 

85 gx, gy = torch.gradient(image_stack, dim=(1,2)) 

86 sharpness = torch.sqrt(gx**2 + gy**2) 

87 else: 

88 raise ValueError(f"Invalid method: {method}. Use 'laplacian' or 'gradient'") 

89 

90 if normalize_sharpness: 

91 sharpness = (sharpness - sharpness.mean(dim=0)) / (sharpness.std(dim=0) + 1e-6) 

92 

93 # Generate sliding window patches 

94 patches = F.unfold( 

95 sharpness.unsqueeze(1), 

96 kernel_size=patch_size, 

97 stride=stride 

98 ).view(Z, -1, H//stride, W//stride) 

99 

100 # Find sharpest z-index per patch 

101 _, max_indices = torch.max(patches, dim=0) 

102 

103 # Create composite image using max sharpness indices 

104 composite = torch.zeros_like(image_stack[0]) 

105 weights = torch.zeros_like(composite) 

106 

107 for i in range(max_indices.shape[1]): 

108 for j in range(max_indices.shape[2]): 

109 z_idx = max_indices[0,i,j] 

110 h_start = i * stride 

111 w_start = j * stride 

112 

113 composite_slice = composite[h_start:h_start+patch_size, w_start:w_start+patch_size] 

114 weight_slice = weights[h_start:h_start+patch_size, w_start:w_start+patch_size] 

115 

116 composite_slice += image_stack[z_idx, h_start:h_start+patch_size, w_start:w_start+patch_size] 

117 weight_slice += torch.ones_like(weight_slice) 

118 

119 # Avoid division by zero in overlapping regions 

120 return (composite / torch.clamp_min(weights, 1)).unsqueeze(0)