Coverage for openhcs/core/memory/dtype_scaling.py: 10.5%
67 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-04 02:09 +0000
1"""
2Dtype scaling and conversion functions for different memory types.
4This module provides framework-specific scaling functions that handle conversion
5between floating point and integer dtypes with proper range scaling.
7Uses enum-driven metaprogramming to eliminate 276 lines of duplication (82% reduction).
8Pattern follows PR #38: pure data → eval() → single generic function.
9"""
11import numpy as np
12from openhcs.constants.constants import MemoryType
13from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG
14from openhcs.core.utils import optional_import
17# Scaling ranges for integer dtypes (shared across all memory types)
18_SCALING_RANGES = {
19 'uint8': 255.0,
20 'uint16': 65535.0,
21 'uint32': 4294967295.0,
22 'int16': (65535.0, 32768.0), # (scale, offset)
23 'int32': (4294967295.0, 2147483648.0),
24}
27# NOTE: Framework-specific scaling operations now defined in framework_config.py
28# This eliminates the scattered _FRAMEWORK_OPS dict
31def _scale_generic(result, target_dtype, mem_type: MemoryType):
32 """
33 Generic scaling function that works for all memory types using framework config.
35 This single function replaces 6 nearly-identical scaling functions.
36 """
37 # Special case: pyclesperanto
38 if mem_type == MemoryType.PYCLESPERANTO:
39 return _scale_pyclesperanto(result, target_dtype)
41 config = _FRAMEWORK_CONFIG[mem_type]
42 ops = config['scaling_ops']
43 mod = optional_import(mem_type.value) # noqa: F841 (used in eval)
44 if mod is None:
45 return result
47 if not hasattr(result, 'dtype'):
48 return result
50 # Handle dtype mapping for frameworks that need it
51 target_dtype_mapped = target_dtype # noqa: F841 (used in eval)
52 if ops.get('needs_dtype_map'):
53 dtype_map = {
54 np.uint8: mod.uint8, np.int8: mod.int8, np.int16: mod.int16,
55 np.int32: mod.int32, np.int64: mod.int64, np.float16: mod.float16,
56 np.float32: mod.float32, np.float64: mod.float64,
57 }
58 target_dtype_mapped = dtype_map.get(target_dtype, mod.float32) # noqa: F841
60 # Extra imports (e.g., jax.numpy)
61 if 'extra_import' in ops:
62 jnp = optional_import(ops['extra_import']) # noqa: F841 (used in eval)
64 # Check if conversion needed (float → int)
65 result_is_float = eval(ops['check_float'])
66 target_is_int = eval(ops['check_int'])
68 if not (result_is_float and target_is_int):
69 # Direct conversion
70 return eval(ops['astype'])
72 # Get min/max
73 result_min = eval(ops['min']) # noqa: F841 (used in eval)
74 result_max = eval(ops['max']) # noqa: F841 (used in eval)
76 if result_max <= result_min:
77 # Constant image
78 return eval(ops['astype'])
80 # Normalize to [0, 1]
81 normalized = (result - result_min) / (result_max - result_min) # noqa: F841 (used in eval)
83 # Scale to target range
84 dtype_name = target_dtype.__name__ if hasattr(target_dtype, '__name__') else str(target_dtype).split('.')[-1]
86 if dtype_name in _SCALING_RANGES:
87 range_info = _SCALING_RANGES[dtype_name]
88 if isinstance(range_info, tuple):
89 scale_val, offset_val = range_info
90 result = normalized * scale_val - offset_val # noqa: F841 (used in eval)
91 else:
92 result = normalized * range_info # noqa: F841 (used in eval)
93 else:
94 result = normalized # noqa: F841 (used in eval)
96 # Convert dtype
97 return eval(ops['astype'])
100def _scale_pyclesperanto(result, target_dtype):
101 """Scale pyclesperanto results (GPU operations require special handling)."""
102 cle = optional_import("pyclesperanto")
103 if cle is None or not hasattr(result, 'dtype'):
104 return result
106 # Check if result is floating point and target is integer
107 result_is_float = np.issubdtype(result.dtype, np.floating)
108 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]
110 if not (result_is_float and target_is_int):
111 # Direct conversion
112 return cle.push(cle.pull(result).astype(target_dtype))
114 # Get min/max
115 result_min = float(cle.minimum_of_all_pixels(result))
116 result_max = float(cle.maximum_of_all_pixels(result))
118 if result_max <= result_min:
119 # Constant image
120 return cle.push(cle.pull(result).astype(target_dtype))
122 # Normalize to [0, 1] using GPU operations
123 normalized = cle.subtract_image_from_scalar(result, scalar=result_min)
124 range_val = result_max - result_min
125 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0/range_val)
127 # Scale to target range
128 dtype_name = target_dtype.__name__
129 if dtype_name in _SCALING_RANGES:
130 range_info = _SCALING_RANGES[dtype_name]
131 if isinstance(range_info, tuple):
132 scale_val, offset_val = range_info
133 scaled = cle.multiply_image_and_scalar(normalized, scalar=scale_val)
134 scaled = cle.subtract_image_from_scalar(scaled, scalar=offset_val)
135 else:
136 scaled = cle.multiply_image_and_scalar(normalized, scalar=range_info)
137 else:
138 scaled = normalized
140 # Convert dtype
141 return cle.push(cle.pull(scaled).astype(target_dtype))
144# Auto-generate all scaling functions using partial application
145from functools import partial
147_SCALING_FUNCTIONS_GENERATED = {
148 mem_type.value: partial(_scale_generic, mem_type=mem_type)
149 for mem_type in MemoryType
150}
152# Registry mapping memory type names to scaling functions (backward compatibility)
153SCALING_FUNCTIONS = _SCALING_FUNCTIONS_GENERATED