Coverage for openhcs/core/memory/wrapper.py: 19.4%

90 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2Memory wrapper implementation for OpenHCS. 

3 

4This module provides the MemoryWrapper class for encapsulating in-memory data arrays 

5with explicit type declarations and conversion methods, enforcing Clause 251 

6(Declarative Memory Conversion Interface) and Clause 106-A (Declared Memory Types). 

7""" 

8 

9from typing import Any, Optional 

10 

11from openhcs.constants.constants import MemoryType 

12 

13from .converters import (convert_memory, validate_data_compatibility, 

14 validate_memory_type) 

15from .exceptions import MemoryConversionError 

16from .utils import _ensure_module, _get_device_id 

17 

18 

19class MemoryWrapper: 

20 """ 

21 Immutable wrapper for in-memory data arrays with explicit type declarations. 

22 

23 This class enforces Clause 251 (Declarative Memory Conversion Interface) and 

24 Clause 106-A (Declared Memory Types) by requiring explicit memory type declarations 

25 and providing declarative conversion methods. 

26 

27 Attributes: 

28 memory_type: The declared memory type (e.g., "numpy", "cupy") 

29 data: The wrapped data array (read-only) 

30 gpu_id: The GPU device ID (for GPU memory types) or None for CPU 

31 input_memory_type: Alias for memory_type (for canonical access pattern) 

32 output_memory_type: Alias for memory_type (for canonical access pattern) 

33 """ 

34 

35 def __init__(self, data: Any, memory_type: str, gpu_id: int): 

36 """ 

37 Initialize a MemoryWrapper with data and explicit memory type. 

38 

39 Args: 

40 data: The in-memory data array (numpy, cupy, torch, tensorflow) 

41 memory_type: The explicit memory type declaration (e.g., "numpy", "cupy") 

42 gpu_id: The GPU device ID (required for GPU memory types) 

43 

44 Raises: 

45 ValueError: If memory_type is not supported or data is incompatible 

46 MemoryConversionError: If gpu_id is invalid 

47 """ 

48 # Validate memory type 

49 validate_memory_type(memory_type) 

50 

51 # Validate data compatibility 

52 validate_data_compatibility(data, memory_type) 

53 

54 # Store data and memory type 

55 self._data = data 

56 self._memory_type = memory_type 

57 

58 # Store the provided gpu_id for all memory types 

59 # We need gpu_id even for numpy data when converting TO GPU memory types 

60 if gpu_id is not None and gpu_id < 0: 

61 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.") 

62 self._gpu_id = gpu_id 

63 

64 @property 

65 def memory_type(self) -> str: 

66 """ 

67 Get the declared memory type. 

68 

69 Returns: 

70 The memory type as a string 

71 """ 

72 return self._memory_type 

73 

74 @property 

75 def data(self) -> Any: 

76 """ 

77 Get the wrapped data array. 

78 

79 Returns: 

80 The wrapped data array 

81 """ 

82 return self._data 

83 

84 @property 

85 def gpu_id(self) -> Optional[int]: 

86 """ 

87 Get the GPU device ID. 

88 

89 Returns: 

90 The GPU device ID or None for CPU memory types 

91 """ 

92 return self._gpu_id 

93 

94 @property 

95 def input_memory_type(self) -> str: 

96 """ 

97 Get input memory type (same as memory_type). 

98 

99 This property is provided for compatibility with the canonical memory type 

100 access pattern defined in Clause 106-A.2. 

101 

102 Returns: 

103 The memory type as a string 

104 """ 

105 return self._memory_type 

106 

107 @property 

108 def output_memory_type(self) -> str: 

109 """ 

110 Get output memory type (same as memory_type). 

111 

112 This property is provided for compatibility with the canonical memory type 

113 access pattern defined in Clause 106-A.2. 

114 

115 Returns: 

116 The memory type as a string 

117 """ 

118 return self._memory_type 

119 

120 def to_numpy(self) -> "MemoryWrapper": 

121 """ 

122 Convert to numpy array and return a new MemoryWrapper. 

123 

124 Returns: 

125 A new MemoryWrapper with numpy array data 

126 

127 Raises: 

128 ValueError: If conversion to numpy is not supported for this memory type 

129 MemoryConversionError: If conversion fails 

130 """ 

131 if self._memory_type == MemoryType.NUMPY.value: 

132 # Already numpy, return self (zero-copy) 

133 return self 

134 

135 # Convert to numpy (always goes to CPU) 

136 # Always allow CPU roundtrip for to_numpy since it's explicitly going to CPU 

137 numpy_data = convert_memory( 

138 self._data, 

139 self._memory_type, 

140 MemoryType.NUMPY.value, 

141 allow_cpu_roundtrip=True, 

142 gpu_id=0 # Use 0 as a placeholder since it's ignored for numpy 

143 ) 

144 # Use 0 as a placeholder for gpu_id since it's ignored for numpy 

145 return MemoryWrapper(numpy_data, MemoryType.NUMPY.value, 0) 

146 

147 def to_cupy(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper": 

148 """ 

