Coverage for openhcs/core/memory/wrapper.py: 19.4%
90 statements
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
« prev ^ index » next coverage.py v7.10.3, created at 2025-08-14 05:57 +0000
1"""
2Memory wrapper implementation for OpenHCS.
4This module provides the MemoryWrapper class for encapsulating in-memory data arrays
5with explicit type declarations and conversion methods, enforcing Clause 251
6(Declarative Memory Conversion Interface) and Clause 106-A (Declared Memory Types).
7"""
9from typing import Any, Optional
11from openhcs.constants.constants import MemoryType
13from .converters import (convert_memory, validate_data_compatibility,
14 validate_memory_type)
15from .exceptions import MemoryConversionError
16from .utils import _ensure_module, _get_device_id
19class MemoryWrapper:
20 """
21 Immutable wrapper for in-memory data arrays with explicit type declarations.
23 This class enforces Clause 251 (Declarative Memory Conversion Interface) and
24 Clause 106-A (Declared Memory Types) by requiring explicit memory type declarations
25 and providing declarative conversion methods.
27 Attributes:
28 memory_type: The declared memory type (e.g., "numpy", "cupy")
29 data: The wrapped data array (read-only)
30 gpu_id: The GPU device ID (for GPU memory types) or None for CPU
31 input_memory_type: Alias for memory_type (for canonical access pattern)
32 output_memory_type: Alias for memory_type (for canonical access pattern)
33 """
35 def __init__(self, data: Any, memory_type: str, gpu_id: int):
36 """
37 Initialize a MemoryWrapper with data and explicit memory type.
39 Args:
40 data: The in-memory data array (numpy, cupy, torch, tensorflow)
41 memory_type: The explicit memory type declaration (e.g., "numpy", "cupy")
42 gpu_id: The GPU device ID (required for GPU memory types)
44 Raises:
45 ValueError: If memory_type is not supported or data is incompatible
46 MemoryConversionError: If gpu_id is invalid
47 """
48 # Validate memory type
49 validate_memory_type(memory_type)
51 # Validate data compatibility
52 validate_data_compatibility(data, memory_type)
54 # Store data and memory type
55 self._data = data
56 self._memory_type = memory_type
58 # Store the provided gpu_id for all memory types
59 # We need gpu_id even for numpy data when converting TO GPU memory types
60 if gpu_id is not None and gpu_id < 0:
61 raise ValueError(f"Invalid GPU ID: {gpu_id}. Must be a non-negative integer.")
62 self._gpu_id = gpu_id
64 @property
65 def memory_type(self) -> str:
66 """
67 Get the declared memory type.
69 Returns:
70 The memory type as a string
71 """
72 return self._memory_type
74 @property
75 def data(self) -> Any:
76 """
77 Get the wrapped data array.
79 Returns:
80 The wrapped data array
81 """
82 return self._data
84 @property
85 def gpu_id(self) -> Optional[int]:
86 """
87 Get the GPU device ID.
89 Returns:
90 The GPU device ID or None for CPU memory types
91 """
92 return self._gpu_id
94 @property
95 def input_memory_type(self) -> str:
96 """
97 Get input memory type (same as memory_type).
99 This property is provided for compatibility with the canonical memory type
100 access pattern defined in Clause 106-A.2.
102 Returns:
103 The memory type as a string
104 """
105 return self._memory_type
107 @property
108 def output_memory_type(self) -> str:
109 """
110 Get output memory type (same as memory_type).
112 This property is provided for compatibility with the canonical memory type
113 access pattern defined in Clause 106-A.2.
115 Returns:
116 The memory type as a string
117 """
118 return self._memory_type
120 def to_numpy(self) -> "MemoryWrapper":
121 """
122 Convert to numpy array and return a new MemoryWrapper.
124 Returns:
125 A new MemoryWrapper with numpy array data
127 Raises:
128 ValueError: If conversion to numpy is not supported for this memory type
129 MemoryConversionError: If conversion fails
130 """
131 if self._memory_type == MemoryType.NUMPY.value:
132 # Already numpy, return self (zero-copy)
133 return self
135 # Convert to numpy (always goes to CPU)
136 # Always allow CPU roundtrip for to_numpy since it's explicitly going to CPU
137 numpy_data = convert_memory(
138 self._data,
139 self._memory_type,
140 MemoryType.NUMPY.value,
141 allow_cpu_roundtrip=True,
142 gpu_id=0 # Use 0 as a placeholder since it's ignored for numpy
143 )
144 # Use 0 as a placeholder for gpu_id since it's ignored for numpy
145 return MemoryWrapper(numpy_data, MemoryType.NUMPY.value, 0)
147 def to_cupy(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper":
148 """
149 Convert to cupy array and return a new MemoryWrapper.
151 Preserves the GPU device ID if possible.
153 Args:
154 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
156 Returns:
157 A new MemoryWrapper with cupy array data
159 Raises:
160 ValueError: If conversion to cupy is not supported for this memory type
161 ImportError: If cupy is not installed
162 MemoryConversionError: If conversion fails and CPU fallback is not authorized
163 """
164 if self._memory_type == MemoryType.CUPY.value:
165 # Already cupy, return self (zero-copy)
166 return self
168 # Convert to cupy, preserving GPU ID if possible
169 cupy_data = convert_memory(
170 self._data,
171 self._memory_type,
172 MemoryType.CUPY.value,
173 gpu_id=self._gpu_id,
174 allow_cpu_roundtrip=allow_cpu_roundtrip
175 )
177 # Get the GPU ID from the result (may have changed during conversion)
178 result_gpu_id = _get_device_id(cupy_data, MemoryType.CUPY.value)
180 # Ensure we have a GPU ID for GPU memory
181 if result_gpu_id is None:
182 raise MemoryConversionError(
183 source_type=self._memory_type,
184 target_type=MemoryType.CUPY.value,
185 method="device_detection",
186 reason="Failed to detect GPU ID for CuPy array after conversion"
187 )
189 return MemoryWrapper(cupy_data, MemoryType.CUPY.value, result_gpu_id)
191 def to_torch(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper":
192 """
193 Convert to torch tensor and return a new MemoryWrapper.
195 Preserves the GPU device ID if possible.
197 Args:
198 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
200 Returns:
201 A new MemoryWrapper with torch tensor data
203 Raises:
204 ValueError: If conversion to torch is not supported for this memory type
205 ImportError: If torch is not installed
206 MemoryConversionError: If conversion fails and CPU fallback is not authorized
207 """
208 if self._memory_type == MemoryType.TORCH.value:
209 # Already torch, return self (zero-copy)
210 return self
212 # Convert to torch, preserving GPU ID if possible
213 torch_data = convert_memory(
214 self._data,
215 self._memory_type,
216 MemoryType.TORCH.value,
217 gpu_id=self._gpu_id,
218 allow_cpu_roundtrip=allow_cpu_roundtrip
219 )
221 # Get the GPU ID from the result (may have changed during conversion)
222 result_gpu_id = _get_device_id(torch_data, MemoryType.TORCH.value)
224 # For GPU tensors, ensure we have a GPU ID
225 if torch_data.is_cuda and result_gpu_id is None:
226 raise MemoryConversionError(
227 source_type=self._memory_type,
228 target_type=MemoryType.TORCH.value,
229 method="device_detection",
230 reason="Failed to detect GPU ID for CUDA tensor after conversion"
231 )
233 return MemoryWrapper(torch_data, MemoryType.TORCH.value, result_gpu_id)
235 def to_tensorflow(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper":
236 """
237 Convert to tensorflow tensor and return a new MemoryWrapper.
239 Preserves the GPU device ID if possible.
241 Args:
242 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
244 Returns:
245 A new MemoryWrapper with tensorflow tensor data
247 Raises:
248 ValueError: If conversion to tensorflow is not supported for this memory type
249 ImportError: If tensorflow is not installed
250 MemoryConversionError: If conversion fails and CPU fallback is not authorized
251 """
252 if self._memory_type == MemoryType.TENSORFLOW.value:
253 # Already tensorflow, return self (zero-copy)
254 return self
256 # Convert to tensorflow, preserving GPU ID if possible
257 tf_data = convert_memory(
258 self._data,
259 self._memory_type,
260 MemoryType.TENSORFLOW.value,
261 gpu_id=self._gpu_id,
262 allow_cpu_roundtrip=allow_cpu_roundtrip
263 )
265 # Get the GPU ID from the result (may have changed during conversion)
266 result_gpu_id = _get_device_id(tf_data, MemoryType.TENSORFLOW.value)
268 # Check if this is a GPU tensor and ensure we have a GPU ID
269 device_str = tf_data.device.lower()
270 is_gpu_tensor = "gpu" in device_str
272 if is_gpu_tensor and result_gpu_id is None:
273 raise MemoryConversionError(
274 source_type=self._memory_type,
275 target_type=MemoryType.TENSORFLOW.value,
276 method="device_detection",
277 reason="Failed to detect GPU ID for TensorFlow GPU tensor after conversion"
278 )
280 return MemoryWrapper(tf_data, MemoryType.TENSORFLOW.value, result_gpu_id)
282 def to_jax(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper":
283 """
284 Convert to JAX array and return a new MemoryWrapper.
286 Preserves the GPU device ID if possible.
288 Args:
289 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
291 Returns:
292 A new MemoryWrapper with JAX array data
294 Raises:
295 ValueError: If conversion to JAX is not supported for this memory type
296 ImportError: If JAX is not installed
297 MemoryConversionError: If conversion fails and CPU fallback is not authorized
298 """
299 if self._memory_type == MemoryType.JAX.value:
300 # Already JAX, return self (zero-copy)
301 return self
303 # Convert to JAX, preserving GPU ID if possible
304 jax_data = convert_memory(
305 self._data,
306 self._memory_type,
307 MemoryType.JAX.value,
308 gpu_id=self._gpu_id,
309 allow_cpu_roundtrip=allow_cpu_roundtrip
310 )
312 # Get GPU ID from JAX array
313 result_gpu_id = _get_device_id(jax_data, MemoryType.JAX.value)
315 # Check if this is a GPU array and ensure we have a GPU ID
316 device_str = str(jax_data.device).lower()
317 is_gpu_array = "gpu" in device_str
319 if is_gpu_array and result_gpu_id is None:
320 raise MemoryConversionError(
321 source_type=self._memory_type,
322 target_type=MemoryType.JAX.value,
323 method="device_detection",
324 reason="Failed to detect GPU ID for JAX GPU array after conversion"
325 )
327 return MemoryWrapper(jax_data, MemoryType.JAX.value, result_gpu_id)
329 def to_pyclesperanto(self, allow_cpu_roundtrip: bool = False) -> "MemoryWrapper":
330 """
331 Convert to pyclesperanto array and return a new MemoryWrapper.
333 Preserves the GPU device ID if possible.
335 Args:
336 allow_cpu_roundtrip: Whether to allow fallback to CPU roundtrip
338 Returns:
339 A new MemoryWrapper with pyclesperanto array data
341 Raises:
342 ValueError: If conversion to pyclesperanto is not supported for this memory type
343 ImportError: If pyclesperanto is not installed
344 MemoryConversionError: If conversion fails and CPU fallback is not authorized
345 """
346 if self._memory_type == MemoryType.PYCLESPERANTO.value:
347 # Already pyclesperanto, return self (zero-copy)
348 return self
350 # Convert to pyclesperanto, preserving GPU ID if possible
351 pyclesperanto_data = convert_memory(
352 self._data,
353 self._memory_type,
354 MemoryType.PYCLESPERANTO.value,
355 gpu_id=self._gpu_id,
356 allow_cpu_roundtrip=allow_cpu_roundtrip
357 )
359 # Get the GPU ID from the result (may have changed during conversion)
360 result_gpu_id = _get_device_id(pyclesperanto_data, MemoryType.PYCLESPERANTO.value)
362 # Ensure we have a GPU ID for GPU memory
363 if result_gpu_id is None:
364 raise MemoryConversionError(
365 source_type=self._memory_type,
366 target_type=MemoryType.PYCLESPERANTO.value,
367 method="device_detection",
368 reason="Failed to detect GPU ID for pyclesperanto array after conversion"
369 )
371 return MemoryWrapper(pyclesperanto_data, MemoryType.PYCLESPERANTO.value, result_gpu_id)
373 def __repr__(self) -> str:
374 """
375 Get a string representation of the MemoryWrapper.
377 Returns:
378 A string representation
379 """
380 return f"MemoryWrapper(memory_type='{self._memory_type}', shape={self._get_shape()})"
382 def _get_shape(self) -> tuple:
383 """
384 Get the shape of the wrapped data array.
386 Returns:
387 The shape as a tuple
388 """
389 if self._memory_type == MemoryType.NUMPY.value:
390 return self._data.shape
391 if self._memory_type == MemoryType.CUPY.value:
392 return self._data.shape
393 if self._memory_type == MemoryType.TORCH.value:
394 return tuple(self._data.shape)
395 if self._memory_type == MemoryType.TENSORFLOW.value:
396 return tuple(self._data.shape)
397 if self._memory_type == MemoryType.PYCLESPERANTO.value:
398 return tuple(self._data.shape)
400 # This should never happen if validate_memory_type is called in __init__
401 return tuple()