Coverage for openhcs/validation/ast_validator.py: 0.0%

144 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 02:09 +0000

1""" 

2AST-based validation for openhcs. 

3 

4This module provides AST-based validation tools for enforcing type safety, 

5backend parameter validation, and architectural constraints at compile time. 

6""" 

7 

8import ast 

9import functools 

10import os 

11from typing import List, Optional 

12 

13# Constants for validation types 

14PATH_TYPE = "path_type" 

15BACKEND_PARAM = "backend_param" 

16MEMORY_TYPE = "memory_type" 

17VFS_BOUNDARY = "vfs_boundary" 

18 

19# Error messages 

20ERROR_INVALID_PATH_TYPE = ( 

21 "Invalid type for {0}: expected str or Path, got {1}. " 

22 "Only str and Path types are allowed, no automatic conversion is performed." 

23) 

24ERROR_MISSING_BACKEND = ( 

25 "Missing required backend parameter. " 

26 "Backend must be provided as a positional parameter." 

27) 

28ERROR_INVALID_BACKEND_TYPE = ( 

29 "Invalid type for backend parameter: expected str, got {0}." 

30) 

31ERROR_VFS_BOUNDARY = ( 

32 "VFS Boundary violation: {0}" 

33) 

34ERROR_MEMORY_TYPE = ( 

35 "Memory type violation: {0}" 

36) 

37 

38class ValidationViolation: 

39 """Represents a validation violation found during AST analysis.""" 

40 

41 def __init__(self, 

42 file_path: str, 

43 line_number: int, 

44 violation_type: str, 

45 message: str, 

46 node: Optional[ast.AST] = None): 

47 self.file_path = file_path 

48 self.line_number = line_number 

49 self.violation_type = violation_type 

50 self.message = message 

51 self.node = node 

52 

53 def __str__(self) -> str: 

54 return f"{self.file_path}:{self.line_number} - {self.violation_type}: {self.message}" 

55 

56 

57class ASTValidator(ast.NodeVisitor): 

58 """Base AST validator for static code analysis.""" 

59 

60 def __init__(self, file_path: str): 

61 self.file_path = file_path 

62 self.violations: List[ValidationViolation] = [] 

63 self.current_function: Optional[ast.FunctionDef] = None 

64 

65 def add_violation(self, 

66 node: ast.AST, 

67 violation_type: str, 

68 message: str) -> None: 

69 """Add a validation violation.""" 

70 self.violations.append( 

71 ValidationViolation( 

72 file_path=self.file_path, 

73 line_number=getattr(node, 'lineno', 0), 

74 violation_type=violation_type, 

75 message=message, 

76 node=node 

77 ) 

78 ) 

79 

80 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 

81 """Visit function definitions to check for decorators and annotations.""" 

82 old_function = self.current_function 

83 self.current_function = node 

84 self.generic_visit(node) 

85 self.current_function = old_function 

86 

87 

88class PathTypeValidator(ASTValidator): 

89 """Validates that path parameters are correctly typed as str or Path.""" 

90 

91 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 

92 """Check function parameters for path type annotations.""" 

93 super().visit_FunctionDef(node) 

94 

95 # Check for @validate_path_types decorator 

96 has_validator = False 

97 for decorator in node.decorator_list: 

98 if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name): 

99 if decorator.func.id == 'validate_path_types': 

100 has_validator = True 

101 break 

102 

103 if not has_validator: 

104 return 

105 

106 # Check parameter annotations 

107 for arg in node.args.args: 

108 if not arg.annotation: 

109 continue 

110 

111 # Check if parameter is annotated as Union[str, Path] 

112 if isinstance(arg.annotation, ast.Subscript): 

113 if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id == 'Union': 

114 # Check if Union contains str and Path 

115 if isinstance(arg.annotation.slice, ast.Index): # Python < 3.9 

116 slice_value = arg.annotation.slice.value 

117 else: # Python >= 3.9 

118 slice_value = arg.annotation.slice 

119 

120 if isinstance(slice_value, ast.Tuple): 

