Coverage for src/arraybridge/dtype_scaling.py: 70%

76 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-03 05:09 +0000

1""" 

2Dtype scaling and conversion functions for different memory types. 

3 

4This module provides framework-specific scaling functions that handle conversion 

5between floating point and integer dtypes with proper range scaling. 

6 

7Uses enum-driven metaprogramming to eliminate 276 lines of duplication (82% reduction). 

8Pattern follows PR #38: pure data → eval() → single generic function. 

9""" 

10 

11from functools import partial 

12 

13import numpy as np 

14 

15from arraybridge.framework_config import _FRAMEWORK_CONFIG 

16from arraybridge.types import MemoryType 

17from arraybridge.utils import optional_import 

18 

19# Scaling ranges for integer dtypes (shared across all memory types) 

20_SCALING_RANGES = { 

21 "uint8": 255.0, 

22 "uint16": 65535.0, 

23 "uint32": 4294967295.0, 

24 "int16": (65535.0, 32768.0), # (scale, offset) 

25 "int32": (4294967295.0, 2147483648.0), 

26} 

27 

28 

29# NOTE: Framework-specific scaling operations now defined in framework_config.py 

30# This eliminates the scattered _FRAMEWORK_OPS dict 

31 

32 

33def _scale_generic(result, target_dtype, mem_type: MemoryType): 

34 """ 

35 Generic scaling function that works for all memory types using framework config. 

36 

37 This single function replaces 6 nearly-identical scaling functions. 

38 """ 

39 # Special case: pyclesperanto 

40 if mem_type == MemoryType.PYCLESPERANTO: 

41 return _scale_pyclesperanto(result, target_dtype) 

42 

43 config = _FRAMEWORK_CONFIG[mem_type] 

44 ops = config["scaling_ops"] 

45 mod = optional_import(mem_type.value) # noqa: F841 (used in eval) 

46 if mod is None: 

47 return result 

48 

49 if not hasattr(result, "dtype"): 

50 return result 

51 

52 # Extra imports (e.g., jax.numpy) - load first as dtype_map might need it 

53 if "extra_import" in ops: 

54 jnp = optional_import(ops["extra_import"]) # noqa: F841 (used in eval) 

55 

56 # Handle dtype mapping for frameworks that need it 

57 target_dtype_mapped = target_dtype # noqa: F841 (used in eval) 

58 if ops.get("needs_dtype_map"): 

59 # Use jnp for JAX, mod for others 

60 dtype_module = jnp if "extra_import" in ops and jnp is not None else mod 

61 dtype_map = { 

62 np.uint8: dtype_module.uint8, 

63 np.int8: dtype_module.int8, 

64 np.int16: dtype_module.int16, 

65 np.int32: dtype_module.int32, 

66 np.int64: dtype_module.int64, 

67 np.float16: dtype_module.float16, 

68 np.float32: dtype_module.float32, 

69 np.float64: dtype_module.float64, 

70 } 

71 target_dtype_mapped = dtype_map.get(target_dtype, dtype_module.float32) # noqa: F841 

72 

73 # Check if conversion needed (float → int) 

74 result_is_float = eval(ops["check_float"]) 

75 target_is_int = eval(ops["check_int"]) 

76 

77 if not (result_is_float and target_is_int): 

78 # Direct conversion 

79 return eval(ops["astype"]) 

80 

81 # Get min/max 

82 result_min = eval(ops["min"]) # noqa: F841 (used in eval) 

83 result_max = eval(ops["max"]) # noqa: F841 (used in eval) 

84 

85 if result_max <= result_min: 

86 # Constant image 

87 return eval(ops["astype"]) 

88 

89 # Normalize to [0, 1] 

90 normalized = (result - result_min) / (result_max - result_min) # noqa: F841 (used in eval) 

91 

92 # Scale to target range 

93 if hasattr(target_dtype, "__name__"): 

