Coverage for openhcs/validation/ast_validator.py: 17.7%

147 statements  

« prev     ^ index     » next       coverage.py v7.10.3, created at 2025-08-14 05:57 +0000

1""" 

2AST-based validation for openhcs. 

3 

4This module provides AST-based validation tools for enforcing type safety, 

5backend parameter validation, and architectural constraints at compile time. 

6""" 

7 

8import ast 

9import functools 

10import inspect 

11import os 

12import sys 

13from pathlib import Path 

14from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union 

15 

16# Constants for validation types 

17PATH_TYPE = "path_type" 

18BACKEND_PARAM = "backend_param" 

19MEMORY_TYPE = "memory_type" 

20VFS_BOUNDARY = "vfs_boundary" 

21 

22# Error messages 

23ERROR_INVALID_PATH_TYPE = ( 

24 "Invalid type for {0}: expected str or Path, got {1}. " 

25 "Only str and Path types are allowed, no automatic conversion is performed." 

26) 

27ERROR_MISSING_BACKEND = ( 

28 "Missing required backend parameter. " 

29 "Backend must be provided as a positional parameter." 

30) 

31ERROR_INVALID_BACKEND_TYPE = ( 

32 "Invalid type for backend parameter: expected str, got {0}." 

33) 

34ERROR_VFS_BOUNDARY = ( 

35 "VFS Boundary violation: {0}" 

36) 

37ERROR_MEMORY_TYPE = ( 

38 "Memory type violation: {0}" 

39) 

40 

41class ValidationViolation: 

42 """Represents a validation violation found during AST analysis.""" 

43 

44 def __init__(self, 

45 file_path: str, 

46 line_number: int, 

47 violation_type: str, 

48 message: str, 

49 node: Optional[ast.AST] = None): 

50 self.file_path = file_path 

51 self.line_number = line_number 

52 self.violation_type = violation_type 

53 self.message = message 

54 self.node = node 

55 

56 def __str__(self) -> str: 

57 return f"{self.file_path}:{self.line_number} - {self.violation_type}: {self.message}" 

58 

59 

60class ASTValidator(ast.NodeVisitor): 

61 """Base AST validator for static code analysis.""" 

62 

63 def __init__(self, file_path: str): 

64 self.file_path = file_path 

65 self.violations: List[ValidationViolation] = [] 

66 self.current_function: Optional[ast.FunctionDef] = None 

67 

68 def add_violation(self, 

69 node: ast.AST, 

70 violation_type: str, 

71 message: str) -> None: 

72 """Add a validation violation.""" 

73 self.violations.append( 

74 ValidationViolation( 

75 file_path=self.file_path, 

76 line_number=getattr(node, 'lineno', 0), 

77 violation_type=violation_type, 

78 message=message, 

79 node=node 

80 ) 

81 ) 

82 

83 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 

84 """Visit function definitions to check for decorators and annotations.""" 

85 old_function = self.current_function 

86 self.current_function = node 

87 self.generic_visit(node) 

88 self.current_function = old_function 

89 

90 

91class PathTypeValidator(ASTValidator): 

92 """Validates that path parameters are correctly typed as str or Path.""" 

93 

94 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 

95 """Check function parameters for path type annotations.""" 

96 super().visit_FunctionDef(node) 

97 

98 # Check for @validate_path_types decorator 

99 has_validator = False 

100 for decorator in node.decorator_list: 

101 if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name): 

102 if decorator.func.id == 'validate_path_types': 

103 has_validator = True 

104 break 

105 

106 if not has_validator: 

107 return 

108 

109 # Check parameter annotations 

110 for arg in node.args.args: 

111 if not arg.annotation: 

112 continue 

113 

114 # Check if parameter is annotated as Union[str, Path] 

115 if isinstance(arg.annotation, ast.Subscript): 

116 if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id == 'Union': 

117 # Check if Union contains str and Path 

118 if isinstance(arg.annotation.slice, ast.Index): # Python < 3.9 

119 slice_value = arg.annotation.slice.value 

120 else: # Python >= 3.9 

121 slice_value = arg.annotation.slice 

