Coverage for ezstitcher/core/utils.py: 81%

156 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2025-04-30 13:20 +0000

1""" 

2Utility functions for the EZStitcher package. 

3""" 

4 

5import threading 

6import time 

7import functools 

8import logging 

9from collections import defaultdict 

10from typing import Dict, List, Any, Callable, Optional 

11 

12logger = logging.getLogger(__name__) 

13 

14# Global thread activity tracking 

15thread_activity = defaultdict(list) 

16active_threads = set() 

17thread_lock = threading.Lock() 

18 

19def get_thread_activity() -> Dict[int, List[Dict[str, Any]]]: 

20 """ 

21 Get the current thread activity data. 

22 

23 Returns: 

24 Dict mapping thread IDs to lists of activity records 

25 """ 

26 return thread_activity 

27 

28def get_active_threads() -> set: 

29 """ 

30 Get the set of currently active thread IDs. 

31 

32 Returns: 

33 Set of active thread IDs 

34 """ 

35 return active_threads 

36 

37def clear_thread_activity(): 

38 """Clear all thread activity data.""" 

39 with thread_lock: 

40 thread_activity.clear() 

41 active_threads.clear() 

42 

43def track_thread_activity(func: Optional[Callable] = None, *, log_level: str = "info"): 

44 """ 

45 Decorator to track thread activity for a function. 

46 

47 Args: 

48 func: The function to decorate 

49 log_level: Logging level to use ("debug", "info", "warning", "error") 

50 

51 Returns: 

52 Decorated function that tracks thread activity 

53 """ 

54 def decorator(f): 

55 @functools.wraps(f) 

56 def wrapper(*args, **kwargs): 

57 # Get thread information 

58 thread_id = threading.get_ident() 

59 thread_name = threading.current_thread().name 

60 

61 # Record thread start time 

62 start_time = time.time() 

63 

64 # Extract function name and arguments for context 

65 func_name = f.__name__ 

66 # Get the first argument if it's a method (self or cls) 

67 context = "" 

68 if args and hasattr(args[0], "__class__"): 

69 if hasattr(args[0].__class__, func_name): 

70 # It's likely a method, extract class name 

71 context = f"{args[0].__class__.__name__}." 

72 

73 # Extract well information if present in kwargs or args 

74 well = kwargs.get('well', None) 

75 if well is None and len(args) > 1 and isinstance(args[1], str): 

76 # Assume second argument might be well in methods like process_well(self, well, ...) 

77 well = args[1] 

78 

79 # Add this thread to active threads 

80 with thread_lock: 

81 active_threads.add(thread_id) 

82 # Record the number of active threads at this moment 

83 thread_activity[thread_id].append({ 

84 'well': well, 

85 'thread_name': thread_name, 

86 'time': time.time(), 

87 'action': 'start', 

88 'function': f"{context}{func_name}", 

89 'active_threads': len(active_threads) 

90 }) 

91 

92 # Log the start of the function 

93 log_func = getattr(logger, log_level.lower()) 

94 log_func(f"Thread {thread_name} (ID: {thread_id}) started {context}{func_name} for well {well}") 

95 log_func(f"Active threads: {len(active_threads)}") 

96 

97 try: 

98 # Call the original function 

99 result = f(*args, **kwargs) 

100 return result 

101 finally: 

102 # Record thread end time 

103 end_time = time.time() 

104 duration = end_time - start_time 

105 

106 # Remove this thread from active threads 

107 with thread_lock: 

108 active_threads.remove(thread_id) 

109 # Record the number of active threads at this moment 

110 thread_activity[thread_id].append({ 

111 'well': well, 

112 'thread_name': thread_name, 

113 'time': time.time(), 

114 'action': 'end', 

115 'function': f"{context}{func_name}", 

116 'duration': duration, 

117 'active_threads': len(active_threads) 

118 }) 

119 

120 log_func(f"Thread {thread_name} (ID: {thread_id}) finished {context}{func_name} for well {well} in {duration:.2f} seconds") 

121 log_func(f"Active threads: {len(active_threads)}") 

122 

123 return wrapper 

124 

125 # Handle both @track_thread_activity and @track_thread_activity(log_level="debug") 

126 if func is None: 

127 return decorator 

128 return decorator(func) 

129 

130def analyze_thread_activity(): 

131 """ 

132 Analyze thread activity data and return a report. 

133 

134 Returns: 

135 Dict containing analysis results 

136 """ 

