Coverage for openhcs/processing/backends/processors/jax_processor.py: 13.4%

227 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2JAX Image Processor Implementation 

3 

4This module implements the ImageProcessorInterface using JAX as the backend. 

5It leverages GPU acceleration for image processing operations. 

6 

7Doctrinal Clauses: 

8- Clause 3 — Declarative Primacy: All functions are pure and stateless 

9- Clause 88 — No Inferred Capabilities: Explicit JAX dependency 

10- Clause 106-A — Declared Memory Types: All methods specify JAX arrays 

11""" 

12from __future__ import annotations 

13 

14import logging 

15from typing import Any, List, Optional, Tuple 

16 

17from openhcs.core.memory.decorators import jax as jax_func 

18from openhcs.core.utils import optional_import 

19 

20# Import JAX as an optional dependency 

21jax = optional_import("jax") 

22jnp = optional_import("jax.numpy") if jax is not None else None 

23lax = jax.lax if jax is not None else None 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28@jax_func 

29def create_linear_weight_mask(height: int, width: int, margin_ratio: float = 0.1) -> "jnp.ndarray": 

30 """ 

31 Create a 2D weight mask that linearly ramps from 0 at the edges to 1 in the center. 

32 

33 Args: 

34 height: Height of the mask 

35 width: Width of the mask 

36 margin_ratio: Ratio of the margin to the image size 

37 

38 Returns: 

39 2D JAX weight mask of shape (height, width) 

40 """ 

41 # The compiler will ensure this function is only called when JAX is available 

42 # No need to check for JAX availability here 

43 

44 margin_y = int(jnp.floor(height * margin_ratio)) 

45 margin_x = int(jnp.floor(width * margin_ratio)) 

46 

47 weight_h = jnp.ones(height, dtype=jnp.float32) 

48 if margin_y > 0: 

49 ramp_top = jnp.linspace(0, 1, margin_y, endpoint=False) 

50 ramp_bottom = jnp.linspace(1, 0, margin_y, endpoint=False) 

51 weight_h = weight_h.at[:margin_y].set(ramp_top) 

52 weight_h = weight_h.at[-margin_y:].set(ramp_bottom) 

53 

54 weight_x = jnp.ones(width, dtype=jnp.float32) 

55 if margin_x > 0: 

56 ramp_left = jnp.linspace(0, 1, margin_x, endpoint=False) 

57 ramp_right = jnp.linspace(1, 0, margin_x, endpoint=False) 

58 weight_x = weight_x.at[:margin_x].set(ramp_left) 

59 weight_x = weight_x.at[-margin_x:].set(ramp_right) 

60 

61 # Create 2D weight mask using outer product 

62 weight_mask = jnp.outer(weight_h, weight_x) 

63 

64 return weight_mask 

65 

66 

67def _validate_3d_array(array: Any, name: str = "input") -> None: 

68 """ 

69 Validate that the input is a 3D JAX array. 

70 

71 Args: 

72 array: Array to validate 

73 name: Name of the array for error messages 

74 

75 Raises: 

76 TypeError: If the array is not a JAX array 

77 ValueError: If the array is not 3D 

78 ImportError: If JAX is not available 

79 """ 

80 # The compiler will ensure this function is only called when JAX is available 

81 # No need to check for JAX availability here 

82 

83 if not isinstance(array, jnp.ndarray): 

84 raise TypeError(f"{name} must be a JAX array, got {type(array)}. " 

85 f"No automatic conversion is performed to maintain explicit contracts.") 

86 

87 if array.ndim != 3: 

88 raise ValueError(f"{name} must be a 3D array, got {array.ndim}D") 

89 

90@jax_func 

91def _gaussian_kernel(sigma: float, kernel_size: int) -> "jnp.ndarray": 

92 """ 

93 Create a 2D Gaussian kernel. 

94 

95 Args: 

96 sigma: Standard deviation of the Gaussian kernel 

97 kernel_size: Size of the kernel (must be odd) 

98 

99 Returns: 

100 2D JAX array of shape (kernel_size, kernel_size) 

101 """ 

102 # Ensure kernel_size is odd 

103 if kernel_size % 2 == 0: 

104 kernel_size += 1 

105 

106 # Create 1D Gaussian kernel 

107 x = jnp.arange(-(kernel_size // 2), kernel_size // 2 + 1, dtype=jnp.float32) 

108 kernel_1d = jnp.exp(-0.5 * (x / sigma) ** 2) 

109 kernel_1d = kernel_1d / jnp.sum(kernel_1d) 

110 

111 # Create 2D Gaussian kernel 

112 kernel_2d = jnp.outer(kernel_1d, kernel_1d) 

113 

114 return kernel_2d 

115 

116@jax_func 

117def _gaussian_blur(image: "jnp.ndarray", sigma: float) -> "jnp.ndarray": 

118 """ 