149 Convert to cupy array and return a new MemoryWrapper. 

150 

151 Preserves the GPU device ID if possible. 

152 

153 Args: 

154 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

155 

156 Returns: 

157 A new MemoryWrapper with cupy array data 

158 

159 Raises: 

160 ValueError: If conversion to cupy is not supported for this memory type 

161 ImportError: If cupy is not installed 

162 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

163 """ 

164 if self._memory_type == MemoryType.CUPY.value: 

165 # Already cupy, return self (zero-copy) 

166 return self 

167 

168 # Convert to cupy, preserving GPU ID if possible 

169 cupy_data = convert_memory( 

170 self._data, 

171 self._memory_type, 

172 MemoryType.CUPY.value, 

173 gpu_id=self._gpu_id, 

174 allow_cpu_roundtrip=allow_cpu_roundtrip 

175 ) 

176 

177 # Get the GPU ID from the result (may have changed during conversion) 

178 result_gpu_id = _get_device_id(cupy_data, MemoryType.CUPY.value) 

179 

180 # Ensure we have a GPU ID for GPU memory 

181 if result_gpu_id is None: 

182 raise MemoryConversionError( 

183 source_type=self._memory_type, 

184 target_type=MemoryType.CUPY.value, 

185 method="device_detection", 

186 reason="Failed to detect GPU ID for CuPy array after conversion" 

187 ) 

188 

189 return MemoryWrapper(cupy_data, MemoryType.CUPY.value, result_gpu_id) 

190 

191 def to_torch(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper": 

192 """ 

193 Convert to torch tensor and return a new MemoryWrapper. 

194 

195 Preserves the GPU device ID if possible. 

196 

197 Args: 

198 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

199 

200 Returns: 

201 A new MemoryWrapper with torch tensor data 

202 

203 Raises: 

204 ValueError: If conversion to torch is not supported for this memory type 

205 ImportError: If torch is not installed 

206 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

207 """ 

208 if self._memory_type == MemoryType.TORCH.value: 

209 # Already torch, return self (zero-copy) 

210 return self 

211 

212 # Convert to torch, preserving GPU ID if possible 

213 torch_data = convert_memory( 

214 self._data, 

215 self._memory_type, 

216 MemoryType.TORCH.value, 

217 gpu_id=self._gpu_id, 

218 allow_cpu_roundtrip=allow_cpu_roundtrip 

219 ) 

220 

221 # Get the GPU ID from the result (may have changed during conversion) 

222 result_gpu_id = _get_device_id(torch_data, MemoryType.TORCH.value) 

223 

224 # For GPU tensors, ensure we have a GPU ID 

225 if torch_data.is_cuda and result_gpu_id is None: 

226 raise MemoryConversionError( 

227 source_type=self._memory_type, 

228 target_type=MemoryType.TORCH.value, 

229 method="device_detection", 

230 reason="Failed to detect GPU ID for CUDA tensor after conversion" 

231 ) 

232 

233 return MemoryWrapper(torch_data, MemoryType.TORCH.value, result_gpu_id) 

234 

235 def to_tensorflow(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper": 

236 """ 

237 Convert to tensorflow tensor and return a new MemoryWrapper. 

238 

239 Preserves the GPU device ID if possible. 

240 

241 Args: 

242 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

243 

244 Returns: 

245 A new MemoryWrapper with tensorflow tensor data 

246 

247 Raises: 

248 ValueError: If conversion to tensorflow is not supported for this memory type 

249 ImportError: If tensorflow is not installed 

250 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

251 """ 

252 if self._memory_type == MemoryType.TENSORFLOW.value: 

253 # Already tensorflow, return self (zero-copy) 

254 return self 

255 

256 # Convert to tensorflow, preserving GPU ID if possible 

257 tf_data = convert_memory( 

258 self._data, 

259 self._memory_type, 

260 MemoryType.TENSORFLOW.value, 

261 gpu_id=self._gpu_id, 

262 allow_cpu_roundtrip=allow_cpu_roundtrip 

263 ) 

264 

265 # Get the GPU ID from the result (may have changed during conversion) 

266 result_gpu_id = _get_device_id(tf_data, MemoryType.TENSORFLOW.value) 

267 

268 # Check if this is a GPU tensor and ensure we have a GPU ID 

269 device_str = tf_data.device.lower() 

270 is_gpu_tensor = "gpu" in device_str 

271 

272 if is_gpu_tensor and result_gpu_id is None: 

273 raise MemoryConversionError( 

274 source_type=self._memory_type, 

275 target_type=MemoryType.TENSORFLOW.value, 

276 method="device_detection", 

277 reason="Failed to detect GPU ID for TensorFlow GPU tensor after conversion" 

278 ) 

279 

280 return MemoryWrapper(tf_data, MemoryType.TENSORFLOW.value, result_gpu_id) 

281 

282 def to_jax(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper": 

283 """ 