122 

123 if isinstance(slice_value, ast.Tuple): 

124 types = [elt.id for elt in slice_value.elts if isinstance(elt, ast.Name)] 

125 if 'str' in types and 'Path' in types: 

126 continue 

127 

128 # Check if parameter is annotated as str or Path 

129 if isinstance(arg.annotation, ast.Name) and arg.annotation.id in ('str', 'Path'): 

130 continue 

131 

132 # If we get here, the parameter has an invalid type annotation 

133 self.add_violation( 

134 node=arg, 

135 violation_type=PATH_TYPE, 

136 message=f"Parameter '{arg.arg}' should be annotated as Union[str, Path], str, or Path" 

137 ) 

138 

139 

140class BackendParameterValidator(ASTValidator): 

141 """Validates that backend parameters are correctly passed and typed.""" 

142 

143 def visit_Call(self, node: ast.Call) -> None: 

144 """Check function calls for backend parameter usage.""" 

145 self.generic_visit(node) 

146 

147 # Check if this is a call to a FileManager method 

148 if not isinstance(node.func, ast.Attribute): 

149 return 

150 

151 # Check if the method belongs to FileManager 

152 if not isinstance(node.func.value, ast.Name): 

153 return 

154 

155 # List of FileManager methods that require a backend parameter 

156 filemanager_methods = { 

157 'list_files', 'list_image_files', 'list_dir', 'ensure_directory', 

158 'exists', 'rename', 'mirror_directory_with_symlinks', 'create_symlink', 

159 'delete', 'copy_file', 'open_file', 'save', 'load' 

160 } 

161 

162 if node.func.attr in filemanager_methods: 

163 # Check if backend parameter is provided 

164 if not node.args or len(node.args) < 2: 

165 self.add_violation( 

166 node=node, 

167 violation_type=BACKEND_PARAM, 

168 message=f"Missing backend parameter in call to '{node.func.attr}'" 

169 ) 

170 return 

171 

172 # Check if backend parameter is a string literal or a variable 

173 backend_arg = node.args[-1] # Backend should be the last positional argument 

174 if isinstance(backend_arg, ast.Constant) and not isinstance(backend_arg.value, str): 

175 self.add_violation( 

176 node=node, 

177 violation_type=BACKEND_PARAM, 

178 message=f"Backend parameter must be a string, got {type(backend_arg.value).__name__}" 

179 ) 

180 

181 

182class VFSBoundaryValidator(ASTValidator): 

183 """Validates VFS boundary enforcement rules.""" 

184 

185 def __init__(self, file_path: str): 

186 super().__init__(file_path) 

187 self.forbidden_imports = {"os.path", "pathlib"} 

188 self.forbidden_path_constructors = {"pathlib.Path"} 

189 self.in_test_module = "/tests/" in file_path or os.path.basename(file_path).startswith("test_") 

190 

191 def visit_ImportFrom(self, node: ast.ImportFrom) -> None: 

192 """Check for forbidden imports.""" 

193 self.generic_visit(node) 

194 

195 if self.in_test_module: 

196 return 

197 

198 if node.module in self.forbidden_imports: 

199 self.add_violation( 

200 node=node, 

201 violation_type=VFS_BOUNDARY, 

202 message=f"Forbidden import from '{node.module}'" 

203 ) 

204 

205 def visit_Import(self, node: ast.Import) -> None: 

206 """Check for forbidden imports.""" 

207 self.generic_visit(node) 

208 

209 if self.in_test_module: 

210 return 

211 

212 for alias in node.names: 

213 if alias.name in self.forbidden_imports: 

214 self.add_violation( 

215 node=node, 

216 violation_type=VFS_BOUNDARY, 

217 message=f"Forbidden import of '{alias.name}'" 

218 ) 

219 

220 def visit_Call(self, node: ast.Call) -> None: 

221 """Check for forbidden path constructors and to_os_path() calls.""" 

222 self.generic_visit(node) 

223 

224 if self.in_test_module: 

225 return 

226 

227 # Check for Path() constructor 