119 Apply Gaussian blur to a 2D image. 

120 

121 Args: 

122 image: 2D JAX array of shape (H, W) 

123 sigma: Standard deviation of the Gaussian kernel 

124 

125 Returns: 

126 Blurred 2D JAX array of shape (H, W) 

127 """ 

128 # Calculate kernel size based on sigma 

129 kernel_size = max(3, int(2 * 4 * sigma + 1)) 

130 

131 # Create Gaussian kernel 

132 kernel = _gaussian_kernel(sigma, kernel_size) 

133 

134 # Pad the image for convolution 

135 pad_size = kernel_size // 2 

136 padded = jnp.pad(image, ((pad_size, pad_size), (pad_size, pad_size)), mode='reflect') 

137 

138 # Apply convolution 

139 # JAX doesn't have a direct 2D convolution function for arbitrary kernels 

140 # We'll use lax.conv_general_dilated with appropriate parameters 

141 

142 # Reshape inputs for lax.conv_general_dilated 

143 kernel_reshaped = kernel.reshape(kernel_size, kernel_size, 1, 1) 

144 padded_reshaped = padded.reshape(1, padded.shape[0], padded.shape[1], 1) 

145 

146 # Apply convolution 

147 result = lax.conv_general_dilated( 

148 padded_reshaped, 

149 kernel_reshaped, 

150 window_strides=(1, 1), 

151 padding='VALID', 

152 dimension_numbers=('NHWC', 'HWIO', 'NHWC') 

153 ) 

154 

155 # Reshape back to 2D 

156 return result[0, :, :, 0] 

157 

158@jax_func 

159def sharpen(image: "jnp.ndarray", radius: float = 1.0, amount: float = 1.0 

160) -> "jnp.ndarray": 

161 """ 

162 Sharpen a 3D image using unsharp masking. 

163 

164 This applies sharpening to each Z-slice independently. 

165 

166 Args: 

167 image: 3D JAX array of shape (Z, Y, X) 

168 radius: Radius of Gaussian blur 

169 amount: Sharpening strength 

170 

171 Returns: 

172 Sharpened 3D JAX array of shape (Z, Y, X) 

173 """ 

174 _validate_3d_array(image) 

175 

176 # Store original dtype 

177 dtype = image.dtype 

178 

179 # Process each Z-slice independently 

180 result_list = [] 

181 

182 for z in range(image.shape[0]): 

183 # Convert to float for processing 

184 slice_float = image[z].astype(jnp.float32) / jnp.max(image[z]) 

185 

186 # Create blurred version for unsharp mask 

187 blurred = _gaussian_blur(slice_float, sigma=radius) 

188 

189 # Apply unsharp mask: original + amount * (original - blurred) 

190 sharpened = slice_float + amount * (slice_float - blurred) 

191 

192 # Clip to valid range 

193 sharpened = jnp.clip(sharpened, 0.0, 1.0) 

194 

195 # Scale back to original range 

196 min_val = jnp.min(sharpened) 

197 max_val = jnp.max(sharpened) 

198 if max_val > min_val: 

199 sharpened = (sharpened - min_val) * 65535.0 / (max_val - min_val) 

200 

201 result_list.append(sharpened) 

202 

203 # Stack results back into a 3D array 

204 result = jnp.stack(result_list, axis=0) 

205 

206 # Convert back to original dtype 

207 if jnp.issubdtype(dtype, jnp.integer): 

208 result = jnp.clip(result, 0, 65535).astype(jnp.uint16) 

209 else: 

210 result = result.astype(dtype) 

211 

212 return result 

213 

214@jax_func 

215def percentile_normalize( 

216 image: "jnp.ndarray", 

217 low_percentile: float = 1.0, 

218 high_percentile: float = 99.0, 

219 target_min: float = 0.0, 

220 target_max: float = 65535.0 

221) -> "jnp.ndarray": 

222 """ 

223 Normalize a 3D image using percentile-based contrast stretching. 

224 

225 This applies normalization to each Z-slice independently. 

226 

227 Args: 

228 image: 3D JAX array of shape (Z, Y, X) 

