Coverage for ezstitcher/core/utils.py: 81%
156 statements
« prev ^ index » next coverage.py v7.3.2, created at 2025-04-30 13:20 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2025-04-30 13:20 +0000
1"""
2Utility functions for the EZStitcher package.
3"""
5import threading
6import time
7import functools
8import logging
9from collections import defaultdict
10from typing import Dict, List, Any, Callable, Optional
12logger = logging.getLogger(__name__)
14# Global thread activity tracking
15thread_activity = defaultdict(list)
16active_threads = set()
17thread_lock = threading.Lock()
19def get_thread_activity() -> Dict[int, List[Dict[str, Any]]]:
20 """
21 Get the current thread activity data.
23 Returns:
24 Dict mapping thread IDs to lists of activity records
25 """
26 return thread_activity
28def get_active_threads() -> set:
29 """
30 Get the set of currently active thread IDs.
32 Returns:
33 Set of active thread IDs
34 """
35 return active_threads
37def clear_thread_activity():
38 """Clear all thread activity data."""
39 with thread_lock:
40 thread_activity.clear()
41 active_threads.clear()
43def track_thread_activity(func: Optional[Callable] = None, *, log_level: str = "info"):
44 """
45 Decorator to track thread activity for a function.
47 Args:
48 func: The function to decorate
49 log_level: Logging level to use ("debug", "info", "warning", "error")
51 Returns:
52 Decorated function that tracks thread activity
53 """
54 def decorator(f):
55 @functools.wraps(f)
56 def wrapper(*args, **kwargs):
57 # Get thread information
58 thread_id = threading.get_ident()
59 thread_name = threading.current_thread().name
61 # Record thread start time
62 start_time = time.time()
64 # Extract function name and arguments for context
65 func_name = f.__name__
66 # Get the first argument if it's a method (self or cls)
67 context = ""
68 if args and hasattr(args[0], "__class__"):
69 if hasattr(args[0].__class__, func_name):
70 # It's likely a method, extract class name
71 context = f"{args[0].__class__.__name__}."
73 # Extract well information if present in kwargs or args
74 well = kwargs.get('well', None)
75 if well is None and len(args) > 1 and isinstance(args[1], str):
76 # Assume second argument might be well in methods like process_well(self, well, ...)
77 well = args[1]
79 # Add this thread to active threads
80 with thread_lock:
81 active_threads.add(thread_id)
82 # Record the number of active threads at this moment
83 thread_activity[thread_id].append({
84 'well': well,
85 'thread_name': thread_name,
86 'time': time.time(),
87 'action': 'start',
88 'function': f"{context}{func_name}",
89 'active_threads': len(active_threads)
90 })
92 # Log the start of the function
93 log_func = getattr(logger, log_level.lower())
94 log_func(f"Thread {thread_name} (ID: {thread_id}) started {context}{func_name} for well {well}")
95 log_func(f"Active threads: {len(active_threads)}")
97 try:
98 # Call the original function
99 result = f(*args, **kwargs)
100 return result
101 finally:
102 # Record thread end time
103 end_time = time.time()
104 duration = end_time - start_time
106 # Remove this thread from active threads
107 with thread_lock:
108 active_threads.remove(thread_id)
109 # Record the number of active threads at this moment
110 thread_activity[thread_id].append({
111 'well': well,
112 'thread_name': thread_name,
113 'time': time.time(),
114 'action': 'end',
115 'function': f"{context}{func_name}",
116 'duration': duration,
117 'active_threads': len(active_threads)
118 })
120 log_func(f"Thread {thread_name} (ID: {thread_id}) finished {context}{func_name} for well {well} in {duration:.2f} seconds")
121 log_func(f"Active threads: {len(active_threads)}")
123 return wrapper
125 # Handle both @track_thread_activity and @track_thread_activity(log_level="debug")
126 if func is None:
127 return decorator
128 return decorator(func)
130def analyze_thread_activity():
131 """
132 Analyze thread activity data and return a report.
134 Returns:
135 Dict containing analysis results
136 """
137 max_concurrent = 0
138 thread_starts = []
139 thread_ends = []
141 for thread_id, activities in thread_activity.items():
142 for activity in activities:
143 max_concurrent = max(max_concurrent, activity['active_threads'])
144 if activity['action'] == 'start':
145 thread_starts.append((
146 activity.get('well'),
147 activity['thread_name'],
148 activity['time'],
149 activity.get('function', '')
150 ))
151 else: # 'end'
152 thread_ends.append((
153 activity.get('well'),
154 activity['thread_name'],
155 activity['time'],
156 activity.get('duration', 0),
157 activity.get('function', '')
158 ))
160 # Sort by time
161 thread_starts.sort(key=lambda x: x[2])
162 thread_ends.sort(key=lambda x: x[2])
164 # Find overlapping time periods
165 overlaps = []
166 for i, (well1, thread1, start1, func1) in enumerate(thread_starts):
167 # Find the end time for this thread
168 end1 = None
169 for w, t, end, d, f in thread_ends:
170 if t == thread1 and w == well1 and f == func1:
171 end1 = end
172 break
174 if end1 is None:
175 continue # Skip if we can't find the end time
177 # Check for overlaps with other threads
178 for j, (well2, thread2, start2, func2) in enumerate(thread_starts):
179 if i == j or thread1 == thread2: # Skip same thread
180 continue
182 # Find the end time for the other thread
183 end2 = None
184 for w, t, end, d, f in thread_ends:
185 if t == thread2 and w == well2 and f == func2:
186 end2 = end
187 break
189 if end2 is None:
190 continue # Skip if we can't find the end time
192 # Check if there's an overlap
193 if start1 < end2 and start2 < end1:
194 overlap_start = max(start1, start2)
195 overlap_end = min(end1, end2)
196 overlap_duration = overlap_end - overlap_start
198 if overlap_duration > 0:
199 overlaps.append({
200 'thread1': thread1,
201 'well1': well1,
202 'function1': func1,
203 'thread2': thread2,
204 'well2': well2,
205 'function2': func2,
206 'duration': overlap_duration
207 })
209 return {
210 'max_concurrent': max_concurrent,
211 'thread_starts': thread_starts,
212 'thread_ends': thread_ends,
213 'overlaps': overlaps
214 }
216def print_thread_activity_report():
217 """Print a detailed report of thread activity."""
218 analysis = analyze_thread_activity()
220 print("\n" + "=" * 80)
221 print("Thread Activity Report")
222 print("=" * 80)
224 print("\nThread Start Events:")
225 for well, thread_name, time_val, func in analysis['thread_starts']:
226 print(f"Thread {thread_name} started {func} for well {well} at {time_val:.2f}")
228 print("\nThread End Events:")
229 for well, thread_name, time_val, duration, func in analysis['thread_ends']:
230 print(f"Thread {thread_name} finished {func} for well {well} at {time_val:.2f} (duration: {duration:.2f}s)")
232 print("\nOverlap Analysis:")
233 for overlap in analysis['overlaps']:
234 print(f"Threads {overlap['thread1']} and {overlap['thread2']} overlapped for {overlap['duration']:.2f}s")
235 print(f" {overlap['thread1']} was processing {overlap['function1']} for well {overlap['well1']}")
236 print(f" {overlap['thread2']} was processing {overlap['function2']} for well {overlap['well2']}")
238 print(f"\nFound {len(analysis['overlaps'])} thread overlaps")
239 print(f"Maximum concurrent threads: {analysis['max_concurrent']}")
240 print("=" * 80)
242 return analysis
245import numpy as np
248def prepare_patterns_and_functions(patterns, processing_funcs, component='default'):
249 """
250 Prepare patterns, processing functions, and processing args for processing.
252 This function handles three main tasks:
253 1. Ensuring patterns are in a component-keyed dictionary format
254 2. Determining which processing functions to use for each component
255 3. Determining which processing args to use for each component
257 Args:
258 patterns (list or dict): Patterns to process, either as a flat list or grouped by component
259 processing_funcs (callable, list, dict, tuple, optional): Processing functions to apply.
260 Can be a single callable, a tuple of (callable, kwargs), a list of either,
261 or a dictionary mapping component values to any of these.
262 component (str): Component name for grouping (only used for clarity in the result)
264 Returns:
265 tuple: (grouped_patterns, component_to_funcs, component_to_args)
266 - grouped_patterns: Dictionary mapping component values to patterns
267 - component_to_funcs: Dictionary mapping component values to processing functions
268 - component_to_args: Dictionary mapping component values to processing args
269 """
270 # Ensure patterns are in a dictionary format
271 # If already a dict, use as is; otherwise wrap the list in a dictionary
272 grouped_patterns = patterns if isinstance(patterns, dict) else {component: patterns}
274 # Initialize dictionaries for functions and args
275 component_to_funcs = {}
276 component_to_args = {}
278 # Helper function to extract function and args from a function item
279 def extract_func_and_args(func_item):
280 if isinstance(func_item, tuple) and len(func_item) == 2 and callable(func_item[0]):
281 # It's a (function, kwargs) tuple
282 return func_item[0], func_item[1]
283 if callable(func_item):
284 # It's just a function, use default args
285 return func_item, {}
286 # Invalid function item
287 logger.warning(
288 "Invalid function item: %s. Expected callable or (callable, kwargs) tuple.",
289 str(func_item)
290 )
291 # Return a dummy function that returns the input unchanged
292 return lambda x, **kwargs: x, {}
294 for comp_value in grouped_patterns.keys():
295 # Get functions and args for this component
296 if isinstance(processing_funcs, dict) and comp_value in processing_funcs:
297 # Direct mapping for this component
298 func_item = processing_funcs[comp_value]
299 elif isinstance(processing_funcs, dict) and component == 'channel':
300 # For channel grouping, use the channel-specific function if available
301 func_item = processing_funcs.get(comp_value, processing_funcs)
302 else:
303 # Use the same function for all components
304 func_item = processing_funcs
306 # Extract function and args
307 if isinstance(func_item, list):
308 # List of functions or function tuples
309 component_to_funcs[comp_value] = func_item
310 # For lists, we'll extract args during processing
311 component_to_args[comp_value] = {}
312 else:
313 # Single function or function tuple
314 func, args = extract_func_and_args(func_item)
315 component_to_funcs[comp_value] = func
316 component_to_args[comp_value] = args
318 return grouped_patterns, component_to_funcs, component_to_args
321def stack(single_image_func: Callable) -> Callable[[List[np.ndarray], Optional[Dict[str, Any]]], List[np.ndarray]]:
322 """Wraps a function designed for single images to operate on a stack (list) of images.
324 Args:
325 single_image_func: A function that processes a single numpy array image
326 and returns a processed numpy array image.
328 Returns:
329 A new function that accepts a list of images and keyword arguments,
330 applies the original function to each image in the list, and returns
331 a list of the processed images.
332 """
333 @functools.wraps(single_image_func)
334 def stack_wrapper(images: List[np.ndarray], **kwargs) -> List[np.ndarray]:
335 """Applies the wrapped single-image function to each image in the stack."""
336 processed_stack = []
337 if not images:
338 return processed_stack # Return empty list if input is empty
340 for img in images:
341 try:
342 # Pass only the image and any relevant kwargs accepted by the original function
343 # Inspecting signature might be needed for robustness, but start simple
344 processed_img = single_image_func(img, **kwargs)
345 if processed_img is not None:
346 processed_stack.append(processed_img)
347 else:
348 logger.warning(f"Function {single_image_func.__name__} returned None for an image. Skipping.")
349 except Exception as e:
350 logger.error(f"Error applying {single_image_func.__name__} to an image in the stack: {e}. Skipping image.")
351 return processed_stack
353 # Attempt to give the wrapper a more informative name
354 try:
355 stack_wrapper.__name__ = f"stacked_{single_image_func.__name__}"
356 except AttributeError:
357 pass # Some callables might not have __name__
359 return stack_wrapper