121 types = [elt.id for elt in slice_value.elts if isinstance(elt, ast.Name)] 

122 if 'str' in types and 'Path' in types: 

123 continue 

124 

125 # Check if parameter is annotated as str or Path 

126 if isinstance(arg.annotation, ast.Name) and arg.annotation.id in ('str', 'Path'): 

127 continue 

128 

129 # If we get here, the parameter has an invalid type annotation 

130 self.add_violation( 

131 node=arg, 

132 violation_type=PATH_TYPE, 

133 message=f"Parameter '{arg.arg}' should be annotated as Union[str, Path], str, or Path" 

134 ) 

135 

136 

137class BackendParameterValidator(ASTValidator): 

138 """Validates that backend parameters are correctly passed and typed.""" 

139 

140 def visit_Call(self, node: ast.Call) -> None: 

141 """Check function calls for backend parameter usage.""" 

142 self.generic_visit(node) 

143 

144 # Check if this is a call to a FileManager method 

145 if not isinstance(node.func, ast.Attribute): 

146 return 

147 

148 # Check if the method belongs to FileManager 

149 if not isinstance(node.func.value, ast.Name): 

150 return 

151 

152 # List of FileManager methods that require a backend parameter 

153 filemanager_methods = { 

154 'list_files', 'list_image_files', 'list_dir', 'ensure_directory', 

155 'exists', 'rename', 'mirror_directory_with_symlinks', 'create_symlink', 

156 'delete', 'copy_file', 'open_file', 'save', 'load' 

157 } 

158 

159 if node.func.attr in filemanager_methods: 

160 # Check if backend parameter is provided 

161 if not node.args or len(node.args) < 2: 

162 self.add_violation( 

163 node=node, 

164 violation_type=BACKEND_PARAM, 

165 message=f"Missing backend parameter in call to '{node.func.attr}'" 

166 ) 

167 return 

168 

169 # Check if backend parameter is a string literal or a variable 

170 backend_arg = node.args[-1] # Backend should be the last positional argument 

171 if isinstance(backend_arg, ast.Constant) and not isinstance(backend_arg.value, str): 

172 self.add_violation( 

173 node=node, 

174 violation_type=BACKEND_PARAM, 

175 message=f"Backend parameter must be a string, got {type(backend_arg.value).__name__}" 

176 ) 

177 

178 

179class VFSBoundaryValidator(ASTValidator): 

180 """Validates VFS boundary enforcement rules.""" 

181 

182 def __init__(self, file_path: str): 

183 super().__init__(file_path) 

184 self.forbidden_imports = {"os.path", "pathlib"} 

185 self.forbidden_path_constructors = {"pathlib.Path"} 

186 self.in_test_module = "/tests/" in file_path or os.path.basename(file_path).startswith("test_") 

187 

188 def visit_ImportFrom(self, node: ast.ImportFrom) -> None: 

189 """Check for forbidden imports.""" 

190 self.generic_visit(node) 

191 

192 if self.in_test_module: 

193 return 

194 

195 if node.module in self.forbidden_imports: 

196 self.add_violation( 

197 node=node, 

198 violation_type=VFS_BOUNDARY, 

199 message=f"Forbidden import from '{node.module}'" 

200 ) 

201 

202 def visit_Import(self, node: ast.Import) -> None: 

203 """Check for forbidden imports.""" 

204 self.generic_visit(node) 

205 

206 if self.in_test_module: 

207 return 

208 

209 for alias in node.names: 

210 if alias.name in self.forbidden_imports: 

211 self.add_violation( 

212 node=node, 

213 violation_type=VFS_BOUNDARY, 

214 message=f"Forbidden import of '{alias.name}'" 

215 ) 

216 

217 def visit_Call(self, node: ast.Call) -> None: 

218 """Check for forbidden path constructors and to_os_path() calls.""" 

219 self.generic_visit(node) 

220 

221 if self.in_test_module: 

222 return 

223 

224 # Check for Path() constructor 

225 if isinstance(node.func, ast.Name) and node.func.id == 'Path': 