229 low_percentile: Lower percentile (0-100) 

230 high_percentile: Upper percentile (0-100) 

231 target_min: Target minimum value 

232 target_max: Target maximum value 

233 

234 Returns: 

235 Normalized 3D JAX array of shape (Z, Y, X) 

236 """ 

237 _validate_3d_array(image) 

238 

239 # Process each Z-slice independently 

240 result_list = [] 

241 

242 # Define a function to normalize a single slice 

243 def normalize_single_slice(slice_idx): 

244 slice_data = image[slice_idx] 

245 

246 # Get percentile values for this slice 

247 p_low = jnp.percentile(slice_data, low_percentile) 

248 p_high = jnp.percentile(slice_data, high_percentile) 

249 

250 # Avoid division by zero 

251 equal_percentiles = jnp.isclose(p_high, p_low) 

252 

253 # Function to normalize when percentiles are different 

254 def normalize_slice(args): 

255 p_low, p_high, slice_data = args 

256 # Clip and normalize to target range 

257 clipped = jnp.clip(slice_data.astype(jnp.float32), p_low, p_high) 

258 scale = (target_max - target_min) / (p_high - p_low) 

259 normalized = (clipped - p_low) * scale + target_min 

260 return normalized 

261 

262 # Function for the case where percentiles are equal 

263 def return_constant(args): 

264 _, _, slice_data = args 

265 return jnp.ones_like(slice_data, dtype=jnp.float32) * target_min 

266 

267 # Handle the case where percentiles are equal 

268 normalized = jax.lax.cond( 

269 equal_percentiles, 

270 return_constant, 

271 normalize_slice, 

272 (p_low, p_high, slice_data) 

273 ) 

274 

275 return normalized 

276 

277 # Process each slice 

278 for z in range(image.shape[0]): 

279 result_list.append(normalize_single_slice(z)) 

280 

281 # Stack results back into a 3D array 

282 result = jnp.stack(result_list, axis=0) 

283 

284 # Convert to uint16 

285 result = jnp.clip(result, 0, 65535).astype(jnp.uint16) 

286 

287 return result 

288 

289@jax_func 

290def stack_percentile_normalize( 

291 stack: "jnp.ndarray", 

292 low_percentile: float = 1.0, 

293 high_percentile: float = 99.0, 

294 target_min: float = 0.0, 

295 target_max: float = 65535.0 

296) -> "jnp.ndarray": 

297 """ 

298 Normalize a stack using global percentile-based contrast stretching. 

299 

300 This ensures consistent normalization across all Z-slices by computing 

301 global percentiles across the entire stack. 

302 

303 Args: 

304 stack: 3D JAX array of shape (Z, Y, X) 

305 low_percentile: Lower percentile (0-100) 

306 high_percentile: Upper percentile (0-100) 

307 target_min: Target minimum value 

308 target_max: Target maximum value 

309 

310 Returns: 

311 Normalized 3D JAX array of shape (Z, Y, X) 

312 """ 

313 _validate_3d_array(stack) 

314 

315 # Calculate global percentiles across the entire stack 

316 p_low = jnp.percentile(stack, low_percentile) 

317 p_high = jnp.percentile(stack, high_percentile) 

318 

319 # Avoid division by zero 

320 if p_high == p_low: 

321 return jnp.ones_like(stack) * target_min 

322 

323 # Clip and normalize to target range (match NumPy implementation exactly) 

324 clipped = jnp.clip(stack, p_low, p_high) 

325 normalized = (clipped - p_low) * (target_max - target_min) / (p_high - p_low) + target_min 

326 normalized = normalized.astype(jnp.uint16) 

327 

328 return normalized 

329 

330@jax_func 

331def create_composite( 

332 stack: "jnp.ndarray", weights: Optional[List[float]] = None 

333) -> "jnp.ndarray": 

334 """ 

335 Create a composite image from a 3D stack where each slice is a channel. 

336 

337 Args: 

338 stack: 3D JAX array of shape (N, Y, X) where N is number of channel slices 

339 weights: List of weights for each slice. If None, equal weights are used. 

340 

341 Returns: 

342 Composite 3D JAX array of shape (1, Y, X) 