137 max_concurrent = 0 

138 thread_starts = [] 

139 thread_ends = [] 

140 

141 for thread_id, activities in thread_activity.items(): 

142 for activity in activities: 

143 max_concurrent = max(max_concurrent, activity['active_threads']) 

144 if activity['action'] == 'start': 

145 thread_starts.append(( 

146 activity.get('well'), 

147 activity['thread_name'], 

148 activity['time'], 

149 activity.get('function', '') 

150 )) 

151 else: # 'end' 

152 thread_ends.append(( 

153 activity.get('well'), 

154 activity['thread_name'], 

155 activity['time'], 

156 activity.get('duration', 0), 

157 activity.get('function', '') 

158 )) 

159 

160 # Sort by time 

161 thread_starts.sort(key=lambda x: x[2]) 

162 thread_ends.sort(key=lambda x: x[2]) 

163 

164 # Find overlapping time periods 

165 overlaps = [] 

166 for i, (well1, thread1, start1, func1) in enumerate(thread_starts): 

167 # Find the end time for this thread 

168 end1 = None 

169 for w, t, end, d, f in thread_ends: 

170 if t == thread1 and w == well1 and f == func1: 

171 end1 = end 

172 break 

173 

174 if end1 is None: 

175 continue # Skip if we can't find the end time 

176 

177 # Check for overlaps with other threads 

178 for j, (well2, thread2, start2, func2) in enumerate(thread_starts): 

179 if i == j or thread1 == thread2: # Skip same thread 

180 continue 

181 

182 # Find the end time for the other thread 

183 end2 = None 

184 for w, t, end, d, f in thread_ends: 

185 if t == thread2 and w == well2 and f == func2: 

186 end2 = end 

187 break 

188 

189 if end2 is None: 

190 continue # Skip if we can't find the end time 

191 

192 # Check if there's an overlap 

193 if start1 < end2 and start2 < end1: 

194 overlap_start = max(start1, start2) 

195 overlap_end = min(end1, end2) 

196 overlap_duration = overlap_end - overlap_start 

197 

198 if overlap_duration > 0: 

199 overlaps.append({ 

200 'thread1': thread1, 

201 'well1': well1, 

202 'function1': func1, 

203 'thread2': thread2, 

204 'well2': well2, 

205 'function2': func2, 

206 'duration': overlap_duration 

207 }) 

208 

209 return { 

210 'max_concurrent': max_concurrent, 

211 'thread_starts': thread_starts, 

212 'thread_ends': thread_ends, 

213 'overlaps': overlaps 

214 } 

215 

216def print_thread_activity_report(): 

217 """Print a detailed report of thread activity.""" 

218 analysis = analyze_thread_activity() 

219 

220 print("\n" + "=" * 80) 

221 print("Thread Activity Report") 

222 print("=" * 80) 

223 

224 print("\nThread Start Events:") 

225 for well, thread_name, time_val, func in analysis['thread_starts']: 

226 print(f"Thread {thread_name} started {func} for well {well} at {time_val:.2f}") 

227 

228 print("\nThread End Events:") 

229 for well, thread_name, time_val, duration, func in analysis['thread_ends']: 

230 print(f"Thread {thread_name} finished {func} for well {well} at {time_val:.2f} (duration: {duration:.2f}s)") 

231 

232 print("\nOverlap Analysis:") 

233 for overlap in analysis['overlaps']: 

234 print(f"Threads {overlap['thread1']} and {overlap['thread2']} overlapped for {overlap['duration']:.2f}s") 

235 print(f" {overlap['thread1']} was processing {overlap['function1']} for well {overlap['well1']}") 

236 print(f" {overlap['thread2']} was processing {overlap['function2']} for well {overlap['well2']}") 

237 

238 print(f"\nFound {len(analysis['overlaps'])} thread overlaps") 

239 print(f"Maximum concurrent threads: {analysis['max_concurrent']}") 

240 print("=" * 80) 

241 

242 return analysis 

243 

244 

245import numpy as np 

246 

247 

248def prepare_patterns_and_functions(patterns, processing_funcs, component='default'): 

