Coverage for ezstitcher/core/image_processor.py: 94%

141 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2025-04-30 13:20 +0000

1import numpy as np 

2import logging 

3from skimage import filters, exposure, morphology as morph, transform as trans 

4 

5logger = logging.getLogger(__name__) 

6 

7 

8def create_linear_weight_mask(height, width, margin_ratio=0.1): 

9 """ 

10 Create a 2D weight mask that linearly ramps from 0 at the edges 

11 to 1 in the center. 

12 

13 Args: 

14 height (int): Height of the mask 

15 width (int): Width of the mask 

16 margin_ratio (float): Ratio of the margin to the image size 

17 

18 Returns: 

19 numpy.ndarray: 2D weight mask 

20 """ 

21 margin_y = int(np.floor(height * margin_ratio)) 

22 margin_x = int(np.floor(width * margin_ratio)) 

23 

24 weight_y = np.ones(height, dtype=np.float32) 

25 if margin_y > 0: 

26 ramp_top = np.linspace(0, 1, margin_y, endpoint=False) 

27 ramp_bottom = np.linspace(1, 0, margin_y, endpoint=False) 

28 weight_y[:margin_y] = ramp_top 

29 weight_y[-margin_y:] = ramp_bottom 

30 

31 weight_x = np.ones(width, dtype=np.float32) 

32 if margin_x > 0: 

33 ramp_left = np.linspace(0, 1, margin_x, endpoint=False) 

34 ramp_right = np.linspace(1, 0, margin_x, endpoint=False) 

35 weight_x[:margin_x] = ramp_left 

36 weight_x[-margin_x:] = ramp_right 

37 

38 # Create 2D weight mask 

39 weight_mask = np.outer(weight_y, weight_x) 

40 

41 return weight_mask 

42 

43 

44# These functions have been moved to their appropriate classes: 

45# - load_image and save_image are now in FileSystemManager 

46# - parse_positions_csv is now in CSVHandler 

47 

48 

49class ImageProcessor: 

50 """ 

51 Handles image normalization, filtering, and compositing. 

52 All methods are static and do not require an instance. 

53 """ 

54 

55 @staticmethod 

56 def sharpen(image, radius=1, amount=1.0): 

57 """ 

58 Sharpen an image using unsharp masking. 

59 

60 Args: 

61 image (numpy.ndarray): Input image 

62 radius (float): Radius of Gaussian blur 

63 amount (float): Sharpening strength 

64 

65 Returns: 

66 numpy.ndarray: Sharpened image 

67 """ 

68 # Convert to float for processing 

69 image_float = image.astype(np.float32) / np.max(image) 

70 

71 # Create blurred version for unsharp mask 

72 if image_float.ndim == 3: 

73 blurred = filters.gaussian(image_float, sigma=radius, channel_axis=-1) 

74 else: 

75 blurred = filters.gaussian(image_float, sigma=radius) 

76 

77 # Apply unsharp mask: original + amount * (original - blurred) 

78 sharpened = image_float + amount * (image_float - blurred) 

79 

80 # Clip to valid range 

81 sharpened = np.clip(sharpened, 0, 1.0) 

82 

83 # Scale back to original range 

84 sharpened = exposure.rescale_intensity(sharpened, in_range='image', out_range=(0, 65535)) 

85 sharpened = sharpened.astype(np.uint16) 

86 

87 return sharpened 

88 

89 @staticmethod 

90 def percentile_normalize(image, low_percentile=1, high_percentile=99, target_min=0, target_max=65535): 

91 """ 

92 Normalize image using percentile-based contrast stretching. 

93 

94 Args: 

95 image (numpy.ndarray): Input image 

96 low_percentile (float): Lower percentile (0-100) 

97 high_percentile (float): Upper percentile (0-100) 

98 target_min (int): Target minimum value 

99 target_max (int): Target maximum value 

100 

101 Returns: 

102 numpy.ndarray: Normalized image 

103 """ 

104 # Get percentile values 

105 p_low, p_high = np.percentile(image, (low_percentile, high_percentile)) 

106 

107 # Avoid division by zero 

108 if p_high == p_low: 

109 return np.ones_like(image) * target_min 

110 

111 # Clip and normalize to target range 

112 clipped = np.clip(image, p_low, p_high) 

113 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min 

114 normalized = normalized.astype(np.uint16) 

115 

116 return normalized 

117 

118 @staticmethod 

119 def stack_percentile_normalize(stack, low_percentile=1, high_percentile=99, target_min=0, target_max=65535): 

120 """ 

121 Normalize a stack of images using global percentile-based contrast stretching. 

122 This ensures consistent normalization across all images in the stack. 

123 

124 Args: 

125 stack (list or numpy.ndarray): Stack of images 

126 low_percentile (float): Lower percentile (0-100) 

127 high_percentile (float): Upper percentile (0-100) 

128 target_min (int): Target minimum value 

129 target_max (int): Target maximum value 

130 

131 Returns: 

132 numpy.ndarray: Normalized stack of images 

133 """ 