94 dtype_name = target_dtype.__name__ 

95 else: 

96 dtype_name = str(target_dtype).split(".")[-1] 

97 

98 if dtype_name in _SCALING_RANGES: 

99 range_info = _SCALING_RANGES[dtype_name] 

100 if isinstance(range_info, tuple): 

101 scale_val, offset_val = range_info 

102 result = normalized * scale_val - offset_val # noqa: F841 (used in eval) 

103 # Clamp to avoid float32 precision overflow 

104 # For int32: range is [-2147483648, 2147483647] 

105 # But float32 cannot precisely represent 2147483647, it rounds to 2147483648 

106 # float32 has ~7 decimal digits of precision, so large integers lose precision 

107 # We need to use a max value that's safely below INT32_MAX when rounded 

108 # Subtracting 128 gives us a safe margin while still using most of the range 

109 min_val = -offset_val # noqa: F841 (used in eval) 

110 max_val = ( 

111 scale_val - offset_val - 128 

112 ) # Safe margin for float32 precision # noqa: F841 E501 

113 else: 

114 result = normalized * range_info # noqa: F841 (used in eval) 

115 # For unsigned types: range is [0, range_info] 

116 min_val = 0 # noqa: F841 (used in eval) 

117 max_val = range_info # noqa: F841 (used in eval) 

118 

119 # Clamp to prevent overflow due to float32 precision issues 

120 if ops.get("clamp"): 

121 result = eval(ops["clamp"]) # noqa: F841 (used in eval) 

122 else: 

123 result = normalized # noqa: F841 (used in eval) 

124 

125 # Convert dtype 

126 return eval(ops["astype"]) 

127 

128 

129def _scale_pyclesperanto(result, target_dtype): 

130 """Scale pyclesperanto results (GPU operations require special handling).""" 

131 cle = optional_import("pyclesperanto") 

132 if cle is None or not hasattr(result, "dtype"): 

133 return result 

134 

135 # Check if result is floating point and target is integer 

136 result_is_float = np.issubdtype(result.dtype, np.floating) 

137 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32] 

138 

139 if not (result_is_float and target_is_int): 

140 # Direct conversion 

141 return cle.push(cle.pull(result).astype(target_dtype)) 

142 

143 # Get min/max 

144 result_min = float(cle.minimum_of_all_pixels(result)) 

145 result_max = float(cle.maximum_of_all_pixels(result)) 

146 

147 if result_max <= result_min: 

148 # Constant image 

149 return cle.push(cle.pull(result).astype(target_dtype)) 

150 

151 # Normalize to [0, 1] using GPU operations 

152 normalized = cle.subtract_image_from_scalar(result, scalar=result_min) 

153 range_val = result_max - result_min 

154 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0 / range_val) 

155 

156 # Scale to target range 

157 dtype_name = target_dtype.__name__ 

158 if dtype_name in _SCALING_RANGES: 

159 range_info = _SCALING_RANGES[dtype_name] 

160 if isinstance(range_info, tuple): 

161 scale_val, offset_val = range_info 

162 scaled = cle.multiply_image_and_scalar(normalized, scalar=scale_val) 

163 scaled = cle.subtract_image_from_scalar(scaled, scalar=offset_val) 

164 else: 

165 scaled = cle.multiply_image_and_scalar(normalized, scalar=range_info) 

166 else: 

167 scaled = normalized 

168 

169 # Convert dtype 

170 return cle.push(cle.pull(scaled).astype(target_dtype)) 

171 

172 

173# Auto-generate all scaling functions using partial application 

174_SCALING_FUNCTIONS_GENERATED = { 

175 mem_type.value: partial(_scale_generic, mem_type=mem_type) for mem_type in MemoryType 

176} 

177 

178# Registry mapping memory type names to scaling functions (backward compatibility) 

179SCALING_FUNCTIONS = _SCALING_FUNCTIONS_GENERATED