343 """ 

344 # Validate input is 3D array 

345 _validate_3d_array(stack) 

346 

347 n_slices, height, width = stack.shape 

348 

349 # Default weights if none provided 

350 if weights is None: 

351 # Equal weights for all slices 

352 weights = [1.0 / n_slices] * n_slices 

353 elif isinstance(weights, (list, tuple)): 

354 # Convert tuple to list if needed 

355 weights = list(weights) 

356 if len(weights) != n_slices: 

357 raise ValueError(f"Number of weights ({len(weights)}) must match number of slices ({n_slices})") 

358 else: 

359 raise TypeError(f"weights must be a list of values or None, got {type(weights)}: {weights}") 

360 

361 # Normalize weights to sum to 1 

362 weight_sum = sum(weights) 

363 if weight_sum == 0: 

364 raise ValueError("Sum of weights cannot be zero") 

365 normalized_weights = [w / weight_sum for w in weights] 

366 

367 # Convert weights to JAX array for efficient computation 

368 # CRITICAL: Use float32 for weights to preserve fractional values, not stack.dtype 

369 weights_array = jnp.array(normalized_weights, dtype=jnp.float32) 

370 

371 # Reshape weights for broadcasting: (N, 1, 1) to multiply with (N, Y, X) 

372 weights_array = weights_array.reshape(n_slices, 1, 1) 

373 

374 # Create composite by weighted sum along the first axis 

375 # Convert stack to float32 for computation to avoid precision loss 

376 stack_float = stack.astype(jnp.float32) 

377 weighted_stack = stack_float * weights_array 

378 composite_slice = jnp.sum(weighted_stack, axis=0, keepdims=True) # Keep as (1, Y, X) 

379 

380 # Convert back to original dtype 

381 composite_slice = composite_slice.astype(stack.dtype) 

382 

383 return composite_slice 

384 

385@jax_func 

386def apply_mask(image: "jnp.ndarray", mask: "jnp.ndarray") -> "jnp.ndarray": 

387 """ 

388 Apply a mask to a 3D image. 

389 

390 This applies the mask to each Z-slice independently if mask is 2D, 

391 or applies the 3D mask directly if mask is 3D. 

392 

393 Args: 

394 image: 3D JAX array of shape (Z, Y, X) 

395 mask: 3D JAX array of shape (Z, Y, X) or 2D JAX array of shape (Y, X) 

396 

397 Returns: 

398 Masked 3D JAX array of shape (Z, Y, X) 

399 """ 

400 _validate_3d_array(image) 

401 

402 # Handle 2D mask (apply to each Z-slice) 

403 if isinstance(mask, jnp.ndarray) and mask.ndim == 2: 

404 if mask.shape != image.shape[1:]: 

405 raise ValueError( 

406 f"2D mask shape {mask.shape} doesn't match image slice shape {image.shape[1:]}" 

407 ) 

408 

409 # Apply 2D mask to each Z-slice 

410 result_list = [] 

411 for z in range(image.shape[0]): 

412 result_list.append(image[z].astype(jnp.float32) * mask.astype(jnp.float32)) 

413 

414 result = jnp.stack(result_list, axis=0) 

415 return result.astype(image.dtype) 

416 

417 # Handle 3D mask 

418 if isinstance(mask, jnp.ndarray) and mask.ndim == 3: 

419 if mask.shape != image.shape: 

420 raise ValueError( 

421 f"3D mask shape {mask.shape} doesn't match image shape {image.shape}" 

422 ) 

423 

424 # Apply 3D mask directly 

425 masked = image.astype(jnp.float32) * mask.astype(jnp.float32) 

426 return masked.astype(image.dtype) 

427 

428 # If we get here, the mask is neither 2D nor 3D JAX array 

429 raise TypeError(f"mask must be a 2D or 3D JAX array, got {type(mask)}") 

430 

431@jax_func 

432def create_weight_mask( 

433 shape: Tuple[int, int], margin_ratio: float = 0.1 

434) -> "jnp.ndarray": 

435 """ 

436 Create a weight mask for blending images. 

437 

438 Args: 

439 shape: Shape of the mask (height, width) 

440 margin_ratio: Ratio of image size to use as margin 

441 

442 Returns: 

443 2D JAX weight mask of shape (Y, X) 

444 """ 

445 if not isinstance(shape, tuple) or len(shape) != 2: 

446 raise TypeError("shape must be a tuple of (height, width)") 

447 

448 height, width = shape 

449 return create_linear_weight_mask(height, width, margin_ratio) 

450 

451@jax_func 

452def max_projection(stack: "jnp.ndarray") -> "jnp.ndarray": 

453 """ 

454 Create a maximum intensity projection from a Z-stack. 

455 

456 Args: 

457 stack: 3D JAX array of shape (Z, Y, X) 

458 

459 Returns: 