134 # Convert to numpy array if it's a list 

135 if isinstance(stack, list): 

136 stack = np.array(stack) 

137 

138 # Calculate global percentiles across the entire stack 

139 p_low = np.percentile(stack, low_percentile) 

140 p_high = np.percentile(stack, high_percentile) 

141 

142 # Avoid division by zero 

143 if p_high == p_low: 

144 return np.ones_like(stack) * target_min 

145 

146 # Clip and normalize to target range 

147 clipped = np.clip(stack, p_low, p_high) 

148 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min 

149 normalized = normalized.astype(np.uint16) 

150 

151 return normalized 

152 

153 @staticmethod 

154 def create_composite(images, weights=None): 

155 """ 

156 Create a grayscale composite image from multiple channels. 

157 

158 Args: 

159 images (list): List of images to composite 

160 weights (list, optional): List of weights for each image. If None, equal weights are used. 

161 

162 Returns: 

163 numpy.ndarray: Grayscale composite image (16-bit) 

164 

165 Raises: 

166 TypeError: If images is not a list or weights is not a list 

167 ValueError: If images list is empty 

168 """ 

169 # Ensure images is a list 

170 if not isinstance(images, list): 

171 raise TypeError("images must be a list of images") 

172 

173 # Check for empty list early 

174 if not images: 

175 raise ValueError("images list cannot be empty") 

176 

177 # Default weights if none provided 

178 if weights is None: 

179 # Equal weights for all images 

180 weights = [1.0 / len(images)] * len(images) 

181 elif not isinstance(weights, list): 

182 raise TypeError("weights must be a list of values") 

183 

184 # Make sure weights list is at least as long as images list 

185 if len(weights) < len(images): 

186 weights = weights + [0.0] * (len(images) - len(weights)) 

187 # Truncate weights if longer than images 

188 weights = weights[:len(images)] 

189 

190 first_image = images[0] 

191 shape = first_image.shape 

192 dtype = first_image.dtype 

193 

194 # Create empty composite 

195 composite = np.zeros(shape, dtype=np.float32) 

196 total_weight = 0.0 

197 

198 # Add each image with its weight 

199 for i, image in enumerate(images): 

200 weight = weights[i] 

201 if weight <= 0.0: 

202 continue 

203 

204 # Add to composite 

205 composite += image.astype(np.float32) * weight 

206 total_weight += weight 

207 

208 # Normalize by total weight 

209 if total_weight > 0: 

210 composite /= total_weight 

211 

212 # Convert back to original dtype (usually uint16) 

213 if np.issubdtype(dtype, np.integer): 

214 max_val = np.iinfo(dtype).max 

215 composite = np.clip(composite, 0, max_val).astype(dtype) 

216 else: 

217 composite = composite.astype(dtype) 

218 

219 return composite 

220 

221 @staticmethod 

222 def apply_mask(image, mask): 

223 """ 

224 Apply a mask to an image. 

225 

226 Args: 

227 image (numpy.ndarray): Input image 

228 mask (numpy.ndarray): Mask image (same shape as input) 

229 

230 Returns: 

231 numpy.ndarray: Masked image 

232 """ 

233 # Ensure mask has same shape as image 

234 if mask.shape != image.shape: 

235 raise ValueError(f"Mask shape {mask.shape} doesn't match image shape {image.shape}") 

236 

237 # Apply mask 

238 masked = image.astype(np.float32) * mask.astype(np.float32) 

239 masked = masked.astype(image.dtype) 

240 

241 return masked 

242 

243 @staticmethod 

244 def create_weight_mask(shape, margin_ratio=0.1): 

245 """ 

246 Create a weight mask for blending images. 

247 

248 Args: 

249 shape (tuple): Shape of the mask (height, width) 

250 margin_ratio (float): Ratio of image size to use as margin 

251 

252 Returns: 

253 numpy.ndarray: Weight mask 

254 """ 

255 return create_linear_weight_mask(shape[0], shape[1], margin_ratio) 

256 

257 @staticmethod 

258 def max_projection(stack): 

259 """ 

260 Create a maximum intensity projection from a Z-stack. 

261 

262 Args: 

263 stack (list or numpy.ndarray): Stack of images 

264 

265 Returns: 

266 numpy.ndarray: Maximum intensity projection 

267 """ 

268 # Convert to numpy array if it's a list 

269 if isinstance(stack, list): 

270 stack = np.array(stack) 

271 

272 # Create max projection 

273 return np.max(stack, axis=0) 

274 

275 @staticmethod 

276 def mean_projection(stack): 

277 """ 

278 Create a mean intensity projection from a Z-stack. 

279 

280 Args: 

281 stack (list or numpy.ndarray): Stack of images 

282 

283 Returns: 

284 numpy.ndarray: Mean intensity projection 

285 """ 