228 if isinstance(node.func, ast.Name) and node.func.id == 'Path': 

229 self.add_violation( 

230 node=node, 

231 violation_type=VFS_BOUNDARY, 

232 message="Direct Path constructor usage is forbidden" 

233 ) 

234 

235 # Check for to_os_path() method calls 

236 if isinstance(node.func, ast.Attribute) and node.func.attr == 'to_os_path': 

237 self.add_violation( 

238 node=node, 

239 violation_type=VFS_BOUNDARY, 

240 message="to_os_path() method can only be used in functions decorated with @vfs_escape_hatch" 

241 ) 

242 

243 

244class MemoryTypeValidator(ASTValidator): 

245 """Validates memory type declarations and usage.""" 

246 

247 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 

248 """Check function decorators for memory type declarations.""" 

249 super().visit_FunctionDef(node) 

250 

251 # Check for memory type decorators 

252 memory_decorators = {'numpy', 'cupy', 'torch', 'tensorflow', 'jax'} 

253 has_memory_decorator = False 

254 

255 for decorator in node.decorator_list: 

256 if isinstance(decorator, ast.Name) and decorator.id in memory_decorators: 

257 has_memory_decorator = True 

258 break 

259 elif isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) and decorator.func.id in memory_decorators: 

260 has_memory_decorator = True 

261 break 

262 

263 # Check if function is in processing module and missing memory decorator 

264 if not has_memory_decorator and 'processing' in self.file_path: 

265 # Skip if function is private or a method 

266 if not node.name.startswith('_') and self.current_function is None: 

267 self.add_violation( 

268 node=node, 

269 violation_type=MEMORY_TYPE, 

270 message=f"Function '{node.name}' in processing module should have a memory type decorator" 

271 ) 

272 

273 

274# Decorator functions for runtime validation 

275 

276def validate_path_types(**type_annotations): 

277 """ 

278 Decorator to validate path type parameters. 

279  

280 Args: 

281 **type_annotations: Type annotations for parameters. 

282  

283 Returns: 

284 Decorated function. 

285 """ 

286 def decorator(func): 

287 @functools.wraps(func) 

288 def wrapper(*args, **kwargs): 

289 # Runtime validation can be added here if needed 

290 return func(*args, **kwargs) 

291 

292 # Store type annotations for AST analysis 

293 wrapper.__path_type_annotations__ = type_annotations 

294 return wrapper 

295 

296 return decorator 

297 

298 

299def validate_backend_parameter(func): 

300 """ 

301 Decorator to validate backend parameter. 

302  

303 Args: 

304 func: Function to decorate. 

305  

306 Returns: 

307 Decorated function. 

308 """ 

309 @functools.wraps(func) 

310 def wrapper(*args, **kwargs): 

311 # Runtime validation can be added here if needed 

312 return func(*args, **kwargs) 

313 

314 # Mark function for AST analysis 

315 wrapper.__validate_backend__ = True 

316 return wrapper 

317 

318 

319# Main validation function 

320 

321def validate_file(file_path: str) -> List[ValidationViolation]: 

322 """ 

323 Validate a Python file using AST-based analysis. 

324  

325 Args: 

326 file_path: Path to the Python file. 

327  

328 Returns: 

329 List of validation violations. 

330 """ 

331 with open(file_path, 'r', encoding='utf-8') as f: 

332 source = f.read() 

333 

334 try: 

335 tree = ast.parse(source, filename=file_path) 

336 except SyntaxError as e: 

337 return [ValidationViolation( 

338 file_path=file_path, 

339 line_number=e.lineno or 0, 

340 violation_type="syntax_error", 

341 message=f"Syntax error: {e}", 

342 node=None 

343 )] 

344 

345 violations = [] 

346 

347 # Run all validators 

348 validators = [ 

349 PathTypeValidator(file_path), 

350 BackendParameterValidator(file_path), 

351 VFSBoundaryValidator(file_path), 

352 MemoryTypeValidator(file_path) 

353 ] 

354 

355 for validator in validators: 

356 validator.visit(tree) 

357 violations.extend(validator.violations) 

358 

359 return violations