Coverage for openhcs/core/memory/dtype_scaling.py: 10.5%

67 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2Dtype scaling and conversion functions for different memory types. 

3 

4This module provides framework-specific scaling functions that handle conversion 

5between floating point and integer dtypes with proper range scaling. 

6 

7Uses enum-driven metaprogramming to eliminate 276 lines of duplication (82% reduction). 

8Pattern follows PR #38: pure data → eval() → single generic function. 

9""" 

10 

11import numpy as np 

12from openhcs.constants.constants import MemoryType 

13from openhcs.core.memory.framework_config import _FRAMEWORK_CONFIG 

14from openhcs.core.utils import optional_import 

15 

16 

17# Scaling ranges for integer dtypes (shared across all memory types) 

18_SCALING_RANGES = { 

19 'uint8': 255.0, 

20 'uint16': 65535.0, 

21 'uint32': 4294967295.0, 

22 'int16': (65535.0, 32768.0), # (scale, offset) 

23 'int32': (4294967295.0, 2147483648.0), 

24} 

25 

26 

27# NOTE: Framework-specific scaling operations now defined in framework_config.py 

28# This eliminates the scattered _FRAMEWORK_OPS dict 

29 

30 

31def _scale_generic(result, target_dtype, mem_type: MemoryType): 

32 """ 

33 Generic scaling function that works for all memory types using framework config. 

34 

35 This single function replaces 6 nearly-identical scaling functions. 

36 """ 

37 # Special case: pyclesperanto 

38 if mem_type == MemoryType.PYCLESPERANTO: 

39 return _scale_pyclesperanto(result, target_dtype) 

40 

41 config = _FRAMEWORK_CONFIG[mem_type] 

42 ops = config['scaling_ops'] 

43 mod = optional_import(mem_type.value) # noqa: F841 (used in eval) 

44 if mod is None: 

45 return result 

46 

47 if not hasattr(result, 'dtype'): 

48 return result 

49 

50 # Handle dtype mapping for frameworks that need it 

51 target_dtype_mapped = target_dtype # noqa: F841 (used in eval) 

52 if ops.get('needs_dtype_map'): 

53 dtype_map = { 

54 np.uint8: mod.uint8, np.int8: mod.int8, np.int16: mod.int16, 

55 np.int32: mod.int32, np.int64: mod.int64, np.float16: mod.float16, 

56 np.float32: mod.float32, np.float64: mod.float64, 

57 } 

58 target_dtype_mapped = dtype_map.get(target_dtype, mod.float32) # noqa: F841 

59 

60 # Extra imports (e.g., jax.numpy) 

61 if 'extra_import' in ops: 

62 jnp = optional_import(ops['extra_import']) # noqa: F841 (used in eval) 

63 

64 # Check if conversion needed (float → int) 

65 result_is_float = eval(ops['check_float']) 

66 target_is_int = eval(ops['check_int']) 

67 

68 if not (result_is_float and target_is_int): 

69 # Direct conversion 

70 return eval(ops['astype']) 

71 

72 # Get min/max 

73 result_min = eval(ops['min']) # noqa: F841 (used in eval) 

74 result_max = eval(ops['max']) # noqa: F841 (used in eval) 

75 

76 if result_max <= result_min: 

77 # Constant image 

78 return eval(ops['astype']) 

79 

80 # Normalize to [0, 1] 

81 normalized = (result - result_min) / (result_max - result_min) # noqa: F841 (used in eval) 

82 

83 # Scale to target range 

84 dtype_name = target_dtype.__name__ if hasattr(target_dtype, '__name__') else str(target_dtype).split('.')[-1] 

85 

86 if dtype_name in _SCALING_RANGES: 

87 range_info = _SCALING_RANGES[dtype_name] 

88 if isinstance(range_info, tuple): 

89 scale_val, offset_val = range_info 

90 result = normalized * scale_val - offset_val # noqa: F841 (used in eval) 

91 else: 

92 result = normalized * range_info # noqa: F841 (used in eval) 

93 else: 

94 result = normalized # noqa: F841 (used in eval) 

95 

96 # Convert dtype 

97 return eval(ops['astype']) 

98 

99 

100def _scale_pyclesperanto(result, target_dtype): 

101 """Scale pyclesperanto results (GPU operations require special handling).""" 

102 cle = optional_import("pyclesperanto") 

103 if cle is None or not hasattr(result, 'dtype'): 

104 return result 

105 

106 # Check if result is floating point and target is integer 

107 result_is_float = np.issubdtype(result.dtype, np.floating) 

108 target_is_int = target_dtype in [np.uint8, np.uint16, np.uint32, np.int8, np.int16, np.int32] 

109 

110 if not (result_is_float and target_is_int): 

111 # Direct conversion 

112 return cle.push(cle.pull(result).astype(target_dtype)) 

113 

114 # Get min/max 

115 result_min = float(cle.minimum_of_all_pixels(result)) 

116 result_max = float(cle.maximum_of_all_pixels(result)) 

117 

118 if result_max <= result_min: 

119 # Constant image 

120 return cle.push(cle.pull(result).astype(target_dtype)) 

121 

122 # Normalize to [0, 1] using GPU operations 

123 normalized = cle.subtract_image_from_scalar(result, scalar=result_min) 

124 range_val = result_max - result_min 

125 normalized = cle.multiply_image_and_scalar(normalized, scalar=1.0/range_val) 

126 

127 # Scale to target range 

128 dtype_name = target_dtype.__name__ 

129 if dtype_name in _SCALING_RANGES: 

130 range_info = _SCALING_RANGES[dtype_name] 

131 if isinstance(range_info, tuple): 

132 scale_val, offset_val = range_info 

133 scaled = cle.multiply_image_and_scalar(normalized, scalar=scale_val) 

134 scaled = cle.subtract_image_from_scalar(scaled, scalar=offset_val) 

135 else: 

136 scaled = cle.multiply_image_and_scalar(normalized, scalar=range_info) 

137 else: 

138 scaled = normalized 

139 

140 # Convert dtype 

141 return cle.push(cle.pull(scaled).astype(target_dtype)) 

142 

143 

144# Auto-generate all scaling functions using partial application 

145from functools import partial 

146 

147_SCALING_FUNCTIONS_GENERATED = { 

148 mem_type.value: partial(_scale_generic, mem_type=mem_type) 

149 for mem_type in MemoryType 

150} 

151 

152# Registry mapping memory type names to scaling functions (backward compatibility) 

153SCALING_FUNCTIONS = _SCALING_FUNCTIONS_GENERATED 

154