284 Convert to JAX array and return a new MemoryWrapper. 

285 

286 Preserves the GPU device ID if possible. 

287 

288 Args: 

289 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

290 

291 Returns: 

292 A new MemoryWrapper with JAX array data 

293 

294 Raises: 

295 ValueError: If conversion to JAX is not supported for this memory type 

296 ImportError: If JAX is not installed 

297 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

298 """ 

299 if self._memory_type == MemoryType.JAX.value: 

300 # Already JAX, return self (zero-copy) 

301 return self 

302 

303 # Convert to JAX, preserving GPU ID if possible 

304 jax_data = convert_memory( 

305 self._data, 

306 self._memory_type, 

307 MemoryType.JAX.value, 

308 gpu_id=self._gpu_id, 

309 allow_cpu_roundtrip=allow_cpu_roundtrip 

310 ) 

311 

312 # Get GPU ID from JAX array 

313 result_gpu_id = _get_device_id(jax_data, MemoryType.JAX.value) 

314 

315 # Check if this is a GPU array and ensure we have a GPU ID 

316 device_str = str(jax_data.device).lower() 

317 is_gpu_array = "gpu" in device_str 

318 

319 if is_gpu_array and result_gpu_id is None: 

320 raise MemoryConversionError( 

321 source_type=self._memory_type, 

322 target_type=MemoryType.JAX.value, 

323 method="device_detection", 

324 reason="Failed to detect GPU ID for JAX GPU array after conversion" 

325 ) 

326 

327 return MemoryWrapper(jax_data, MemoryType.JAX.value, result_gpu_id) 

328 

329 def to_pyclesperanto(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper": 

330 """ 

331 Convert to pyclesperanto array and return a new MemoryWrapper. 

332 

333 Preserves the GPU device ID if possible. 

334 

335 Args: 

336 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip 

337 

338 Returns: 

339 A new MemoryWrapper with pyclesperanto array data 

340 

341 Raises: 

342 ValueError: If conversion to pyclesperanto is not supported for this memory type 

343 ImportError: If pyclesperanto is not installed 

344 MemoryConversionError: If conversion fails and CPU fallback is not authorized 

345 """ 

346 if self._memory_type == MemoryType.PYCLESPERANTO.value: 

347 # Already pyclesperanto, return self (zero-copy) 

348 return self 

349 

350 # Convert to pyclesperanto, preserving GPU ID if possible 

351 pyclesperanto_data = convert_memory( 

352 self._data, 

353 self._memory_type, 

354 MemoryType.PYCLESPERANTO.value, 

355 gpu_id=self._gpu_id, 

356 allow_cpu_roundtrip=allow_cpu_roundtrip 

357 ) 

358 

359 # Get the GPU ID from the result (may have changed during conversion) 

360 result_gpu_id = _get_device_id(pyclesperanto_data, MemoryType.PYCLESPERANTO.value) 

361 

362 # Ensure we have a GPU ID for GPU memory 

363 if result_gpu_id is None: 

364 raise MemoryConversionError( 

365 source_type=self._memory_type, 

366 target_type=MemoryType.PYCLESPERANTO.value, 

367 method="device_detection", 

368 reason="Failed to detect GPU ID for pyclesperanto array after conversion" 

369 ) 

370 

371 return MemoryWrapper(pyclesperanto_data, MemoryType.PYCLESPERANTO.value, result_gpu_id) 

372 

373 def __repr__(self) -> str: 

374 """ 

375 Get a string representation of the MemoryWrapper. 

376 

377 Returns: 

378 A string representation 

379 """ 

380 return f"MemoryWrapper(memory_type='{self._memory_type}', shape={self._get_shape()})" 

381 

382 def _get_shape(self) -> tuple: 

383 """ 

384 Get the shape of the wrapped data array. 

385 

386 Returns: 

387 The shape as a tuple 

388 """ 

389 if self._memory_type == MemoryType.NUMPY.value: 

390 return self._data.shape 

391 if self._memory_type == MemoryType.CUPY.value: 

392 return self._data.shape 

393 if self._memory_type == MemoryType.TORCH.value: 

394 return tuple(self._data.shape) 

395 if self._memory_type == MemoryType.TENSORFLOW.value: 

396 return tuple(self._data.shape) 

397 if self._memory_type == MemoryType.PYCLESPERANTO.value: 

398 return tuple(self._data.shape) 

399 

400 # This should never happen if validate_memory_type is called in __init__ 

401 return tuple()