460 3D JAX array of shape (1, Y, X) 

461 """ 

462 _validate_3d_array(stack) 

463 

464 # Create max projection 

465 projection_2d = jnp.max(stack, axis=0) 

466 return jnp.expand_dims(projection_2d, axis=0) 

467 

468@jax_func 

469def mean_projection(stack: "jnp.ndarray") -> "jnp.ndarray": 

470 """ 

471 Create a mean intensity projection from a Z-stack. 

472 

473 Args: 

474 stack: 3D JAX array of shape (Z, Y, X) 

475 

476 Returns: 

477 3D JAX array of shape (1, Y, X) 

478 """ 

479 _validate_3d_array(stack) 

480 

481 # Create mean projection 

482 projection_2d = jnp.mean(stack.astype(jnp.float32), axis=0).astype(stack.dtype) 

483 return jnp.expand_dims(projection_2d, axis=0) 

484 

485@jax_func 

486def stack_equalize_histogram( 

487 stack: "jnp.ndarray", 

488 bins: int = 65536, 

489 range_min: float = 0.0, 

490 range_max: float = 65535.0 

491) -> "jnp.ndarray": 

492 """ 

493 Apply histogram equalization to an entire stack. 

494 

495 This ensures consistent contrast enhancement across all Z-slices by 

496 computing a global histogram across the entire stack. 

497 

498 Args: 

499 stack: 3D JAX array of shape (Z, Y, X) 

500 bins: Number of bins for histogram computation 

501 range_min: Minimum value for histogram range 

502 range_max: Maximum value for histogram range 

503 

504 Returns: 

505 Equalized 3D JAX array of shape (Z, Y, X) 

506 """ 

507 _validate_3d_array(stack) 

508 

509 # Flatten the entire stack to compute the global histogram 

510 flat_stack = stack.flatten() 

511 

512 # Calculate the histogram 

513 hist, _ = jnp.histogram(flat_stack, bins=bins, range=(range_min, range_max)) 

514 

515 # Calculate cumulative distribution function (CDF) 

516 cdf = jnp.cumsum(hist) 

517 

518 # Normalize the CDF to the range [0, 65535] 

519 # Avoid division by zero 

520 cdf_max = jnp.max(cdf) 

521 cdf_normalized = jax.lax.cond( 

522 cdf_max > 0, 

523 lambda x: 65535.0 * x / cdf_max, 

524 lambda x: x, 

525 cdf 

526 ) 

527 

528 # Scale input values to bin indices 

529 bin_width = (range_max - range_min) / bins 

530 indices = jnp.clip( 

531 jnp.floor((flat_stack - range_min) / bin_width).astype(jnp.int32), 

532 0, bins - 1 

533 ) 

534 

535 # Look up CDF values 

536 equalized_flat = jnp.take(cdf_normalized, indices) 

537 

538 # Reshape back to original shape 

539 equalized_stack = equalized_flat.reshape(stack.shape) 

540 

541 # Convert to uint16 

542 return equalized_stack.astype(jnp.uint16) 

543 

544@jax_func 

545def create_projection( 

546 stack: "jnp.ndarray", method: str = "max_projection" 

547) -> "jnp.ndarray": 

548 """ 

549 Create a projection from a stack using the specified method. 

550 

551 Args: 

552 stack: 3D JAX array of shape (Z, Y, X) 

553 method: Projection method (max_projection, mean_projection) 

554 

555 Returns: 

556 3D JAX array of shape (1, Y, X) 

557 """ 

558 _validate_3d_array(stack) 

559 

560 if method == "max_projection": 

561 return max_projection(stack) 

562 

563 if method == "mean_projection": 

564 return mean_projection(stack) 

565 

566 # FAIL FAST: No fallback projection methods 

567 raise ValueError(f"Unknown projection method: {method}. Valid methods: max_projection, mean_projection") 

568 

569@jax_func 

570def tophat( 

571 image: "jnp.ndarray", 

572 selem_radius: int = 50, 

573 downsample_factor: int = 4 

574) -> "jnp.ndarray": 

575 """ 

576 Apply white top-hat filter to a 3D image for background removal. 

577 

578 This applies the filter to each Z-slice independently using JAX's 

579 native operations. 

580 

581 Args: 

582 image: 3D JAX array of shape (Z, Y, X) 

583 selem_radius: Radius of the structuring element disk 

584 downsample_factor: Factor by which to downsample the image for processing 

585 

586 Returns: 

587 Filtered 3D JAX array of shape (Z, Y, X) 

