Coverage for src/arraybridge/dtype_scaling.py: 70%
76 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-03 05:09 +0000
1"""
2Dtype scaling and conversion functions for different memory types.
4This module provides framework-specific scaling functions that handle conversion
5between floating point and integer dtypes with proper range scaling.
7Uses enum-driven metaprogramming to eliminate 276 lines of duplication (82% reduction).
8Pattern follows PR #38: pure data → eval() → single generic function.
9"""
11from functools import partial
13import numpy as np
15from arraybridge.framework_config import _FRAMEWORK_CONFIG
16from arraybridge.types import MemoryType
17from arraybridge.utils import optional_import
19# Scaling ranges for integer dtypes (shared across all memory types)
20_SCALING_RANGES = {
21 "uint8": 255.0,
22 "uint16": 65535.0,
23 "uint32": 4294967295.0,
24 "int16": (65535.0, 32768.0), # (scale, offset)
25 "int32": (4294967295.0, 2147483648.0),
26}
29# NOTE: Framework-specific scaling operations now defined in framework_config.py
30# This eliminates the scattered _FRAMEWORK_OPS dict
33def _scale_generic(result, target_dtype, mem_type: MemoryType):
34 """
35 Generic scaling function that works for all memory types using framework config.
37 This single function replaces 6 nearly-identical scaling functions.
38 """
39 # Special case: pyclesperanto
40 if mem_type == MemoryType.PYCLESPERANTO:
41 return _scale_pyclesperanto(result, target_dtype)
43 config = _FRAMEWORK_CONFIG[mem_type]
44 ops = config["scaling_ops"]
45 mod = optional_import(mem_type.value) # noqa: F841 (used in eval)
46 if mod is None:
47 return result
49 if not hasattr(result, "dtype"):
50 return result
52 # Extra imports (e.g., jax.numpy) - load first as dtype_map might need it
53 if "extra_import" in ops:
54 jnp = optional_import(ops["extra_import"]) # noqa: F841 (used in eval)
56 # Handle dtype mapping for frameworks that need it
57 target_dtype_mapped = target_dtype # noqa: F841 (used in eval)
58 if ops.get("needs_dtype_map"):
59 # Use jnp for JAX, mod for others
60 dtype_module = jnp if "extra_import" in ops and jnp is not None else mod
61 dtype_map = {
62 np.uint8: dtype_module.uint8,
63 np.int8: dtype_module.int8,
64 np.int16: dtype_module.int16,
65 np.int32: dtype_module.int32,
66 np.int64: dtype_module.int64,
67 np.float16: dtype_module.float16,
68 np.float32: dtype_module.float32,
69 np.float64: dtype_module.float64,
70 }
71 target_dtype_mapped = dtype_map.get(target_dtype, dtype_module.float32) # noqa: F841
73 # Check if conversion needed (float → int)
74 result_is_float = eval(ops["check_float"])
75 target_is_int = eval(ops["check_int"])
77 if not (result_is_float and target_is_int):
78 # Direct conversion
79 return eval(ops["astype"])
81 # Get min/max
82 result_min = eval(ops["min"]) # noqa: F841 (used in eval)
83 result_max = eval(ops["max"]) # noqa: F841 (used in eval)
85 if result_max <= result_min:
86 # Constant image
87 return eval(ops["astype"])
89 # Normalize to [0, 1]
90 normalized = (result - result_min) / (result_max - result_min) # noqa: F841 (used in eval)
92 # Scale to target range
93 if hasattr(target_dtype, "__name__"):
94 dtype_name = target_dtype.__name__
95 else:
96 dtype_name = str(target_dtype).split(".")[-1]
98 if dtype_name in _SCALING_RANGES:
99 range_info = _SCALING_RANGES[dtype_name]
100 if isinstance(range_info, tuple):
101 scale_val, offset_val = range_info
102 result = normalized * scale_val - offset_val # noqa: F841 (used in eval)
103 # Clamp to avoid float32 precision overflow
104 # For int32: range is [-2147483648, 2147483647]
105 # But float32 cannot precisely represent 2147483647, it rounds to 2147483648
106 # float32 has ~7 decimal digits of precision, so large integers lose precision
107 # We need to use a max value that's safely below INT32_MAX when rounded
108 # Subtracting 128 gives us a safe margin while still using most of the range
109 min_val = -offset_val # noqa: F841 (used in eval)
110 max_val = (
111 scale_val - offset_val - 128
112 ) # Safe margin for float32 precision # noqa: F841 E501
113 else:
114 result = normalized * range_info # noqa: F841 (used in eval)
115 # For unsigned types: range is [0, range_info]
116 min_val = 0 # noqa: F841 (used in eval)
117 max_val = range_info # noqa: F841 (used in eval)
119 # Clamp to prevent overflow due to float32 precision issues
120 if ops.get("clamp"):
121 result = eval(ops["clamp"]) # noqa: F841 (used in eval)
122 else:
123 result = normalized # noqa: F841 (used in eval)
125 # Convert dtype
126 return eval(ops["astype"])
129def _scale_pyclesperanto(result, target_dtype):
130 """Scale pyclesperanto results (GPU operations require special handling)."""
131 cle = optional_import("pyclesperanto")
132 if cle is None or not hasattr(result, "dtype"):
133 return result
135 # Check if result is floating point and target is integer
136 result_is_float = np.issubdtype(result.dtype, np.floating)
137 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32]
139 if not (result_is_float and target_is_int):
140 # Direct conversion
141 return cle.push(cle.pull(result).astype(target_dtype))
143 # Get min/max
144 result_min = float(cle.minimum_of_all_pixels(result))
145 result_max = float(cle.maximum_of_all_pixels(result))
147 if result_max <= result_min:
148 # Constant image
149 return cle.push(cle.pull(result).astype(target_dtype))
151 # Normalize to [0, 1] using GPU operations
152 normalized = cle.subtract_image_from_scalar(result, scalar=result_min)
153 range_val = result_max - result_min
154 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0 / range_val)
156 # Scale to target range
157 dtype_name = target_dtype.__name__
158 if dtype_name in _SCALING_RANGES:
159 range_info = _SCALING_RANGES[dtype_name]
160 if isinstance(range_info, tuple):
161 scale_val, offset_val = range_info
162 scaled = cle.multiply_image_and_scalar(normalized, scalar=scale_val)
163 scaled = cle.subtract_image_from_scalar(scaled, scalar=offset_val)
164 else:
165 scaled = cle.multiply_image_and_scalar(normalized, scalar=range_info)
166 else:
167 scaled = normalized
169 # Convert dtype
170 return cle.push(cle.pull(scaled).astype(target_dtype))
173# Auto-generate all scaling functions using partial application
174_SCALING_FUNCTIONS_GENERATED = {
175 mem_type.value: partial(_scale_generic, mem_type=mem_type) for mem_type in MemoryType
176}
178# Registry mapping memory type names to scaling functions (backward compatibility)
179SCALING_FUNCTIONS = _SCALING_FUNCTIONS_GENERATED