226 self.add_violation( 

227 node=node, 

228 violation_type=VFS_BOUNDARY, 

229 message="Direct Path constructor usage is forbidden" 

230 ) 

231 

232 # Check for to_os_path() method calls 

233 if isinstance(node.func, ast.Attribute) and node.func.attr == 'to_os_path': 

234 self.add_violation( 

235 node=node, 

236 violation_type=VFS_BOUNDARY, 

237 message="to_os_path() method can only be used in functions decorated with @vfs_escape_hatch" 

238 ) 

239 

240 

241class MemoryTypeValidator(ASTValidator): 

242 """Validates memory type declarations and usage.""" 

243 

244 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 

245 """Check function decorators for memory type declarations.""" 

246 super().visit_FunctionDef(node) 

247 

248 # Check for memory type decorators 

249 memory_decorators = {'numpy', 'cupy', 'torch', 'tensorflow', 'jax'} 

250 has_memory_decorator = False 

251 

252 for decorator in node.decorator_list: 

253 if isinstance(decorator, ast.Name) and decorator.id in memory_decorators: 

254 has_memory_decorator = True 

255 break 

256 elif isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) and decorator.func.id in memory_decorators: 

257 has_memory_decorator = True 

258 break 

259 

260 # Check if function is in processing module and missing memory decorator 

261 if not has_memory_decorator and 'processing' in self.file_path: 

262 # Skip if function is private or a method 

263 if not node.name.startswith('_') and self.current_function is None: 

264 self.add_violation( 

265 node=node, 

266 violation_type=MEMORY_TYPE, 

267 message=f"Function '{node.name}' in processing module should have a memory type decorator" 

268 ) 

269 

270 

271# Decorator functions for runtime validation 

272 

273def validate_path_types(**type_annotations): 

274 """ 

275 Decorator to validate path type parameters. 

276  

277 Args: 

278 **type_annotations: Type annotations for parameters. 

279  

280 Returns: 

281 Decorated function. 

282 """ 

283 def decorator(func): 

284 @functools.wraps(func) 

285 def wrapper(*args, **kwargs): 

286 # Runtime validation can be added here if needed 

287 return func(*args, **kwargs) 

288 

289 # Store type annotations for AST analysis 

290 wrapper.__path_type_annotations__ = type_annotations 

291 return wrapper 

292 

293 return decorator 

294 

295 

296def validate_backend_parameter(func): 

297 """ 

298 Decorator to validate backend parameter. 

299  

300 Args: 

301 func: Function to decorate. 

302  

303 Returns: 

304 Decorated function. 

305 """ 

306 @functools.wraps(func) 

307 def wrapper(*args, **kwargs): 

308 # Runtime validation can be added here if needed 

309 return func(*args, **kwargs) 

310 

311 # Mark function for AST analysis 

312 wrapper.__validate_backend__ = True 

313 return wrapper 

314 

315 

316# Main validation function 

317 

318def validate_file(file_path: str) -> List[ValidationViolation]: 

319 """ 

320 Validate a Python file using AST-based analysis. 

321  

322 Args: 

323 file_path: Path to the Python file. 

324  

325 Returns: 

326 List of validation violations. 

327 """ 

328 with open(file_path, 'r', encoding='utf-8') as f: 

329 source = f.read() 

330 

331 try: 

332 tree = ast.parse(source, filename=file_path) 

333 except SyntaxError as e: 

334 return [ValidationViolation( 

335 file_path=file_path, 

336 line_number=e.lineno or 0, 

337 violation_type="syntax_error", 

338 message=f"Syntax error: {e}", 

339 node=None 

340 )] 

341 

342 violations = [] 

343 

344 # Run all validators 

345 validators = [ 

346 PathTypeValidator(file_path), 

347 BackendParameterValidator(file_path), 

348 VFSBoundaryValidator(file_path), 

349 MemoryTypeValidator(file_path) 

350 ] 

351 

352 for validator in validators: 

353 validator.visit(tree) 

354 violations.extend(validator.violations) 

355 

356 return violations