588 """ 

589 _validate_3d_array(image) 

590 

591 # Process each Z-slice independently 

592 result_list = [] 

593 

594 # Define a function to process a single slice 

595 def process_slice(slice_idx): 

596 slice_data = image[slice_idx] 

597 input_dtype = slice_data.dtype 

598 

599 # 1) Downsample 

600 # JAX doesn't have a direct resize function, so we'll use a simple approach 

601 # This is a simplified version and might not match scikit-image's resize exactly 

602 new_h = slice_data.shape[0] // downsample_factor 

603 new_w = slice_data.shape[1] // downsample_factor 

604 

605 # Simple block averaging for downsampling 

606 slice_data_float = slice_data.astype(jnp.float32) 

607 blocks = slice_data_float.reshape( 

608 new_h, downsample_factor, new_w, downsample_factor 

609 ) 

610 image_small = jnp.mean(blocks, axis=(1, 3)) 

611 

612 # 2) Create a circular structuring element 

613 small_selem_radius = max(1, selem_radius // downsample_factor) 

614 

615 # Create grid for structuring element 

616 y_range = jnp.arange(-small_selem_radius, small_selem_radius + 1) 

617 x_range = jnp.arange(-small_selem_radius, small_selem_radius + 1) 

618 grid_y, grid_x = jnp.meshgrid(y_range, x_range, indexing='ij') 

619 

620 # Create circular mask 

621 small_mask = (grid_x**2 + grid_y**2) <= small_selem_radius**2 

622 small_selem = small_mask.astype(jnp.float32) 

623 

624 # 3) Apply white top-hat 

625 # JAX doesn't have built-in morphological operations 

626 # This is a simplified implementation that approximates the behavior 

627 

628 # Pad the image for convolution 

629 pad_size = small_selem_radius 

630 padded = jnp.pad(image_small, pad_size, mode='reflect') 

631 

632 # Implement erosion (minimum filter) 

633 # For each pixel, find the minimum value in the neighborhood defined by the structuring element 

634 eroded = jnp.zeros_like(image_small) 

635 

636 # This is a simplified approach - in a real implementation, we would use a more efficient method 

637 for y in range(new_h): 

638 for x in range(new_w): 

639 # Extract neighborhood 

640 neighborhood = padded[y:y+2*pad_size+1, x:x+2*pad_size+1] 

641 # Apply structuring element and find minimum 

642 masked_values = jnp.where(small_selem, neighborhood, jnp.inf) 

643 eroded = eroded.at[y, x].set(jnp.min(masked_values)) 

644 

645 # Implement dilation (maximum filter) 

646 # For each pixel, find the maximum value in the neighborhood defined by the structuring element 

647 opened = jnp.zeros_like(image_small) 

648 

649 # Pad the eroded image 

650 padded_eroded = jnp.pad(eroded, pad_size, mode='reflect') 

651 

652 # This is a simplified approach - in a real implementation, we would use a more efficient method 

653 for y in range(new_h): 

654 for x in range(new_w): 

655 # Extract neighborhood 

656 neighborhood = padded_eroded[y:y+2*pad_size+1, x:x+2*pad_size+1] 

657 # Apply structuring element and find maximum 

658 masked_values = jnp.where(small_selem, neighborhood, -jnp.inf) 

659 opened = opened.at[y, x].set(jnp.max(masked_values)) 

660 

661 # White top-hat is original minus opening 

662 tophat_small = image_small - opened 

663 

664 # 4) Calculate background 

665 background_small = image_small - tophat_small 

666 

667 # 5) Upscale background to original size 

668 # Simple nearest neighbor upscaling 

669 background_large = jnp.repeat( 

670 jnp.repeat(background_small, downsample_factor, axis=0), 

671 downsample_factor, axis=1 

672 ) 

673 

674 # Crop to original size if needed 

675 if background_large.shape != slice_data.shape: 

676 background_large = background_large[:slice_data.shape[0], :slice_data.shape[1]] 

677 

678 # 6) Subtract background and clip negative values 

679 slice_result = jnp.maximum(slice_data.astype(jnp.float32) - background_large, 0) 

680 

681 # 7) Convert back to original data type 

682 return slice_result.astype(input_dtype) 

683 

684 # Process each slice 

685 for z in range(image.shape[0]): 

686 result_list.append(process_slice(z)) 

687 

688 # Stack results back into a 3D array 

689 result = jnp.stack(result_list, axis=0) 

690 

691 return result