286 # Convert to numpy array if it's a list 

287 if isinstance(stack, list): 

288 stack = np.array(stack) 

289 

290 # Create mean projection 

291 return np.mean(stack, axis=0).astype(stack[0].dtype) 

292 

293 @staticmethod 

294 def stack_equalize_histogram(stack, bins=65536, range_min=0, range_max=65535): 

295 """ 

296 Apply true histogram equalization to an entire stack of images. 

297 This ensures consistent contrast enhancement across all images in the stack. 

298 

299 Unlike standard histogram equalization applied to individual images, 

300 this method computes a global histogram across the entire stack and 

301 applies the same transformation to all images, preserving relative 

302 intensity relationships between Z-planes. 

303 

304 Args: 

305 stack (list or numpy.ndarray): Stack of images 

306 bins (int): Number of bins for histogram computation 

307 range_min (int): Minimum value for histogram range 

308 range_max (int): Maximum value for histogram range 

309 

310 Returns: 

311 numpy.ndarray: Histogram-equalized stack of images 

312 """ 

313 # Convert to numpy array if it's a list 

314 if isinstance(stack, list): 

315 stack = np.array(stack) 

316 

317 # Flatten the entire stack to compute the global histogram 

318 flat_stack = stack.flatten() 

319 

320 # Calculate the histogram and cumulative distribution function (CDF) 

321 hist, bin_edges = np.histogram(flat_stack, bins=bins, range=(range_min, range_max)) 

322 cdf = hist.cumsum() 

323 

324 # Normalize the CDF to the range [0, 65535] 

325 # Avoid division by zero 

326 if cdf[-1] > 0: 

327 cdf = 65535 * cdf / cdf[-1] 

328 

329 # Use linear interpolation to map input values to equalized values 

330 equalized_stack = np.interp(stack.flatten(), bin_edges[:-1], cdf).reshape(stack.shape) 

331 

332 # Convert to uint16 

333 return equalized_stack.astype(np.uint16) 

334 

335 

336 @staticmethod 

337 def create_projection(stack, method="max_projection", focus_analyzer=None): 

338 """ 

339 Create a projection from a stack using the specified method. 

340 

341 Args: 

342 stack (list): List of images 

343 method (str): Projection method (max_projection, mean_projection, best_focus) 

344 focus_analyzer (FocusAnalyzer, optional): Focus analyzer for best_focus method 

345 

346 Returns: 

347 numpy.ndarray: Projected image 

348 """ 

349 if method == "max_projection": 

350 return ImageProcessor.max_projection(stack) 

351 

352 if method == "mean_projection": 

353 return ImageProcessor.mean_projection(stack) 

354 

355 if method == "best_focus": 

356 if focus_analyzer is None: 

357 logger.warning("No focus analyzer provided for best_focus method, " 

358 "using max_projection instead") 

359 return ImageProcessor.max_projection(stack) 

360 best_idx, _ = focus_analyzer.find_best_focus(stack) 

361 return stack[best_idx] 

362 

363 # Default case for unknown methods 

364 logger.warning("Unknown projection method: %s, using max_projection", method) 

365 return ImageProcessor.max_projection(stack) 

366 

367 @staticmethod 

368 def tophat(image, selem_radius=50, downsample_factor=4): 

369 """ 

370 Apply white top-hat filter to an image for background removal. 

371 

372 This implementation uses downsampling for efficiency with large structuring elements. 

373 

374 Args: 

375 image (numpy.ndarray): Input image 

376 selem_radius (int): Radius of the structuring element disk 

377 downsample_factor (int): Factor by which to downsample the image for processing 

378 

379 Returns: 

380 numpy.ndarray: Filtered image with background removed 

381 """ 

382 # Store original data type 

383 input_dtype = image.dtype 

384 

385 # 1) Downsample 

386 # For grayscale images: trans.resize with anti_aliasing=True 

387 image_small = trans.resize(image, 

388 (image.shape[0]//downsample_factor, 

389 image.shape[1]//downsample_factor), 

390 anti_aliasing=True, preserve_range=True) 

391 

392 # 2) Build structuring element for the smaller image 

393 selem_small = morph.disk(selem_radius // downsample_factor) 

394 

395 # 3) White top-hat on the smaller image 

396 tophat_small = morph.white_tophat(image_small, selem_small) 

397 

398 # 4) Upscale background to original size 

399 background_small = image_small - tophat_small 

400 background_large = trans.resize(background_small, 

401 image.shape, 

402 anti_aliasing=False, 

403 preserve_range=True) 

404 

405 # 5) Subtract background and clip negative values 

406 result = np.maximum(image - background_large, 0) 

407 

408 # 6) Convert back to original data type 

409 result = result.astype(input_dtype) 

410 

411 return result