249 """ 

250 Prepare patterns, processing functions, and processing args for processing. 

251 

252 This function handles three main tasks: 

253 1. Ensuring patterns are in a component-keyed dictionary format 

254 2. Determining which processing functions to use for each component 

255 3. Determining which processing args to use for each component 

256 

257 Args: 

258 patterns (list or dict): Patterns to process, either as a flat list or grouped by component 

259 processing_funcs (callable, list, dict, tuple, optional): Processing functions to apply. 

260 Can be a single callable, a tuple of (callable, kwargs), a list of either, 

261 or a dictionary mapping component values to any of these. 

262 component (str): Component name for grouping (only used for clarity in the result) 

263 

264 Returns: 

265 tuple: (grouped_patterns, component_to_funcs, component_to_args) 

266 - grouped_patterns: Dictionary mapping component values to patterns 

267 - component_to_funcs: Dictionary mapping component values to processing functions 

268 - component_to_args: Dictionary mapping component values to processing args 

269 """ 

270 # Ensure patterns are in a dictionary format 

271 # If already a dict, use as is; otherwise wrap the list in a dictionary 

272 grouped_patterns = patterns if isinstance(patterns, dict) else {component: patterns} 

273 

274 # Initialize dictionaries for functions and args 

275 component_to_funcs = {} 

276 component_to_args = {} 

277 

278 # Helper function to extract function and args from a function item 

279 def extract_func_and_args(func_item): 

280 if isinstance(func_item, tuple) and len(func_item) == 2 and callable(func_item[0]): 

281 # It's a (function, kwargs) tuple 

282 return func_item[0], func_item[1] 

283 if callable(func_item): 

284 # It's just a function, use default args 

285 return func_item, {} 

286 # Invalid function item 

287 logger.warning( 

288 "Invalid function item: %s. Expected callable or (callable, kwargs) tuple.", 

289 str(func_item) 

290 ) 

291 # Return a dummy function that returns the input unchanged 

292 return lambda x, **kwargs: x, {} 

293 

294 for comp_value in grouped_patterns.keys(): 

295 # Get functions and args for this component 

296 if isinstance(processing_funcs, dict) and comp_value in processing_funcs: 

297 # Direct mapping for this component 

298 func_item = processing_funcs[comp_value] 

299 elif isinstance(processing_funcs, dict) and component == 'channel': 

300 # For channel grouping, use the channel-specific function if available 

301 func_item = processing_funcs.get(comp_value, processing_funcs) 

302 else: 

303 # Use the same function for all components 

304 func_item = processing_funcs 

305 

306 # Extract function and args 

307 if isinstance(func_item, list): 

308 # List of functions or function tuples 

309 component_to_funcs[comp_value] = func_item 

310 # For lists, we'll extract args during processing 

311 component_to_args[comp_value] = {} 

312 else: 

313 # Single function or function tuple 

314 func, args = extract_func_and_args(func_item) 

315 component_to_funcs[comp_value] = func 

316 component_to_args[comp_value] = args 

317 

318 return grouped_patterns, component_to_funcs, component_to_args 

319 

320 

321def stack(single_image_func: Callable) -> Callable[[List[np.ndarray], Optional[Dict[str, Any]]], List[np.ndarray]]: 

322 """Wraps a function designed for single images to operate on a stack (list) of images. 

323 

324 Args: 

325 single_image_func: A function that processes a single numpy array image 

326 and returns a processed numpy array image. 

327 

328 Returns: 

329 A new function that accepts a list of images and keyword arguments, 

330 applies the original function to each image in the list, and returns 

331 a list of the processed images. 

332 """ 

333 @functools.wraps(single_image_func) 

334 def stack_wrapper(images: List[np.ndarray], **kwargs) -> List[np.ndarray]: 

335 """Applies the wrapped single-image function to each image in the stack.""" 

336 processed_stack = [] 

337 if not images: 

338 return processed_stack # Return empty list if input is empty 

339 

340 for img in images: 

341 try: 

342 # Pass only the image and any relevant kwargs accepted by the original function 

343 # Inspecting signature might be needed for robustness, but start simple 

344 processed_img = single_image_func(img, **kwargs) 

345 if processed_img is not None: 

346 processed_stack.append(processed_img) 

347 else: 

348 logger.warning(f"Function {single_image_func.__name__} returned None for an image. Skipping.") 

349 except Exception as e: 

350 logger.error(f"Error applying {single_image_func.__name__} to an image in the stack: {e}. Skipping image.") 

351 return processed_stack 

352 

353 # Attempt to give the wrapper a more informative name 

354 try: 

355 stack_wrapper.__name__ = f"stacked_{single_image_func.__name__}" 

356 except AttributeError: 

357 pass # Some callables might not have __name__ 

358 

359 return stack_wrapper 

360