potluck.mast
Mast: Matcher of ASTs for Python Copyright (c) 2017-2018, Benjamin P. Wood, Wellesley College
mast.py
This version is included in potluck instead of codder.
See NOTE below for how to run this file as a script.
Uses the builtin ast package for parsing of concrete syntax and representation of abstract syntax for both source code and patterns.
The main API for pattern matching with mast:
parse
: parse an AST from a python source stringparse_file
: parse an AST from a python source fileparse_pattern
: parse a pattern from a pattern source stringmatch
: check if an AST matches a patternfind
: search for the first (improper) sub-AST matching a patternfindall
: search for all (improper) sub-ASTs matching a patterncount
: count all (improper) sub-ASTs matching a pattern
API utility functions applicable to all ASTs in ast:
dump
: a nicer version of ast.dump AST structure pretty-printingast2source
: source: pretty-print an AST to python source
NOTE: To run this file as a script, go up one directory and execute
python -m potluck.mast
For example, if this file is ~/Sites/potluck/potluck/mast.py, then
cd ~/Sites/potluck
python -m potluck.mast
Because of relative imports, you cannot attempt to run this file as a script in ~/Sites/potluck/potluck. Here's what will happen:
$ cd ~/Sites/potluck/potluck
$ python mast.py
Traceback (most recent call last):
File "mast.py", line 40, in <module>
from .util import (...
SystemError: Parent module '' not loaded, cannot perform relative import
This StackOverflow post is helpful for understanding this issue: https://stackoverflow.com/questions/14132789/relative-imports-for-the-billionth-time
SOME HISTORY [2019/01/19-23, lyn] Python 2 to 3 conversion plus debugging aids [2021/06/22ish, Peter Mawhorter] Linting pass; import into potluck
TODO:
- have flag to control debugging prints (look for $)
Documentation of the Python ast package:
1""" 2Mast: Matcher of ASTs for Python 3Copyright (c) 2017-2018, Benjamin P. Wood, Wellesley College 4 5mast.py 6 7This version is included in potluck instead of codder. 8 9See NOTE below for how to run this file as a script. 10 11Uses the builtin ast package for parsing of concrete syntax and 12representation of abstract syntax for both source code and patterns. 13 14The main API for pattern matching with mast: 15 16- `parse`: parse an AST from a python source string 17- `parse_file`: parse an AST from a python source file 18- `parse_pattern`: parse a pattern from a pattern source string 19- `match`: check if an AST matches a pattern 20- `find`: search for the first (improper) sub-AST matching a pattern 21- `findall`: search for all (improper) sub-ASTs matching a pattern 22- `count`: count all (improper) sub-ASTs matching a pattern 23 24API utility functions applicable to all ASTs in ast: 25 26- `dump`: a nicer version of ast.dump AST structure pretty-printing 27- `ast2source`: source: pretty-print an AST to python source 28 29NOTE: To run this file as a script, go up one directory and execute 30 31```sh 32python -m potluck.mast 33``` 34 35For example, if this file is ~/Sites/potluck/potluck/mast.py, then 36 37```sh 38cd ~/Sites/potluck 39python -m potluck.mast 40``` 41 42Because of relative imports, you *cannot* attempt to run this file 43as a script in ~/Sites/potluck/potluck. Here's what will happen: 44 45```sh 46$ cd ~/Sites/potluck/potluck 47$ python mast.py 48Traceback (most recent call last): 49File "mast.py", line 40, in <module> 50 from .util import (... 51SystemError: Parent module '' not loaded, cannot perform relative import 52``` 53 54This StackOverflow post is helpful for understanding this issue: 55https://stackoverflow.com/questions/14132789/relative-imports-for-the-billionth-time 56 57SOME HISTORY 58[2019/01/19-23, lyn] Python 2 to 3 conversion plus debugging aids 59[2021/06/22ish, Peter Mawhorter] Linting pass; import into potluck 60 61TODO: 62 63* have flag to control debugging prints (look for $) 64 65Documentation of the Python ast package: 66 67- Python 2: https://docs.python.org/2/library/ast.html 68- Python 3: 69 - https://docs.python.org/3/library/ast.html 70 - https://greentreesnakes.readthedocs.io/en/latest/ 71""" 72 73import ast 74import os 75import re 76import sys 77 78from collections import OrderedDict as odict 79 80from .mast_utils import ( 81 Some, 82 takeone, 83 FiniteIterator, 84 iterone, 85 iterempty, 86 dict_unbind 87) 88 89import itertools 90ichain = itertools.chain.from_iterable 91 92 93#------------------------------------# 94# AST data structure pretty printing # 95#------------------------------------# 96 97def dump(node): 98 """Nicer display of ASTs.""" 99 return ( 100 ast.dump(node) if isinstance(node, ast.AST) 101 else '[' + ', '.join(map(dump, node)) + ']' if type(node) == list 102 else '(' + ', '.join(map(dump, node)) + ')' if type(node) == tuple 103 else '{' + ', '.join(dump(key) + ': ' + dump(value) 104 for key, value in node.items()) + '}' 105 if type(node) == dict 106 # Sneaky way to have dump descend into Some and FiniteIterator objects 107 # for debugging 108 else node.dump(dump) 109 if (type(node) == Some or type(node) == FiniteIterator) 110 else repr(node) 111 ) 112 113 114def showret(x, prefix=''): 115 """Shim for use debugging AST return values.""" 116 print(prefix, dump(x)) 117 return x 118 119 120#-------------------# 121# Pattern Variables # 122#-------------------# 123 124NODE_VAR_SPEC = (1, re.compile(r'\A_(?:([a-zA-Z0-9]+)_)?\Z')) 125"""Spec for names of scalar pattern variables.""" 126 127SEQ_TYPES = set([ 128 ast.arguments, 129 ast.Assign, 130 ast.Call, 131 ast.Expr, 132 # ast.For, 133 # ast.Print, # removed from ast module in Python 3 134 ast.Tuple, 135 ast.List, 136 list, 137 type(None), 138 ast.BoolOp, 139]) 140"""Types against which sequence patterns match.""" 141 142SET_VAR_SPEC = (1, re.compile(r'\A___(?:([a-zA-Z0-9]+)___)?\Z')) 143"""Spec for names of sequence pattern variables.""" 144 145if sys.version_info[0] < 3 or sys.version_info[1] < 8: 146 # This was the setup before 3.8 147 LIT_TYPES = { 148 'int': lambda x: (Some(x.n) 149 if type(x) == ast.Num and type(x.n) == int 150 else None), 151 'float': lambda x: (Some(x.n) 152 if type(x) == ast.Num and type(x.n) == float 153 else None), 154 'str': lambda x: (Some(x.s) 155 if type(x) == ast.Str 156 else None), 157 # 'bool': lambda x: (Some(bool(x.id)) 158 # Lyn notes this was a bug! Should have been Some(bool(x.id == 'True'))! 159 # if type(x) == ast.Name and (x.id == 'True' 160 # or x.id == 'False') 161 # else None), 162 # ----- 163 # Above is Python 2 code for bools; below is Python 3 (where bools 164 # are named constants) 165 'bool': lambda x: ( 166 Some(x.value) 167 if ( 168 type(x) == ast.NameConstant 169 and (x.value is True or x.value is False) 170 ) else None 171 ), 172 } 173 """Types against which typed literal pattern variables match.""" 174else: 175 # From 3.8, Constant is used in place of lots of previous stuff 176 LIT_TYPES = { 177 'int': lambda x: (Some(x.value) 178 if type(x) == ast.Constant and type(x.value) == int 179 else None), 180 'float': lambda x: (Some(x.value) 181 if type(x) == ast.Constant 182 and type(x.value) == float 183 else None), 184 'str': lambda x: (Some(x.value) 185 if type(x) == ast.Constant and type(x.value) == str 186 else None), 187 'bool': lambda x: ( 188 Some(x.value) 189 if ( 190 type(x) == ast.Constant 191 and (x.value is True or x.value is False) 192 ) else None 193 ), 194 } 195 """Types against which typed literal pattern variables match.""" 196 197TYPED_LIT_VAR_SPEC = ( 198 (1, 2), 199 re.compile(r'\A_([a-zA-Z0-9]+)_(' + '|'.join(LIT_TYPES.keys()) + r')_\Z') 200) 201"""Spec for names/types of typed literal pattern variables.""" 202 203 204def var_is_anonymous(identifier): 205 """Determine whether a pattern variable name (string) is anonymous.""" 206 assert type(identifier) == str # All Python 3 strings are unicode 207 return not re.search(r'[a-zA-Z0-9]', identifier) 208 209 210def node_is_name(node): 211 """Determine if a node is an AST Name node.""" 212 return isinstance(node, ast.Name) 213 214 215def identifier_key(identifier, spec): 216 """Extract the name of a pattern variable from its identifier string, 217 returning an option: Some(key) if identifier is valid variable 218 name according to spec, otherwise None. 219 220 Examples for NODE_VAR_SPEC: 221 identifier_key('_a_') => Some('a') 222 identifier_key('_') => Some('_') 223 identifier_key('_a') => None 224 225 """ 226 assert type(identifier) == str # All Python 3 strings are unicode 227 groups, regex = spec 228 match = regex.match(identifier) 229 if match: 230 if var_is_anonymous(identifier): 231 return identifier 232 elif type(groups) == tuple: 233 return tuple(match.group(i) for i in groups) 234 else: 235 return match.group(groups) 236 else: 237 return None 238 239 240def node_var(pat, spec=NODE_VAR_SPEC, wrap=False): 241 """Extract the key name of a scalar pattern variable, 242 returning Some(key) if pat is a scalar pattern variable, 243 otherwise None. 244 245 A named or anonymous node variable pattern, written `_a_` or `_`, 246 respectively, may appear in any expression or identifier context 247 in a pattern. It matches any single AST in the corresponding 248 position in the target program. 249 250 """ 251 if wrap: 252 pat = ast.Name(id=pat, ctx=None) 253 elif isinstance(pat, ast.Expr): 254 pat = pat.value 255 elif isinstance(pat, ast.alias) and pat.asname is None: 256 # [Peter Mawhorter 2021-8-29] Want to treat aliases without an 257 # 'as' part kind of like normal Name nodes. 258 pat = ast.Name(id=pat.name, ctx=None) 259 return ( 260 identifier_key(pat.id, spec) 261 if node_is_name(pat) 262 else None 263 ) 264 265 266def node_var_str(pat, spec=NODE_VAR_SPEC): 267 268 return node_var(pat, spec=spec, wrap=True) 269 270 271def set_var(pat, wrap=False): 272 """Extract the key name of a set or sequence pattern variable, 273 returning Some(key) if pat is a set or sequence pattern variable, 274 otherwise None. 275 276 A named or anonymous set or sequence pattern variable, written 277 `___a___` or `___`, respectively, may appear as an element of a 278 set or sequence context in a pattern. It matches 0 or more nodes 279 in the corresponding context in the target program. 280 281 """ 282 return node_var(pat, 283 spec=SET_VAR_SPEC, 284 wrap=wrap) 285 286 287def set_var_str(pat): 288 return set_var(pat, wrap=True) 289 290 291def typed_lit_var(pat): 292 """Extract the key name of a typed literal pattern variable, 293 returning Some(key) if pat is a typed literal pattern variable, 294 otherwise None. 295 296 A typed literal variable pattern, written `_a_type_`, may appear 297 in any expression context in a pattern. It matches any single AST 298 node for a literal of the given primitive type in the 299 corresponding position in the target program. 300 301 """ 302 return node_var(pat, spec=TYPED_LIT_VAR_SPEC) 303 304# def stmt_var(pat): 305# """Extract the key name of a sequence pattern variable, returning 306# Some(key) if pat is a sequence pattern variable appearing in a 307# statement context, otherwise None. 308# """ 309# return seq_var(pat.value) if isinstance(pat, ast.Expr) else None 310 311 312#---------------------# 313# AST Node Properties # 314#---------------------# 315 316def is_pat(p): 317 """Determine if p could be a pattern (by type).""" 318 # All Python 3 strings are unicode 319 return isinstance(p, ast.AST) or type(p) == str 320 321 322def node_is_docstring(node): 323 """Is this node a docstring node?""" 324 return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str) 325 326 327def node_is_bindable(node): 328 """Can a node pattern variable bind to this node?""" 329 # return isinstance(node, ast.expr) or isinstance(node, ast.stmt) 330 # Modified to allow debugging prints 331 result = ( 332 isinstance(node, ast.expr) 333 or isinstance(node, ast.stmt) 334 or (isinstance(node, ast.alias) and node.asname is None) 335 ) 336 # print('\n$ node_is_bindable({}) => {}'.format(dump(node),result)) 337 return result 338 339 340def node_is_lit(node, ty): 341 """Is this node a literal primitive node?""" 342 return ( 343 (isinstance(node, ast.Num) and ty == type(node.n)) # noqa E721 344 # All Python 3 strings are unicode 345 or (isinstance(node, ast.Str) and ty == str) 346 or ( 347 isinstance(node, ast.Name) 348 and ty == bool 349 and (node.id == 'True' or node.id == 'False') 350 ) 351 ) 352 353 354def expr_has_type(node, ty): 355 """Is this expr statically guaranteed to be of this type? 356 357 Literals and conversions have definite types. All other types are 358 conservatively statically unknown. 359 360 """ 361 return node_is_lit(node, ty) or match(node, '{}(_)'.format(ty.__name__)) 362 363 364def node_line(node): 365 """Bet the line number of the source line on which this node starts. 366 (best effort)""" 367 try: 368 return node.lineno if type(node) != list else node[0].lineno 369 except Exception: 370 return None 371 372 373#---------------------------# 374# Checkers and Transformers # 375#---------------------------# 376 377class PatternSyntaxError(BaseException): 378 """Exception for errors in pattern syntax.""" 379 def __init__(self, node, message): 380 BaseException.__init__( 381 self, 382 "At pattern line {}: {}".format(node_line(node), message) 383 ) 384 385 386class PatternValidator(ast.NodeVisitor): 387 """AST visitor: check pattern structure.""" 388 def __init__(self): 389 self.parent = None 390 pass 391 392 def generic_visit(self, pat): 393 oldparent = self.parent 394 self.parent = pat 395 try: 396 return ast.NodeVisitor.generic_visit(self, pat) 397 finally: 398 self.parent = oldparent 399 pass 400 pass 401 402 def visit_Name(self, pat): 403 sn = set_var(pat) 404 if sn: 405 if type(self.parent) not in SEQ_TYPES: 406 raise PatternSyntaxError( 407 pat, 408 "Set/sequence variable ({}) not allowed in {} node".format( 409 sn, type(self.parent) 410 ) 411 ) 412 pass 413 pass 414 415 def visit_arg(self, pat): 416 '''[2019/01/22, lyn] Python 3 now has arg object for param (not Name object). 417 So copied visit_Name here.''' 418 sn = set_var(pat) 419 if sn: 420 if type(self.parent) not in SEQ_TYPES: 421 raise PatternSyntaxError( 422 pat, 423 "Set/sequence variable ({}) not allowed in {} node".format( 424 sn.value, type(self.parent) 425 ) 426 ) 427 pass 428 pass 429 430 def visit_Call(self, c): 431 if 1 < sum(1 for kw in c.keywords if set_var_str(kw.arg)): 432 raise PatternSyntaxError( 433 c.keywords, 434 "Calls may use at most one keyword argument set variable." 435 ) 436 return self.generic_visit(c) 437 438 def visit_keyword(self, k): 439 if (identifier_key(k.arg, SET_VAR_SPEC) 440 and not (node_is_name(k.value) 441 and var_is_anonymous(k.value.id))): 442 raise PatternSyntaxError( 443 k.value, 444 "Value patterns for keyword argument set variables must be _." 445 ) 446 return self.generic_visit(k) 447 pass 448 449 450# TODO [Peter 2021-6-24]: This pass breaks things subtly by removing 451# e.g., decorators lists, resulting in an AST tree that cannot be 452# compiled and run as code! In the past there was mention that without 453# this pass things would break... but for now it's disabled by default. 454class RemoveDocstrings(ast.NodeTransformer): 455 """AST Transformer: remove all docstring nodes.""" 456 def filterDocstrings(self, seq): 457 # print('PREFILTERED', seq) 458 filt = [self.visit(n) for n in seq 459 if not node_is_docstring(n)] 460 # print('FILTERED', dump(filt)) 461 return filt 462 463 def visit_Expr(self, node): 464 if isinstance(node.value, ast.Str): 465 assert False 466 # return ast.copy_location(ast.Expr(value=None), node) 467 else: 468 return self.generic_visit(node) 469 470 def visit_FunctionDef(self, node): 471 # print('Removing docstring: %s' % dump(node.body[0].value)) 472 return ast.copy_location(ast.FunctionDef( 473 name=node.name, 474 args=self.generic_visit(node.args), 475 body=self.filterDocstrings(node.body), 476 ), node) 477 478 def visit_Module(self, node): 479 return ast.copy_location(ast.Module( 480 body=self.filterDocstrings(node.body) 481 ), node) 482 483 def visit_For(self, node): 484 return ast.copy_location(ast.For( 485 body=self.filterDocstrings(node.body), 486 target=node.target, 487 iter=node.iter, 488 orelse=self.filterDocstrings(node.orelse) 489 ), node) 490 491 def visit_While(self, node): 492 return ast.copy_location(ast.While( 493 body=self.filterDocstrings(node.body), 494 test=node.test, 495 orelse=self.filterDocstrings(node.orelse) 496 ), node) 497 498 def visit_If(self, node): 499 return ast.copy_location(ast.If( 500 body=self.filterDocstrings(node.body), 501 test=node.test, 502 orelse=self.filterDocstrings(node.orelse) 503 ), node) 504 505 def visit_With(self, node): 506 return ast.copy_location(ast.With( 507 body=self.filterDocstrings(node.body), 508 # Python 3 just has withitems: 509 items=node.items 510 # Old Python 2 stuff: 511 #context_expr=node.context_expr, 512 #optional_vars=node.optional_vars 513 ), node) 514 515# def visit_TryExcept(self, node): 516# return ast.copy_location(ast.TryExcept( 517# body=self.filterDocstrings(node.body), 518# handlers=self.filterDocstrings(node.handlers), 519# orelse=self.filterDocstrings(node.orelse) 520# ), node) 521 522# def visit_TryFinally(self, node): 523# return ast.copy_location(ast.TryFinally( 524# body=self.filterDocstrings(node.body), 525# finalbody=self.filterDocstrings(node.finalbody) 526# ), node) 527# 528 529# In Python 3, a single Try node covers both TryExcept and TryFinally 530 def visit_Try(self, node): 531 return ast.copy_location(ast.Try( 532 body=self.filterDocstrings(node.body), 533 handlers=self.filterDocstrings(node.handlers), 534 orelse=self.filterDocstrings(node.orelse), 535 finalbody=self.filterDocstrings(node.finalbody) 536 ), node) 537 538 def visit_ClassDef(self, node): 539 return ast.copy_location(ast.ClassDef( 540 body=self.filterDocstrings(node.body), 541 name=node.name, 542 bases=node.bases, 543 decorator_list=node.decorator_list 544 ), node) 545 546 547class FlattenBoolOps(ast.NodeTransformer): 548 """AST transformer: flatten nested boolean expressions.""" 549 def visit_BoolOp(self, node): 550 values = [] 551 for v in node.values: 552 if isinstance(v, ast.BoolOp) and v.op == node.op: 553 values += v.values 554 else: 555 values.append(v) 556 pass 557 pass 558 return ast.copy_location(ast.BoolOp(op=node.op, values=values), node) 559 pass 560 561 562class ContainsCall(Exception): 563 """Exception raised when prohibiting call expressions.""" 564 pass 565 566 567class ProhibitCall(ast.NodeVisitor): 568 """AST visitor: check call-freedom.""" 569 def visit_Call(self, c): 570 raise ContainsCall 571 pass 572 573 574class SetLoadContexts(ast.NodeTransformer): 575 """ 576 Transforms any AugStore contexts into Load contexts, for use with 577 DesugarAugAssign. 578 """ 579 def visit_AugStore(self, ctx): 580 return ast.copy_location(ast.Load(), ctx) 581 582 583class DesugarAugAssign(ast.NodeTransformer): 584 """AST transformer: desugar augmented assignments (e.g., `+=`) 585 when simple desugaring does not duplicate effectful expressions. 586 587 FIXME: this desugaring and others cause surprising match 588 counts. See: https://github.com/wellesleycs111/codder/issues/31 589 590 FIXME: This desugaring should probably not happen in cases where 591 .__iadd__ and .__add__ yield different results, for example with 592 lists where .__iadd__ is really extend while .__add__ creates a new 593 list, such that 594 595 [1, 2] + "34" 596 597 is an error, but 598 599 x = [1, 2] 600 x += "34" 601 602 is not! 603 604 A better approach might be to avoid desugaring entirely and instead 605 provide common collections of match rules for common patterns. 606 607 """ 608 def visit_AugAssign(self, assign): 609 try: 610 # Desugaring *all* AugAssigns is not sound. 611 # Example: xs[f()] += 1 --> xs[f()] = xs[f()] + 1 612 # Check for presence of call in target. 613 ProhibitCall().visit(assign.target) 614 return ast.copy_location( 615 ast.Assign( 616 targets=[self.visit(assign.target)], 617 value=ast.copy_location( 618 ast.BinOp( 619 left=SetLoadContexts().visit(assign.target), 620 op=self.visit(assign.op), 621 right=self.visit(assign.value) 622 ), 623 assign 624 ) 625 ), 626 assign 627 ) 628 except ContainsCall: 629 return self.generic_visit(assign) 630 631 def visit_AugStore(self, ctx): 632 return ast.copy_location(ast.Store(), ctx) 633 634 def visit_AugLoad(self, ctx): 635 return ast.copy_location(ast.Load(), ctx) 636 637 638class ExpandExplicitElsePattern(ast.NodeTransformer): 639 """AST transformer: transform patterns that in include `else: ___` 640 to use `else: _; ___` instead, forcing the else to have at least 641 one statement.""" 642 def visit_If(self, ifelse): 643 if 1 == len(ifelse.orelse) and set_var(ifelse.orelse[0]): 644 return ast.copy_location( 645 ast.If( 646 test=ifelse.test, 647 body=ifelse.body, 648 orelse=[ 649 ast.copy_location( 650 ast.Expr(value=ast.Name(id='_', ctx=None)), 651 ifelse.orelse[0] 652 ), 653 ifelse.orelse[0] 654 ] 655 ), 656 ifelse 657 ) 658 else: 659 return self.generic_visit(ifelse) 660 661 662def pipe_visit(node, visitors): 663 """Send an AST through a pipeline of AST visitors/transformers.""" 664 if 0 == len(visitors): 665 return node 666 else: 667 v = visitors[0] 668 if isinstance(v, ast.NodeTransformer): 669 # visit for the transformed result 670 return pipe_visit(v.visit(node), visitors[1:]) 671 else: 672 # visit for the effect 673 v.visit(node) 674 return pipe_visit(node, visitors[1:]) 675 676 677#---------# 678# Parsing # 679#---------# 680 681def parse_file(path, docstrings=True): 682 """Load and parse a program AST from the given path.""" 683 with open(path) as f: 684 return parse(f.read(), filename=path, docstrings=docstrings) 685 pass 686 687 688# Note: 2021-7-29 Peter Mawhorter: A bug in DesugarAugAssign was found 689# and fixed (incorrectly left Store context on RHS, which was apparently 690# ignored in many Python versions...). However, upon consideration I've 691# decided to remove it from standard + docstring passes, since the 692# problem with += and lists is pretty relevant (I've seen students 693# accidentally do list += string several times in office hours)! 694STANDARD_PASSES = [ 695 RemoveDocstrings, 696 #DesugarAugAssign, 697 FlattenBoolOps 698] 699"""Thunks for AST visitors/transformers that are applied during parsing.""" 700 701 702DOCSTRING_PASSES = [ 703 #DesugarAugAssign, 704 FlattenBoolOps 705] 706"""Extra thunks for docstring mode?""" 707 708 709class MastParseError(Exception): 710 def __init__(self, pattern, error): 711 super(MastParseError, self).__init__( 712 'Error parsing pattern source string <<<\n{}\n>>>\n{}'.format( 713 pattern, str(error) 714 ) 715 ) 716 self.trigger = error # hang on to the original error 717 pass 718 pass 719 720 721def parse( 722 string, 723 docstrings=True, 724 filename='<unknown>', 725 passes=STANDARD_PASSES 726): 727 """Parse an AST from a string.""" 728 if type(string) == str: # All Python 3 strings are unicode 729 string = str(string) 730 assert type(string) == str # All Python 3 strings are unicode 731 if docstrings and RemoveDocstrings in passes: 732 passes.remove(RemoveDocstrings) 733 pipe = [thunk() for thunk in passes] 734 735 try: 736 parsed_ast = ast.parse(string, filename=filename) 737 except Exception as error: 738 raise MastParseError(string, error) 739 return pipe_visit(parsed_ast, pipe) 740 741 742PATTERN_PARSE_CACHE = {} 743"""a pattern parsing cache""" 744 745 746PATTERN_PASSES = STANDARD_PASSES + [ 747 # ExpandExplicitElsePattern, 748 # Not currently used because a number of existing rules and specs 749 # actually use the ability of `if _: ___; else: ___` to match 750 # either an if-else or an elseless if. 751] 752 753 754def parse_pattern(pat, toplevel=False, docstrings=True): 755 """Parse and validate a pattern.""" 756 if isinstance(pat, ast.AST): 757 PatternValidator().visit(pat) 758 return pat 759 elif type(pat) == list: 760 iter(lambda x: PatternValidator().visit(x), pat) 761 return pat 762 elif type(pat) == str: # All Python 3 strings are unicode 763 cache_key = (toplevel, docstrings, pat) 764 pat_ast = PATTERN_PARSE_CACHE.get(cache_key) 765 if not pat_ast: 766 # 2021-6-16 Peter commented this out (should it be used 767 # instead of passes on the line below?) 768 # pipe = [thunk() for thunk in PATTERN_PASSES] 769 pat_ast = parse(pat, filename='<pattern>', passes=PATTERN_PASSES) 770 PatternValidator().visit(pat_ast) 771 if not toplevel: 772 # Unwrap as needed. 773 if len(pat_ast.body) == 1: 774 # If the pattern is a single definition, statement, or 775 # expression, unwrap the Module node. 776 b = pat_ast.body[0] 777 pat_ast = b.value if isinstance(b, ast.Expr) else b 778 pass 779 else: 780 # If the pattern is a sequence of definitions or 781 # statements, validate from the top, but return only 782 # the Module body. 783 pat_ast = pat_ast.body 784 pass 785 pass 786 PATTERN_PARSE_CACHE[cache_key] = pat_ast 787 pass 788 return pat_ast 789 else: 790 assert False, 'Cannot parse pattern of type {}: {}'.format( 791 type(pat), 792 dump(pat) 793 ) 794 795 796def pat(pat, toplevel=False, docstrings=True): 797 """Alias for parse_pattern.""" 798 return parse_pattern(pat, toplevel, docstrings) 799 800 801#--------------# 802# Permutations # 803#--------------# 804 805# [2019/02/09, lyn] Careful below! generally need list(map(...)) in Python 3 806ASSOCIATIVE_OPS = list(map(type, [ 807 ast.Add(), ast.Mult(), ast.BitOr(), ast.BitXor(), ast.BitAnd(), 808 ast.Eq(), ast.NotEq(), ast.Is(), ast.IsNot() 809])) 810"""Types of AST operation nodes that are associative.""" 811 812 813def op_is_assoc(op): 814 """Determine if the given operation node is associative.""" 815 return type(op) in ASSOCIATIVE_OPS 816 817 818MIRRORING = [ 819 (ast.Lt(), ast.Gt()), 820 (ast.LtE(), ast.GtE()), 821 (ast.Eq(), ast.Eq()), 822 (ast.NotEq(), ast.NotEq()), 823 (ast.Is(), ast.Is()), 824 (ast.IsNot(), ast.IsNot()), 825] 826"""Pairs of operations that are mirrors of each other.""" 827 828 829def mirror(op): 830 """Return the mirror operation of the given operation if any.""" 831 for (x, y) in MIRRORING: 832 if type(x) == type(op): 833 return y 834 elif type(y) == type(op): 835 return x 836 pass 837 return None 838 839 840def op_has_mirror(op): 841 """Determine if the given operation has a mirror.""" 842 return bool(mirror) 843 844 845# MIRROR_OPS = set(reduce(lambda acc,(x,y): acc + [x,y], MIRRORING, [])) 846# ASSOCIATIVE_IF_PURE_OPS = set([ 847# ast.And(), ast.Or() 848# ]) 849 850 851# Patterns for potentially effectful operations 852FUNCALL = parse_pattern('_(___)') 853METHODCALL = parse_pattern('_._(___)') 854ASSIGN = parse_pattern('_ = _') 855DEL = parse_pattern('del _') 856PRINT = parse_pattern('print(___)') 857IMPURE2 = [METHODCALL, ASSIGN, DEL, PRINT] 858 859 860# Patterns for definitely pure operations 861PUREFUNS = set(['len', 'round', 'int', 'str', 'float', 'list', 'tuple', 862 'map', 'filter', 'reduce', 'iter', 'all', 'any']) 863 864 865class PermuteBoolOps(ast.NodeTransformer): 866 """AST transformer: permute topmost boolean operation 867 according to the given index order.""" 868 def __init__(self, indices): 869 self.indices = indices 870 871 def visit_BoolOp(self, node): 872 if self.indices: 873 values = [ 874 self.generic_visit(node.values[i]) 875 for i in self.indices 876 ] 877 self.indices = None 878 return ast.copy_location( 879 ast.BoolOp(op=node.op, values=values), 880 node 881 ) 882 else: 883 return self.generic_visit(node) 884 885# class PermuteAST(ast.NodeVisitor): 886# def visit_BinOp(self, node): 887# yield self.generic_visit(node) 888# if op_is_assoc(node) and ispure(node): 889 890 891def node_is_pure(node, purefuns=[]): 892 """Determine if the given node is (conservatively pure (effect-free).""" 893 return ( 894 count(node, FUNCALL, 895 matchpred=lambda x, env: 896 x.func not in PUREFUNS and x.func not in purefuns) == 0 897 and all(count(node, pat) == 0 for pat in IMPURE2) 898 ) 899 900 901def permutations(node): 902 """ 903 Generate all permutations of the binary and boolean operations, as 904 well as of import orderings, in and below node, via associativity and 905 mirroring, respecting purity. Because this is heavily exponential, 906 limiting the number of imported names and the complexity of binary 907 operation trees in your patterns is a good thing. 908 """ 909 # TODO: These permutations allow matching like this: 910 # Node: import math, sys, io 911 # Pat: import _x_, ___ 912 # can bind x to math or io, but NOT sys (see TODO near 913 # parse_pattern("___")) 914 if isinstance(node, ast.Import): 915 for perm in list_permutations(node.names): 916 yield ast.Import(names=perm) 917 elif isinstance(node, ast.ImportFrom): 918 for perm in list_permutations(node.names): 919 yield ast.ImportFrom( 920 module=node.module, 921 names=perm, 922 level=node.level 923 ) 924 if ( 925 isinstance(node, ast.BinOp) 926 and op_is_assoc(node) 927 and node_is_pure(node) 928 ): 929 for left in permutations(node.left): 930 for right in permutations(node.right): 931 yield ast.copy_location(ast.BinOp( 932 left=left, 933 op=node.op, 934 right=right, 935 ctx=node.ctx 936 )) 937 yield ast.copy_location(ast.BinOp( 938 left=right, 939 op=node.op, 940 right=left, 941 ctx=node.ctx 942 )) 943 elif (isinstance(node, ast.Compare) 944 and len(node.ops) == 1 945 and op_has_mirror(node.ops[0]) 946 and node_is_pure(node)): 947 assert len(node.comparators) == 1 948 for left in permutations(node.left): 949 for right in permutations(node.comparators[0]): 950 # print('PERMUTE', dump(left), dump(node.ops), dump(right)) 951 yield ast.copy_location(ast.Compare( 952 left=left, 953 ops=node.ops, 954 comparators=[right] 955 ), node) 956 yield ast.copy_location(ast.Compare( 957 left=right, 958 ops=[mirror(node.ops[0])], 959 comparators=[left] 960 ), node) 961 elif isinstance(node, ast.BoolOp) and node_is_pure(node): 962 #print(dump(node)) 963 stuff = [[x for x in permutations(v)] for v in node.values] 964 prod = [x for x in itertools.product(*stuff)] 965 # print(prod) 966 for values in prod: 967 # print('VALUES', map(dump,values)) 968 for indices in itertools.permutations(range(len(node.values))): 969 # print('BOOL', map(dump, values)) 970 yield PermuteBoolOps(indices).visit( 971 ast.copy_location( 972 ast.BoolOp(op=node.op, values=values), 973 node 974 ) 975 ) 976 pass 977 pass 978 pass 979 else: 980 # print('NO', dump(node)) 981 yield node 982 pass 983 984 985def list_permutations(items): 986 """ 987 A generator which yields all possible orderings of the given list. 988 """ 989 if len(items) <= 1: 990 yield items[:] 991 else: 992 first = items[0] 993 for subperm in list_permutations(items[1:]): 994 for i in range(len(subperm) + 1): 995 yield subperm[:i] + [first] + subperm[i:] 996 997 998#----------# 999# Matching # 1000#----------# 1001 1002def match(node, pat, **kwargs): 1003 """A convenience wrapper for matchpat that accepts a patterns either 1004 as an pre-parsed AST or as a pattern source string to be parsed. 1005 1006 """ 1007 return matchpat(node, parse_pattern(pat), **kwargs) 1008 1009 1010def predtrue(node, matchenv): 1011 """The True predicate""" 1012 return True 1013 1014 1015def matchpat( 1016 node, 1017 pat, 1018 matchpred=predtrue, 1019 env={}, 1020 gen=False, 1021 normalize=False 1022): 1023 """Match an AST against a (pre-parsed) pattern. 1024 1025 The optional keyword argument matchpred gives a predicate of type 1026 (AST node * match environment) -> bool that filters structural 1027 matches by arbitrary additional criteria. The default is the true 1028 predicate, accepting all structural matches. 1029 1030 The optional keyword argument gen determines whether this function: 1031 - (gen=True) yields an environment for each way pat matches node. 1032 - (gen=False) returns Some(the first of these environments), or 1033 None if there are no matches (default). 1034 1035 EXPERIMENTAL 1036 The optional keyword argument normalize determines whether to 1037 rewrite the target AST node and the pattern to inline simple 1038 straightline variable assignments into large expressions (to the 1039 extent possible) for matching. The default is no normalization 1040 (False). Normalization is experimental. It is rather ad hoc and 1041 conservative and may causes unintuitive matching behavior. Use 1042 with caution. 1043 1044 INTERNAL 1045 The optional keyword env gives an initial match environment which 1046 may be used to constrain otherwise free pattern variables. This 1047 argument is mainly intended for internal use. 1048 1049 """ 1050 assert node is not None 1051 assert pat is not None 1052 if normalize: 1053 if isinstance(node, ast.AST) or type(node) == list: 1054 node = canonical_pure(node) 1055 pass 1056 if isinstance(node, ast.AST) or type(node) == list: 1057 pat = canonical_pure(pat) 1058 pass 1059 pass 1060 1061 # Permute the PATTERN, not the AST, so that nodes returned 1062 # always match real code. 1063 matches = (matchenv 1064 # outer iteration 1065 for permpat in permutations(pat) 1066 # inner iteration 1067 for matchenv in imatches(node, permpat, 1068 Some(env), True) 1069 if matchpred(node, matchenv)) 1070 1071 return matches if gen else takeone(matches) 1072 1073 1074def bind(env1, name, value): 1075 """ 1076 Unify the environment in option env1 with a new binding of name to 1077 value. Return Some(extended environment) if env1 is Some(existing 1078 environment) in which name is not bound or is bound to value. 1079 Otherwise, env1 is None or the existing binding of name 1080 incompatible with value, return None. 1081 """ 1082 # Lyn modified to allow debugging prints 1083 assert type(name) == str 1084 if env1 is None: 1085 # return None 1086 result = None 1087 env = env1.value 1088 if var_is_anonymous(name): 1089 # return env1 1090 result = env1 1091 elif name in env: 1092 if takeone(imatches(env[name], value, Some({}), True)): 1093 # return env1 1094 result = env1 1095 else: 1096 # return None 1097 result = None 1098 else: 1099 env = env.copy() 1100 env[name] = value 1101 # print 'bind', name, dump(value), dump(env) 1102 # return Some(env) 1103 result = Some(env) 1104 # print( 1105 # '\n$ bind({}, {}, {}) => {}'.format( 1106 # dump(env1), 1107 # name, 1108 # dump(value), 1109 # dump(result) 1110 # ) 1111 # ) 1112 return result 1113 pass 1114 1115 1116IGNORE_FIELDS = set(['ctx', 'lineno', 'col_offset']) 1117"""AST node fields to be ignored when matching children.""" 1118 1119 1120def argObjToName(argObj): 1121 '''Convert Python 3 arg object into a Name node, 1122 ignoring any annotation, lineno, and col_offset.''' 1123 if argObj is None: 1124 return None 1125 else: 1126 return ast.Name(id=argObj.arg, ctx=None) 1127 1128 1129def astStr(astObj): 1130 """ 1131 Converts an AST object to a string, imperfectly. 1132 """ 1133 if isinstance(astObj, ast.Name): 1134 return astObj.id 1135 elif isinstance(astObj, ast.Str): 1136 return repr(astObj.s) 1137 elif isinstance(astObj, ast.Num): 1138 return str(astObj.n) 1139 elif isinstance(astObj, (list, tuple)): 1140 return str([astStr(x) for x in astObj]) 1141 elif hasattr(astObj, "_fields") and len(astObj._fields) > 0: 1142 return "{}({})".format( 1143 type(astObj).__name__, 1144 ', '.join(astStr(getattr(astObj, f)) for f in astObj._fields) 1145 ) 1146 elif hasattr(astObj, "_fields"): 1147 return type(astObj).__name__ 1148 else: 1149 return '<{}>'.format(type(astObj).__name__) 1150 1151 1152def defaultToName(defObj): 1153 """ 1154 Converts a default expression to a name for matching purposes. 1155 TODO: Actually do recursive matching on these expressions! 1156 """ 1157 if isinstance(defObj, ast.Name): 1158 return defObj.id 1159 else: 1160 return astStr(defObj) 1161 1162 1163def field_values(node): 1164 """Return a list of the values of all matching-relevant fields of the 1165 given AST node, with fields in consistent list positions.""" 1166 1167 # Specializations: 1168 if isinstance(node, ast.FunctionDef): 1169 return [ast.Name(id=v, ctx=None) if k == 'name' 1170 # Lyn sez: commented out following, because fields are 1171 # only name, arguments, body 1172 # else sorted(v, key=lambda x: x.arg) if k == 'keywords' 1173 else v 1174 for (k, v) in ast.iter_fields(node) 1175 if k not in IGNORE_FIELDS] 1176 1177 if isinstance(node, ast.ClassDef): 1178 return [ast.Name(id=v, ctx=None) if k == 'name' 1179 else v 1180 for (k, v) in ast.iter_fields(node) 1181 if k not in IGNORE_FIELDS] 1182 1183 if isinstance(node, ast.Call): 1184 # Old: sorted(v, key=lambda kw: kw.arg) 1185 # New: keyword args use nominal (not positional) matching. 1186 #return [sorted(v, key=lambda kw: kw.arg) if k == 'keywords' else v 1187 return [ 1188 ( 1189 odict((kw.arg, kw.value) for kw in v) 1190 if k == 'keywords' 1191 else v 1192 ) 1193 for (k, v) in ast.iter_fields(node) 1194 if k not in IGNORE_FIELDS 1195 ] 1196 1197 # Lyn sez: ast.arguments handling is new for Python 3 1198 if isinstance(node, ast.arguments): 1199 argList = [ 1200 # Need to create Name nodes to match patterns. 1201 argObjToName(argObj) 1202 for argObj in node.args 1203 ] # Ignores argObj annotation, lineno, col_offset 1204 if ( 1205 node.vararg is None 1206 and node.kwarg is None 1207 and node.kwonlyargs == [] 1208 and node.kw_defaults == [] 1209 and node.defaults == [] 1210 ): 1211 # Optimize (for debugging purposes) this very common case by 1212 # returning a singleton list of lists 1213 return [argList] 1214 else: 1215 # In unoptimized case, return list with sublists and 1216 # argObjects/Nones: 1217 # TODO: treat triple underscores separately/specially!!! 1218 return [ 1219 argList, 1220 argObjToName(node.vararg), 1221 argObjToName(node.kwarg), 1222 [argObjToName(argObj) for argObj in node.kwonlyargs], 1223 # Peter 2019-9-30 the defaults cannot be reliably converted 1224 # into names, because they are expressions! 1225 node.kw_defaults, 1226 node.defaults 1227 ] 1228 1229 if isinstance(node, ast.keyword): 1230 return [ast.Name(id=v, ctx=None) if k == 'arg' else v 1231 for (k, v) in ast.iter_fields(node) 1232 if k not in IGNORE_FIELDS] 1233 1234 if isinstance(node, ast.Global): 1235 return [ast.Name(id=n, ctx=None) for n in sorted(node.names)] 1236 1237 if isinstance(node, ast.Import): 1238 return [ node.names ] 1239 1240 if isinstance(node, ast.ImportFrom): 1241 return [ast.Name(id=node.module, ctx=None), node.level, node.names] 1242 1243 if isinstance(node, ast.alias): 1244 result = [ ast.Name(id=node.name, ctx=None) ] 1245 if node.asname is not None: 1246 result.append(ast.Name(id=node.asname, ctx=None)) 1247 return result 1248 1249 # General cases: 1250 if isinstance(node, ast.AST): 1251 return [v for (k, v) in ast.iter_fields(node) 1252 if k not in IGNORE_FIELDS] 1253 1254 if type(node) == list: 1255 return node 1256 1257 # Added by Peter Mawhorter 2019-4-2 to fix problem where subpatterns fail 1258 # to match within keyword arguments of functions because field_values 1259 # returns an odict as one value for functions with kwargs, and asking for 1260 # the field_values of that odict was hitting the base case below. 1261 if isinstance(node, (dict, odict)): 1262 return [ 1263 ast.Name(id=n, ctx=None) for n in node.keys() 1264 ] + list(node.values()) 1265 1266 return [] 1267 1268 1269imatchCount = 0 # For showing stack depth in debugging prints for imathces 1270 1271 1272def imatches(node, pat, env1, seq): 1273 """Exponential backtracking match generating 0 or more matching 1274 environments. Supports multiple sequence patterns in one context, 1275 simple permutations of semantically equivalent but syntactically 1276 mirrored patterns. 1277 """ 1278 # Lyn change early-return pattern to named result returned at end to 1279 # allow debugging prints 1280 # global imatchCount #$ 1281 # print('\n$ {}Entering imatches({}, {}, {}, {})'.format( 1282 # '| '*imatchCount, dump(node), dump(pat), dump(env1), seq)) 1283 # imatchCount += 1 #$ 1284 result = iterempty() # default result if not overridden 1285 if env1 is None: 1286 result = iterempty() 1287 # imatchCount -= 1 #$ 1288 # print( 1289 # '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format( 1290 # '| ' * imatchCount, 1291 # dump(node), 1292 # dump(pat), 1293 # dump(env1), 1294 # seq, 1295 # dump(result) 1296 # ) 1297 # ) 1298 return result 1299 env = env1.value 1300 assert env is not None 1301 if ( 1302 ( 1303 type(pat) == bool 1304 or type(pat) == str # All Python 3 strings are unicode 1305 or pat is None 1306 ) 1307 and node == pat 1308 ): 1309 result = iterone(env) 1310 elif type(pat) == int or type(pat) == float: 1311 # Literal int or float pattern 1312 if type(node) == type(pat): 1313 if ( 1314 (type(pat) == int and node == pat) 1315 or (type(pat) == float and abs(node - pat) < 0.001) 1316 ): 1317 result = iterone(env) 1318 pass 1319 elif node_var(pat): 1320 # Var pattern. 1321 # Match and bind name to node. 1322 if node_is_bindable(node): 1323 # [Peter Mawhorter 2021-8-29] Attempting to allow import 1324 # aliases to unify with variable references later on. If the 1325 # alias has an 'as' part we unify as if it's a name with that 1326 # ID, otherwise we use the name part as the ID. 1327 if isinstance(node, ast.alias): 1328 if node.asname: 1329 bind_as = ast.Name(id=node.asname, ctx=None) 1330 else: 1331 bind_as = ast.Name(id=node.name, ctx=None) 1332 else: 1333 bind_as = node 1334 env2 = bind(env1, node_var(pat), bind_as) 1335 if env2: 1336 result = iterone(env2.value) 1337 pass 1338 elif typed_lit_var(pat): 1339 # Var pattern to bind only a literal of given type. 1340 id, ty = typed_lit_var(pat) 1341 lit = LIT_TYPES[ty](node) 1342 # Match and bind name to literal. 1343 if lit: 1344 env2 = bind(env1, id, lit.value) 1345 if env2: 1346 result = iterone(env2.value) 1347 pass 1348 elif type(node) == type(pat): 1349 # Node and pattern have same type. 1350 if type(pat) == list: 1351 # Node and pattern are both lists. Do positional matching. 1352 if len(pat) == 0: 1353 # Empty list pattern. Node must also be empty. 1354 if len(node) == 0: 1355 result = iterone(env) 1356 pass 1357 elif len(node) == 0: 1358 # Non-empty list pattern with empty node. 1359 # Try to match sequence subpatterns. 1360 if seq: 1361 psn = set_var(pat[0]) 1362 if psn: 1363 result = imatches(node, pat[1:], 1364 bind(env1, psn, []), seq) 1365 pass 1366 pass 1367 else: 1368 # Both are non-empty. 1369 psn = set_var(pat[0]) 1370 if seq and psn: 1371 # First subpattern is a sequence pattern. 1372 # Try all consumption sizes, greediest first. 1373 # Unsophisticated exponential backtracking search. 1374 result = ichain( 1375 imatches( 1376 node[i:], 1377 pat[1:], 1378 bind(env1, psn, node[:i]), 1379 seq 1380 ) 1381 for i in range(len(node), -1, -1) 1382 ) 1383 # Lyn sez: common special case helpful for more concrete 1384 # debugging results (e.g., may return FiniteIterator 1385 # rather than itertools.chain object.) 1386 elif len(node) == 1 and len(pat) == 1: 1387 result = imatches(node[0], pat[0], env1, True) 1388 else: 1389 # For all matches of scalar first element sub pattern. 1390 # Generate all corresponding matches of remainder. 1391 # Unsophisticated exponential backtracking search. 1392 result = ichain( 1393 imatches(node[1:], pat[1:], Some(bs), seq) 1394 for bs in imatches(node[0], pat[0], env1, True) 1395 ) 1396 pass 1397 elif type(node) == dict or type(node) == odict: 1398 result = match_dict(node, pat, env1) 1399 else: 1400 # Node and pat have same type, but are not lists. 1401 # Match scalar structures by matching lists of their fields. 1402 if isinstance(node, ast.AST): 1403 # TODO: DEBUG 1404 #if isinstance(node, ast.Import): 1405 # print( 1406 # "FV2i", 1407 # dump(field_values(node)), 1408 # dump(field_values(pat)) 1409 # ) 1410 #if isinstance(node, ast.alias): 1411 # print( 1412 # "FV2a", 1413 # dump(field_values(node)), 1414 # dump(field_values(pat)) 1415 # ) 1416 result = imatches( 1417 field_values(node), 1418 field_values(pat), 1419 env1, 1420 False 1421 ) 1422 pass 1423 pass 1424 # return iterempty() 1425 # imatchCount -= 1 #$ 1426 # print( 1427 # '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format( 1428 # '| ' * imatchCount, 1429 # dump(node), 1430 # dump(pat), 1431 # dump(env1), 1432 # seq, 1433 # dump(result) 1434 # ) 1435 # ) 1436 return result 1437 1438 1439def match_dict(node, pat, env1): 1440 # Node and pattern are both dictionaries. Do nominal matching. 1441 # Match all named key patterns, then all single-key pattern variables, 1442 # and finally the multi-key pattern variables. 1443 assert all(type(k) == str # All Python 3 strings are unicode 1444 for k in pat) 1445 assert all(type(k) == str # All Python 3 strings are unicode 1446 for k in node) 1447 1448 def match_keys(node, pat, envopt): 1449 """Match literal keys.""" 1450 keyopt = takeone(k for k in pat 1451 if not node_var_str(k) 1452 and not set_var_str(k)) 1453 if keyopt: 1454 # There is at least one named key in the pattern. 1455 # If this key is also in the program node, then for each 1456 # match of the corresponding node value and pattern value, 1457 # generate all matches for the remaining keys. 1458 key = keyopt.value 1459 if key in node: 1460 return ichain(match_keys(dict_unbind(node, key), 1461 dict_unbind(pat, key), 1462 Some(kenv)) 1463 for kenv in imatches(node[key], pat[key], 1464 envopt, False)) 1465 else: 1466 return iterempty() 1467 pass 1468 else: 1469 # The pattern contains no literal keys. 1470 # Generate all matches for the node and set key variables. 1471 return match_var_keys(node, pat, envopt) 1472 pass 1473 1474 def match_var_keys(node, pat, envopt): 1475 """Match node variable keys.""" 1476 keyvaropt = takeone(k for k in pat if node_var_str(k)) 1477 if keyvaropt: 1478 # There is at least one single-key variable in the pattern. 1479 # For each key-value pair in the node whose value matches 1480 # this single-key variable's associated value pattern, 1481 # generate all matches for the remaining keys. 1482 keyvar = keyvaropt.value 1483 return ichain(match_var_keys(dict_unbind(node, nkey), 1484 dict_unbind(pat, keyvar), 1485 bind(Some(kenv), 1486 node_var_str(keyvar), 1487 ast.Name(id=nkey, ctx=None))) 1488 # outer iteration: 1489 for nkey, nval in node.items() 1490 # inner iteration: 1491 for kenv in imatches(nval, pat[keyvar], 1492 envopt, False)) 1493 else: 1494 # The pattern contains no single-key variables. 1495 # Generate all matches for the set key variables. 1496 return match_set_var_keys(node, pat, envopt) 1497 pass 1498 1499 def match_set_var_keys(node, pat, envopt): 1500 """Match set variable keys.""" 1501 # NOTE: see discussion of match environments for this case: 1502 # https://github.com/wellesleycs111/codder/issues/25 1503 assert envopt 1504 keysetvaropt = takeone(k for k in pat if set_var_str(k)) 1505 if keysetvaropt: 1506 # There is a multi-key variable in the pattern. 1507 # Capture all remaining key-value pairs in the node. 1508 e = bind(envopt, set_var_str(keysetvaropt.value), 1509 [(ast.Name(id=kw, ctx=None), karg) 1510 for kw, karg in node.items()]) 1511 return iterone(e.value) if e else iterempty() 1512 elif 0 == len(node): 1513 # There is no multi-key variable in the pattern. 1514 # There is a match only if there are no remaining 1515 # keys in the node. 1516 # There should also be no remaining keys in the pattern. 1517 assert 0 == len(pat) 1518 return iterone(envopt.value) 1519 else: 1520 return iterempty() 1521 1522 return match_keys(node, pat, env1) 1523 1524 1525def find(node, pat, **kwargs): 1526 """Pre-order search for first sub-AST matching pattern, returning 1527 (matched node, bindings).""" 1528 kwargs['gen'] = True 1529 return takeone(findall(node, parse_pattern(pat), **kwargs)) 1530 1531 1532def findall(node, pat, outside=[], **kwargs): 1533 """ 1534 Search for all sub-ASTs matching pattern, returning list of (matched 1535 node, bindings). 1536 """ 1537 assert node is not None 1538 assert pat is not None 1539 gen = kwargs.get('gen', False) 1540 kwargs['gen'] = True 1541 pat = parse_pattern(pat) 1542 # Top-level sequence patterns are not "anchored" to the ends of 1543 # the containing block when *finding* a submatch within a node (as 1544 # opposed to matching a node exactly). They may match any 1545 # contiguous subsequence. 1546 # - To allow sequence patterns to match starting later than the 1547 # beginning of a program sequence, matching is attempted 1548 # recursively with smaller and smaller suffixes of the program 1549 # sequence. 1550 # - To allow sequence patterns to match ending earlier 1551 # than the end of a program block, we implicitly ensure that a 1552 # sequence wildcard pattern terminates every top-level sequence 1553 # pattern. 1554 # TODO: Because permutations are applied later, this doesn't allow 1555 # all the matching we'd like in the following scenario 1556 # Node: [ alias(name="x"), alias(name="y"), alias(name="z") ] 1557 # Pat: [ alias(name="_a_"), alias(name="___") ] 1558 # here because order of aliases doesn't matter, we *should* be able 1559 # to bind _a_ to x, y, OR z, but it can only bind to x or z, NOT y 1560 if type(pat) == list and not set_var(pat[-1]): 1561 pat = pat + [parse_pattern('___')] 1562 pass 1563 1564 def findall_scalar_pat_iter(node): 1565 """Generate all matches of a scalar (non-list) pattern at this node 1566 or any non-excluded descendant of this node.""" 1567 assert type(pat) != list 1568 assert node is not None 1569 # Yield any environment(s) for match(es) of pattern at node. 1570 envs = [e for e in matchpat(node, pat, **kwargs)] 1571 if 0 < len(envs): 1572 yield (node, envs) 1573 # Continue the search for matches in sub-ASTs of node, only if 1574 # node is not excluded by a "match outside" pattern. 1575 if not any(match(node, op) for op in outside): 1576 for n in field_values(node): 1577 if n: # Search only within non-None children. 1578 for result in findall_scalar_pat_iter(n): 1579 yield result 1580 pass 1581 pass 1582 pass 1583 pass 1584 pass 1585 1586 def findall_list_pat_iter(node): 1587 """Generate all matches of a list pattern at this node or any 1588 non-excluded descendant of this node.""" 1589 assert type(pat) == list 1590 assert 0 < len(pat) 1591 if type(node) == list: 1592 # If searching against a list: 1593 # - match against the list itself 1594 # - search against the first child of the list 1595 # - search against the tail of the list 1596 1597 # Match against the list itself. 1598 # Yield any environment(s) for match(es) of pattern at node. 1599 envs = [e for e in matchpat(node, pat, **kwargs)] 1600 if 0 < len(envs): 1601 yield (node, envs) 1602 # Continue the search for matches in sub-ASTs of node, 1603 # only if node is not excluded by a "match outside" 1604 # pattern. 1605 if ( 1606 not any(match(node, op) for op in outside) 1607 and 0 < len(node) # only in nonempty nodes... 1608 ): 1609 # Search for matches in the first sub-AST. 1610 for m in findall_list_pat_iter(node[0]): 1611 yield m 1612 pass 1613 if not set_var(pat[0]): 1614 # If this sequence pattern does not start with 1615 # a sequence wildcard, then: 1616 # Search for matches in the tail of the list. 1617 # (Includes matches against the entire tail.) 1618 for m in findall_list_pat_iter(node[1:]): 1619 yield m 1620 pass 1621 pass 1622 pass 1623 pass 1624 elif not any(match(node, op) for op in outside): # and node is not list 1625 # A list pattern cannot match against a scalar node. 1626 # Search for matches in children of this scalar (non-list) 1627 # node, only if node is not excluded by a "match outside" 1628 # pattern. 1629 1630 # Optimize to search only where list patterns could match. 1631 1632 # Body blocks 1633 for ty in [ast.ClassDef, ast.FunctionDef, ast.With, ast.Module, 1634 ast.If, ast.For, ast.While, 1635 # ast.TryExcept, ast.TryFinally, 1636 # In Python 3, a single Try node covers both 1637 # TryExcept and TryFinally 1638 ast.Try, 1639 ast.ExceptHandler]: 1640 if isinstance(node, ty): 1641 for m in findall_list_pat_iter(node.body): 1642 yield m 1643 pass 1644 break 1645 pass 1646 # Block else blocks 1647 for ty in [ast.If, ast.For, ast.While, 1648 # ast.TryExcept 1649 # In Python 3, a single Try node covers both 1650 # TryExcept and TryFinally 1651 ast.Try 1652 ]: 1653 if isinstance(node, ty) and node.orelse: 1654 for m in findall_list_pat_iter(node.orelse): 1655 yield m 1656 pass 1657 break 1658 pass 1659 1660# # Except handler blocks 1661# if isinstance(node, ast.TryExcept): 1662# for h in node.handlers: 1663# for m in findall_list_pat_iter(h.body): 1664# yield m 1665# pass 1666# pass 1667# pass 1668# # finally blocks 1669# if isinstance(node, ast.TryFinally): 1670# for m in findall_list_pat_iter(node.finalbody): 1671# yield m 1672# pass 1673# pass 1674 1675 # In Python 3, a single Try node covers both TryExcept and 1676 # TryFinally 1677 if isinstance(node, ast.Try): 1678 for h in node.handlers: 1679 for m in findall_list_pat_iter(h.body): 1680 yield m 1681 pass 1682 pass 1683 pass 1684 1685 # General non-optimized version. 1686 # Must be mutually exclusive with the above if used. 1687 # for n in field_values(node): 1688 # if n: 1689 # for result in findall_list_pat_iter(n): 1690 # yield result 1691 # pass 1692 # pass 1693 # pass 1694 pass 1695 pass 1696 # Apply the right search based on pattern type. 1697 matches = (findall_list_pat_iter if type(pat) == list 1698 else findall_scalar_pat_iter)(node) 1699 # Return the generator or a list of all generated matches, 1700 # depending on gen. 1701 return matches if gen else list(matches) 1702 1703 1704def count(node, pat, **kwargs): 1705 """ 1706 Count all sub-ASTs matching pattern. Does NOT count individual 1707 environments that match (i.e., ways that bindings could attach at a 1708 given node), but rather counts nodes at which one or more bindings 1709 are possible. 1710 """ 1711 assert 'gen' not in kwargs 1712 return sum(1 for x in findall(node, pat, 1713 gen=True, **kwargs)) 1714 1715 1716#---------------------------------------------------------# 1717# EXPERIMENTAL: Normalize/Inline Simple Straightline Code # 1718#---------------------------------------------------------# 1719 1720class Unimplemented(Exception): 1721 pass 1722 1723 1724SIMPLE_INLINE_TYPES = [ 1725 ast.Expr, ast.Return, 1726 # ast.Print, # Removed from ast module in Python 3 1727 ast.Name, ast.Store, ast.Load, ast.Param, 1728 # Simple augs should be desugared already. 1729 # Those remaining are too tricky for a simple implementation. 1730] 1731 1732 1733class InlineAvailableExpressions(ast.NodeTransformer): 1734 def __init__(self, other=None): 1735 self.available = dict(other.available) if other else {} 1736 1737 def visit_Name(self, name): 1738 if isinstance(name.ctx, ast.Load) and name.id in self.available: 1739 return self.available[name.id] 1740 else: 1741 return self.generic_visit(name) 1742 1743 def visit_Assign(self, assign): 1744 raise Unimplemented 1745 new = ast.copy_location(ast.Assign( 1746 targets=assign.targets, 1747 value=self.visit(assign.value) 1748 ), assign) 1749 self.available[assign.targets[0].id] = new.value 1750 return new 1751 1752 def visit_If(self, ifelse): 1753 # Inline into the test. 1754 test = self.visit(ifelse.test) 1755 # Inline and accumulate in the then and else independently. 1756 body_inliner = InlineAvailableExpressions(self) 1757 orelse_inliner = InlineAvailableExpressions(self) 1758 body = body_inliner.inline_block(ifelse.body) 1759 orelse = orelse_inliner.inline_block(ifelse.orelse) 1760 # Any var->expression that is available after both branches 1761 # is available after. 1762 self.available = { 1763 name: body_inliner.available[name] 1764 for name in (set(body_inliner.available) 1765 & set(orelse_inliner.available)) 1766 if (body_inliner.available[name] == orelse_inliner.available[name]) 1767 } 1768 return ast.copy_location( 1769 ast.If(test=test, body=body, orelse=orelse), 1770 ifelse 1771 ) 1772 1773 def generic_visit(self, node): 1774 if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES): 1775 raise Unimplemented() 1776 return ast.NodeTransformer.generic_visit(self, node) 1777 1778 def inline_block(self, block): 1779 # Introduce duplicate common subexpressions... 1780 return [self.visit(stmt) for stmt in block] 1781 1782 1783class DeadCodeElim(ast.NodeTransformer): 1784 def __init__(self, other=None): 1785 self.used = set(other.users) if other else set() 1786 pass 1787 1788 def visit_Name(self, name): 1789 if isinstance(name.ctx, ast.Store): 1790 # Var store defines, removing from the set. 1791 self.used = self.used - {name.id} 1792 return name 1793 elif isinstance(name.ctx, ast.Load): 1794 # Var use uses, adding to the set. 1795 self.used = self.used | {name.id} 1796 return name 1797 else: 1798 return name 1799 pass 1800 1801 def visit_Assign(self, assign): 1802 # This restriction prevents worries about things like: 1803 # x, x[0] = [[1], 2] 1804 # By using this restriction it is safe to keep the single set, 1805 # thus order of removals and additions will not be a problem 1806 # since defs are discovered first (always to left of =), then 1807 # uses are discovered next (to right of =). 1808 assert all( 1809 ( 1810 node_is_name(t) 1811 or ( 1812 isinstance(t, ast.Tuple) 1813 and all(node_is_name(t) for t in t.elts) 1814 ) 1815 ) 1816 for t in assign.targets 1817 ) 1818 # Now handled by visit_Name 1819 # self.used = self.used - set(n.id for n in assign.targets) 1820 if (any(t.id in self.used for t in assign.targets if node_is_name(t)) 1821 or any(t.id in self.used 1822 for tup in assign.targets for t in tup 1823 if type(tup) == tuple and node_is_name(t))): 1824 return ast.copy_location(ast.Assign( 1825 targets=[self.visit(t) for t in assign.targets], 1826 value=self.visit(assign.value) 1827 ), assign) 1828 else: 1829 return None 1830 1831 def visit_If(self, ifelse): 1832 body_elim = DeadCodeElim(self) 1833 orelse_elim = DeadCodeElim(self) 1834 # DCE the body 1835 body = body_elim.elim_block(ifelse.body) 1836 # DCE the else 1837 orelse = orelse_elim.elim_block(ifelse.body) 1838 # Use the test -- TODO: could eliminate entire if sometimes. 1839 # Keep it for now for clarity. 1840 self.used = body_elim.used | orelse_elim.used 1841 test = self.visit(ifelse.test) 1842 return ast.copy_location(ast.If(test=test, body=body, orelse=orelse), 1843 ifelse) 1844 1845 def generic_visit(self, node): 1846 if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES): 1847 raise Unimplemented() 1848 return ast.NodeTransformer.generic_visit(self, node) 1849 1850 def elim_block(self, block): 1851 # Introduce duplicate common subexpressions... 1852 return [s for s in 1853 (self.visit(stmt) for stmt in block[::-1]) 1854 if s][::-1] 1855 1856 1857class NormalizePure(ast.NodeTransformer): 1858 """AST transformer: normalize/inline straightline assignments into 1859 single expressions as possible.""" 1860 def normalize_block(self, block): 1861 try: 1862 return DeadCodeElim().elim_block( 1863 InlineAvailableExpressions().inline_block(block) 1864 ) 1865 except Unimplemented: 1866 return block 1867 1868 def visit_FunctionDef(self, fun): 1869 # Lyn warning: previously commented code below is from Python 2 1870 # and won't work in Python 3 (args have changed). 1871 # assert 0 < len( 1872 # result.value[0].intersection(set(a.id for a in fun.args.args)) 1873 # ) 1874 normbody = self.normalize_block(fun.body) 1875 if normbody != fun.body: 1876 return ast.copy_location(ast.FunctionDef( 1877 name=fun.name, 1878 args=self.generic_visit(fun.args), 1879 body=normbody, 1880 ), fun) 1881 else: 1882 return fun 1883 1884 1885# Note 2021-7-29 Peter Mawhorter: Disabled use of DesugarAugAssign here 1886# because of concerns about students using list += string. If you want 1887# to match += you'll have to do so explicitly (and if that's a common use 1888# case, we can add something to specifications and/or patterns for that). 1889def canonical_pure(node): 1890 """Return the normalized/inlined version of an AST node.""" 1891 # print(type(fun)) 1892 # assert isinstance(fun, ast.FunctionDef) 1893 # print() 1894 # print('FUN', dump(fun)) 1895 if type(node) == list: 1896 return NormalizePure().normalize_block(node) 1897 # [DesugarAugAssign().visit(stmt) for stmt in node] 1898 #) 1899 else: 1900 assert isinstance(node, ast.AST) 1901 #return NormalizePure().visit(DesugarAugAssign().visit(node)) 1902 return NormalizePure().visit(node) 1903 1904 1905CANON_CACHE = {} 1906 1907 1908def parse_canonical_pure(string, toplevel=False): 1909 """Parse a normalized/inlined version of a program.""" 1910 if type(string) == list: 1911 return [parse_canonical_pure(x) for x in string] 1912 elif string not in CANON_CACHE: 1913 CANON_CACHE[string] = canonical_pure( 1914 parse_pattern(string, toplevel=toplevel) 1915 ) 1916 pass 1917 return CANON_CACHE[string] 1918 1919#------------------------------------# 1920# Pretty Print ASTs to Python Source # 1921#------------------------------------# 1922 1923 1924INDENT = ' ' 1925 1926 1927def indent(pat, indent=INDENT): 1928 """Apply indents to a source string.""" 1929 return indent + pat.replace('\n', '\n' + indent) 1930 1931 1932class SourceFormatter(ast.NodeVisitor): 1933 """AST visitor: pretty print AST to python source string""" 1934 def __init__(self): 1935 ast.NodeVisitor.__init__(self) 1936 self._indent = '' 1937 pass 1938 1939 def indent(self): 1940 self._indent += INDENT 1941 pass 1942 1943 def unindent(self): 1944 self._indent = self._indent[:-4] 1945 pass 1946 1947 def line(self, ln): 1948 return self._indent + ln + '\n' 1949 1950 def lines(self, lst): 1951 return ''.join(lst) 1952 1953 def generic_visit(self, node): 1954 assert False, 'visiting {}'.format(ast.dump(node)) 1955 pass 1956 1957 def visit_Module(self, m): 1958 return self.lines(self.visit(n) for n in m.body) 1959 1960 def visit_Interactive(self, i): 1961 return self.lines(self.visit(n) for n in i.body) 1962 1963 def visit_Expression(self, e): 1964 return self.line(self.visit(e.body)) 1965 1966 def visit_FunctionDef(self, f): 1967 assert not f.decorator_list 1968 header = self.line('def {name}({args}):'.format( 1969 name=f.name, 1970 args=self.visit(f.args)) 1971 ) 1972 self.indent() 1973 body = self.lines(self.visit(s) for s in f.body) 1974 self.unindent() 1975 return header + body + '\n' 1976 1977 def visit_ClassDef(self, c): 1978 assert not c.decorator_list 1979 header = self.line('class {name}({bases}):'.format( 1980 name=c.name, 1981 bases=', '.join(self.visit(b) for b in c.bases) 1982 )) 1983 self.indent() 1984 body = self.lines(self.visit(s) for s in c.body) 1985 self.unindent() 1986 return header + body + '\n' 1987 1988 def visit_Return(self, r): 1989 return self.line('return' if r.value is None 1990 else 'return {}'.format(self.visit(r.value))) 1991 1992 def visit_Delete(self, d): 1993 return self.line('del ' + ''.join(self.visit(e) for e in d.targets)) 1994 1995 def visit_Assign(self, a): 1996 return self.line(', '.join(self.visit(e) 1997 for e in a.targets) 1998 + ' = ' + self.visit(a.value)) 1999 2000 def visit_AugAssign(self, a): 2001 return self.line('{target} {op}= {expr}'.format( 2002 target=self.visit(a.target), 2003 op=self.visit(a.op), 2004 expr=self.visit(a.value)) 2005 ) 2006 2007# Print removed as ast node on Python 3 2008# def visit_Print(self, p): 2009# assert p.dest == None 2010# return self.line('print {}{}'.format( 2011# ', '.join(self.visit(e) for e in p.values), 2012# ',' if p.values and not p.nl else '' 2013# )) 2014 2015 def visit_For(self, f): 2016 header = self.line('for {} in {}:'.format( 2017 self.visit(f.target), 2018 self.visit(f.iter)) 2019 ) 2020 self.indent() 2021 body = self.lines(self.visit(s) for s in f.body) 2022 orelse = self.lines(self.visit(s) for s in f.orelse) 2023 self.unindent() 2024 return header + body + ( 2025 self.line('else:') + orelse 2026 if f.orelse else '' 2027 ) 2028 2029 def visit_While(self, w): 2030 # Peter 2021-6-16: Removed this assert; orelse isn't defined here 2031 # assert not orelse 2032 header = self.line('while {}:'.format(self.visit(w.test))) 2033 self.indent() 2034 body = self.lines(self.visit(s) for s in w.body) 2035 orelse = self.lines(self.visit(s) for s in w.orelse) 2036 self.unindent() 2037 return header + body + ( 2038 self.line('else:') + orelse 2039 if w.orelse else '' 2040 ) 2041 return header + body 2042 2043 def visit_If(self, i): 2044 header = self.line('if {}:'.format(self.visit(i.test))) 2045 self.indent() 2046 body = self.lines(self.visit(s) for s in i.body) 2047 orelse = self.lines(self.visit(s) for s in i.orelse) 2048 self.unindent() 2049 return header + body + ( 2050 self.line('else:') + orelse 2051 if i.orelse else '' 2052 ) 2053 2054 def visit_With(self, w): 2055 # Converted to Python3 withitems: 2056 header = self.line( 2057 'with {items}:'.format( 2058 items=', '.join( 2059 '{expr}{asnames}'.format( 2060 expr=self.visit(item.context_expr), 2061 asnames=('as ' + self.visit(item.optional_vars) 2062 if item.optional_vars else '') 2063 ) 2064 for item in w.items 2065 ) 2066 ) 2067 ) 2068 self.indent() 2069 body = self.lines(self.visit(s) for s in w.body) 2070 self.unindent() 2071 return header + body 2072 2073 # Python 3: raise has new abstract syntax 2074 def visit_Raise(self, r): 2075 return self.line('raise{}{}'.format( 2076 (' ' + self.visit(r.exc)) if r.exc else '', 2077 (' from ' + self.visit(r.cause)) if r.cause else '' 2078 )) 2079 2080# def visit_Raise(self, r): 2081# return self.line('raise{}{}{}{}{}'.format( 2082# self.visit(r.type) if r.type else '', 2083# ', ' if r.type and r.inst else '', 2084# self.visit(r.inst) if r.inst else '', 2085# ', ' if (r.type or r.inst) and r.tback else '', 2086# self.visit(r.tback) if r.tback else '' 2087# )) 2088 2089# def visit_TryExcept(self, te): 2090# self.indent() 2091# tblock = self.lines(self.visit(s) for s in te.body) 2092# orelse = self.lines(self.visit(s) for s in te.orelse) 2093# self.unindent() 2094# return ( 2095# self.line('try:') 2096# + tblock 2097# + ''.join(self.visit(eh) for eh in te.handlers) 2098# + (self.line('else:') + orelse if orelse else '' ) 2099# ) 2100 2101# def visit_TryFinally(self, tf): 2102# self.indent() 2103# tblock = self.lines(self.visit(s) for s in tf.body) 2104# fblock = self.lines(self.visit(s) for s in tf.finalbody) 2105# self.unindent() 2106# return ( 2107# self.line('try:') 2108# + tblock 2109# + self.line('finally:') 2110# + fblock 2111# ) 2112 2113 # In Python 3, a single Try node covers both TryExcept and TryFinally 2114 def visit_Try(self, t): 2115 self.indent() 2116 tblock = self.lines(self.visit(s) for s in t.body) 2117 orelse = self.lines(self.visit(s) for s in t.orelse) 2118 fblock = self.lines(self.visit(s) for s in t.finalbody) 2119 self.unindent() 2120 return ( 2121 self.line('try:') 2122 + tblock 2123 + ''.join(self.visit(eh) for eh in t.handlers) 2124 + (self.line('else:') + orelse if orelse else '' ) 2125 + (self.line('finally:') + fblock if fblock else '' ) 2126 ) 2127 2128 def visit_ExceptHandler(self, eh): 2129 header = self.line('except{}{}{}{}:'.format( 2130 ' ' if eh.type else '', 2131 self.visit(eh.type) if eh.type else '', 2132 ' as ' if eh.type and eh.name else ' ' if eh.name else '', 2133 self.visit(eh.name) if eh.name and isinstance(eh.name, ast.AST) 2134 else (eh.name if eh.name else '') 2135 )) 2136 self.indent() 2137 body = self.lines(self.visit(s) for s in eh.body) 2138 self.unindent() 2139 return header + body 2140 2141 def visit_Assert(self, a): 2142 return self.line('assert {}{}{}'.format( 2143 self.visit(a.test), 2144 ', ' if a.msg else '', 2145 self.visit(a.msg) if a.msg else '' 2146 )) 2147 2148 def visit_Import(self, i): 2149 return self.line( 2150 'import {}'.format(', '.join(self.visit(n) for n in i.names)) 2151 ) 2152 2153 def visit_ImportFrom(self, f): 2154 return self.line('from {}{} import {}'.format( 2155 '.' * f.level, 2156 f.module if f.module else '', 2157 ', '.join(self.visit(n) for n in f.names) 2158 )) 2159 2160 def visit_Exec(self, e): 2161 return self.line('exec {}{}{}{}{}'.format( 2162 self.visit(e.body), 2163 ' in ' if e.globals else '', 2164 self.visit(e.globals) if e.globals else '', 2165 ', ' if e.locals else '', 2166 self.visit(e.locals) if e.locals else '' 2167 )) 2168 2169 def visit_Global(self, g): 2170 return self.line('global {}'.format(', '.join(g.names))) 2171 2172 def visit_Expr(self, e): 2173 return self.line(self.visit(e.value)) 2174 2175 def visit_Pass(self, p): 2176 return self.line('pass') 2177 2178 def visit_Break(self, b): 2179 return self.line('break') 2180 2181 def visit_Continue(self, c): 2182 return self.line('continue') 2183 2184 def visit_BoolOp(self, b): 2185 return ' {} '.format( 2186 self.visit(b.op) 2187 ).join('({})'.format(self.visit(e)) for e in b.values) 2188 2189 def visit_BinOp(self, b): 2190 return '({}) {} ({})'.format( 2191 self.visit(b.left), 2192 self.visit(b.op), 2193 self.visit(b.right) 2194 ) 2195 2196 def visit_UnaryOp(self, u): 2197 return '{} ({})'.format( 2198 self.visit(u.op), 2199 self.visit(u.operand) 2200 ) 2201 2202 def visit_Lambda(self, ld): 2203 return '(lambda {}: {})'.format( 2204 self.visit(ld.args), 2205 self.visit(ld.body) 2206 ) 2207 2208 def visit_IfExp(self, i): 2209 return '({} if {} else {})'.format( 2210 self.visit(i.body), 2211 self.visit(i.test), 2212 self.visit(i.orelse) 2213 ) 2214 2215 def visit_Dict(self, d): 2216 return '{{ {} }}'.format( 2217 ', '.join('{}: {}'.format(self.visit(k), self.visit(v)) 2218 for k, v in zip(d.keys, d.values)) 2219 ) 2220 2221 def visit_Set(self, s): 2222 return '{{ {} }}'.format(', '.join(self.visit(e) for e in s.elts)) 2223 2224 def visit_ListComp(self, lc): 2225 return '[{} {}]'.format( 2226 self.visit(lc.elt), 2227 ' '.join(self.visit(g) for g in lc.generators) 2228 ) 2229 2230 def visit_SetComp(self, sc): 2231 return '{{{} {}}}'.format( 2232 self.visit(sc.elt), 2233 ' '.join(self.visit(g) for g in sc.generators) 2234 ) 2235 2236 def visit_DictComp(self, dc): 2237 return '{{{} {}}}'.format( 2238 '{}: {}'.format(self.visit(dc.key), self.visit(dc.value)), 2239 ' '.join(self.visit(g) for g in dc.generators) 2240 ) 2241 2242 def visit_GeneratorExp(self, ge): 2243 return '({} {})'.format( 2244 self.visit(ge.elt), 2245 ' '.join(self.visit(g) for g in ge.generators) 2246 ) 2247 2248 def visit_Yield(self, y): 2249 return 'yield {}'.format(self.visit(y.value) if y.value else '') 2250 2251 def visit_Compare(self, c): 2252 assert len(c.ops) == len(c.comparators) 2253 return '{} {}'.format( 2254 '({})'.format(self.visit(c.left)), 2255 ' '.join( 2256 '{} ({})'.format(self.visit(op), self.visit(expr)) 2257 for op, expr in zip(c.ops, c.comparators) 2258 ) 2259 ) 2260 2261 def visit_Call(self, c): 2262 # return '{fun}({args}{keys}{starargs}{starstarargs})'.format( 2263 # Unlike Python 2, Python 3 has no starargs or startstarargs 2264 return '{fun}({args}{keys})'.format( 2265 fun=self.visit(c.func), 2266 args=', '.join(self.visit(a) for a in c.args), 2267 keys=( 2268 (', ' if c.args else '') 2269 + ( 2270 ', '.join(self.visit(ka) for ka in c.keywords) 2271 if c.keywords else '' 2272 ) 2273 ) 2274 ) 2275 2276 def visit_Repr(self, r): 2277 return 'repr({})'.format(self.visit(r.expr)) 2278 2279 def visit_Num(self, n): 2280 return repr(n.n) 2281 2282 def visit_Str(self, s): 2283 return repr(s.s) 2284 2285 def visit_Attribute(self, a): 2286 return '{}.{}'.format(self.visit(a.value), a.attr) 2287 2288 def visit_Subscript(self, s): 2289 return '{}[{}]'.format(self.visit(s.value), self.visit(s.slice)) 2290 2291 def visit_Name(self, n): 2292 return n.id 2293 2294 def visit_List(self, ls): 2295 return '[{}]'.format(', '.join(self.visit(e) for e in ls.elts)) 2296 2297 def visit_Tuple(self, tp): 2298 return '({})'.format(', '.join(self.visit(e) for e in tp.elts)) 2299 2300 def visit_Ellipsis(self, s): 2301 return '...' 2302 2303 def visit_Slice(self, s): 2304 return '{}:{}{}{}'.format( 2305 self.visit(s.lower) if s.lower else '', 2306 self.visit(s.upper) if s.upper else '', 2307 ':' if s.step else '', 2308 self.visit(s.step) if s.step else '' 2309 ) 2310 2311 def visit_ExtSlice(self, es): 2312 return ', '.join(self.visit(s) for s in es.dims) 2313 2314 def visit_Index(self, i): 2315 return self.visit(i.value) 2316 2317 def visit_And(self, a): 2318 return 'and' 2319 2320 def visit_Or(self, o): 2321 return 'or' 2322 2323 def visit_Add(self, a): 2324 return '+' 2325 2326 def visit_Sub(self, a): 2327 return '-' 2328 2329 def visit_Mult(self, a): 2330 return '*' 2331 2332 def visit_Div(self, a): 2333 return '/' 2334 2335 def visit_Mod(self, a): 2336 return '%' 2337 2338 def visit_Pow(self, a): 2339 return '**' 2340 2341 def visit_LShift(self, a): 2342 return '<<' 2343 2344 def visit_RShift(self, a): 2345 return '>>' 2346 2347 def visit_BitOr(self, a): 2348 return '|' 2349 2350 def visit_BixXor(self, a): 2351 return '^' 2352 2353 def visit_BitAnd(self, a): 2354 return '&' 2355 2356 def visit_FloorDiv(self, a): 2357 return '//' 2358 2359 def visit_Invert(self, a): 2360 return '~' 2361 2362 def visit_Not(self, a): 2363 return 'not' 2364 2365 def visit_UAdd(self, a): 2366 return '+' 2367 2368 def visit_USub(self, a): 2369 return '-' 2370 2371 def visit_Eq(self, a): 2372 return '==' 2373 2374 def visit_NotEq(self, a): 2375 return '!=' 2376 2377 def visit_Lt(self, a): 2378 return '<' 2379 2380 def visit_LtE(self, a): 2381 return '<=' 2382 2383 def visit_Gt(self, a): 2384 return '>' 2385 2386 def visit_GtE(self, a): 2387 return '>=' 2388 2389 def visit_Is(self, a): 2390 return 'is' 2391 2392 def visit_IsNot(self, a): 2393 return 'is not' 2394 2395 def visit_In(self, a): 2396 return 'in' 2397 2398 def visit_NotIn(self, a): 2399 return 'not in' 2400 2401 def visit_comprehension(self, c): 2402 return 'for {} in {}{}{}'.format( 2403 self.visit(c.target), 2404 self.visit(c.iter), 2405 ' ' if c.ifs else '', 2406 ' '.join('if {}'.format(self.visit(i)) for i in c.ifs) 2407 ) 2408 2409 def visit_arg(self, a): 2410 '''[2019/01/22, lyn] Handle new arg objects in Python 3.''' 2411 return a.arg # The name of the argument 2412 2413 def visit_keyword(self, k): 2414 return '{}={}'.format(k.arg, self.visit(k.value)) 2415 2416 def visit_alias(self, a): 2417 return '{} as {}'.format(a.name, a.asname) if a.asname else a.name 2418 2419 def visit_arguments(self, a): 2420 # [2019/01/22, lyn] Note: This does *not* handle Python 3's 2421 # keyword-only arguments (probably moot for 111, but not 2422 # beyond). 2423 stdargs = a.args[:-len(a.defaults)] if a.defaults else a.args 2424 defargs = ( 2425 zip(a.args[-len(a.defaults):], a.defaults) 2426 if a.defaults else [] 2427 ) 2428 return '{stdargs}{sep1}{defargs}{sep2}{varargs}{sep3}{kwargs}'.format( 2429 stdargs=', '.join(self.visit(sa) for sa in stdargs), 2430 sep1=', ' if 0 < len(stdargs) and defargs else '', 2431 defargs=', '.join('{}={}'.format(self.visit(da), self.visit(dd)) 2432 for da, dd in defargs), 2433 sep2=', ' if 0 < len(a.args) and a.vararg else '', 2434 varargs='*{}'.format(a.vararg) if a.vararg else '', 2435 sep3=', ' if (0 < len(a.args) or a.vararg) and a.kwarg else '', 2436 kwargs='**{}'.format(a.kwarg) if a.kwarg else '' 2437 ) 2438 2439 def visit_NameConstant(self, nc): 2440 return str(nc.value) 2441 2442 def visit_Starred(self, st): 2443 # TODO: Is this correct? 2444 return '*' + st.value.id 2445 2446 2447def ast2source(node): 2448 """Pretty print an AST as a python source string""" 2449 return SourceFormatter().visit(node) 2450 2451 2452def source(node): 2453 """Alias for ast2source""" 2454 return ast2source(node) 2455 2456 2457#---------# 2458# Testing # 2459#---------# 2460 2461if __name__ == '__main__': 2462 tests = [ 2463 (True, 1, '2', '2'), 2464 (False, 0, '2', '3'), 2465 (True, 1, 'x', '_a_'), 2466 (True, 1, '(x, y)', '(_a_, _b_)'), 2467 (False, 0, '(x, y)', '(_a_, _a_)'), 2468 (True, 1, '(x, x)', '(_a_, _a_)'), 2469 (True, 4, 'max(7,3)', '_x_'), 2470 (True, 1, 'max(7,3)', 'max(7,3)'), 2471 (False, 0, 'max(7,2)', 'max(7,3)'), 2472 (True, 1, 'max(7,3,5)', 'max(___args___)'), 2473 (False, 1, 'min(max(7,3),5)', 'max(___args___)'), 2474 (True, 2, 'min(max(7,3),5)', '_f_(___args___)'), 2475 (True, 1, 'min(max(7,3),5)', 'min(max(___maxargs___),___minargs___)'), 2476 (True, 1, 'max()', 'max(___args___)'), 2477 (True, 1, 'max(4)', 'max(4,___args___)'), 2478 (True, 1, 'max(4,5,6)', 'max(4,___args___)'), 2479 (True, 1, '"hello %s" % x', '_a_str_ % _b_'), 2480 (False, 0, 'y % x', '_a_str_ % _b_'), 2481 (False, 0, '7 % x', '_a_str_ % _b_'), 2482 (True, 1, '3', '_a_int_'), 2483 (True, 1, '3.4', '_a_float_'), 2484 (False, 0, '3', '_a_float_'), 2485 (False, 0, '3.4', '_a_int_'), 2486 (True, 1, 'True', '_a_bool_'), 2487 (True, 1, 'False', '_a_bool_'), 2488 (True, 1, 'None', 'None'), 2489 # node vard can bind statements or exprs based on context. 2490 (True, 7, 'print("hello"+str(3))', '_x_'), # 6 => 7 in Python 3 2491 (True, 1, 'print("hello"+str(3))', 'print(_x_)'), 2492 (True, 1, 'print(1)', 'print(_x_, ___args___)'), 2493 (False, 0, 'print(1)', 'print(_x_, _y_, ___args___)'), 2494 (False, 0, 'print(1, 2)', 'print(_x_)'), 2495 (True, 1, 'print(1, 2)', 'print(_x_, ___args___)'), 2496 (True, 1, 'print(1, 2)', 'print(_x_, _y_, ___args___)'), 2497 (True, 1, 'print(1, 2, 3)', 'print(_x_, ___args___)'), 2498 (True, 1, 'print(1, 2, 3)', 'print(_x_, _y_, ___args___)'), 2499 (True, 1, 2500 ''' 2501def f(x): 2502 return 17 2503 ''', 2504 ''' 2505def f(_a_): 2506 return _b_ 2507 '''), 2508 (True, 1, 2509 ''' 2510def f(x): 2511 return x 2512 ''', 2513 ''' 2514def f(_a_): 2515 return _b_ 2516 '''), 2517 (True, 1, 2518 ''' 2519def f(x): 2520 return x 2521 ''', 2522 ''' 2523def f(_a_): 2524 return _a_ 2525 '''), 2526 (False, 0, 2527 ''' 2528def f(x): 2529 return 17 2530 ''', 2531 ''' 2532def f(_a_): 2533 return _a_ 2534 '''), 2535 (True, 1, 2536 ''' 2537def f(x): 2538 return 17 2539 ''', 2540 ''' 2541def f(_x_): 2542 return _y_ 2543 '''), 2544 (True, 1, 2545 ''' 2546def f(x,y): 2547 print('hi') 2548 return x 2549 ''', 2550 ''' 2551def f(_x_,_y_): 2552 print('hi') 2553 return _x_ 2554 '''), 2555 (True, 1, 2556 ''' 2557def f(x,y): 2558 print('hi') 2559 return x 2560 ''', 2561 ''' 2562def _f_(_x_,_y_): 2563 print('hi') 2564 return _x_ 2565 '''), 2566 (False, 0, 2567 ''' 2568def f(x,y): 2569 print('hi') 2570 return y 2571 ''', 2572 ''' 2573def f(_a_,_b_): 2574 print('hi') 2575 return _a_ 2576 '''), 2577 (False, 0, 'x', 'y'), 2578 (True, 1, 2579 ''' 2580def f(x,y): 2581 print('hi') 2582 return y 2583 ''', 2584 ''' 2585def _f_(_x_,_y_): 2586 ___ 2587 return _y_ 2588 '''), 2589 (True, 1, 2590 ''' 2591def f(x,y): 2592 print('hi') 2593 print('world') 2594 print('bye') 2595 return y 2596 ''', 2597 ''' 2598def _f_(_x_,_y_): 2599 ___stmts___ 2600 return _y_ 2601 '''), 2602 (False, 0, 2603 ''' 2604def f(x,y): 2605 print('hi') 2606 print('world') 2607 x = 4 2608 print('really') 2609 y = 7 2610 print('bye') 2611 return y 2612 ''', 2613 ''' 2614def _f_(_x_,_y_): 2615 ___stmts___ 2616 print(_z_) 2617 _y_ = _a_int_ 2618 return _y_ 2619 '''), 2620 (True, 1, 2621 ''' 2622def f(x,y): 2623 print('hi') 2624 print('world') 2625 x = 4 2626 print('really') 2627 y = 7 2628 print('bye') 2629 return y 2630 ''', 2631 ''' 2632def _f_(_x_,_y_): 2633 ___stmts___ 2634 print(_z_) 2635 _y_ = _a_int_ 2636 ___more___ 2637 return _y_ 2638 '''), 2639 (True, 1, 2640 ''' 2641def f(x,y): 2642 print('hi') 2643 print('world') 2644 x = 4 2645 print('really') 2646 y = 7 2647 print('bye') 2648 return y 2649 ''', 2650 ''' 2651def _f_(_x_,_y_): 2652 ___stmts___ 2653 print(_a_) 2654 _b_ = _c_ 2655 ___more___ 2656 return _d_ 2657 '''), 2658 (False, 1, 2659 ''' 2660def f(x,y): 2661 print('hi') 2662 print('world') 2663 x = 4 2664 print('really') 2665 y = 7 2666 print('bye') 2667 return y 2668 ''', 2669 ''' 2670___stmts___ 2671print(_a_) 2672_b_ = _c_ 2673___more___ 2674return _d_ 2675 '''), 2676 (True, 1, 2677 ''' 2678def eyes(): 2679 eye1 = Layer() 2680 eye2 = Layer() 2681 face = Layer() 2682 face.add(eye) 2683 face.add(eye2) 2684 return face 2685 ''', 2686 ''' 2687def eyes(): 2688 ___ 2689 _face_ = Layer() 2690 ___ 2691 return _face_ 2692 '''), 2693 (True, 1, 2694 ''' 2695def f(x,y): 2696 """I have a docstring. 2697 It is a couple lines long.""" 2698 eye = Layer() 2699 if eye: 2700 face = Layer() 2701 else: 2702 face = Layer() 2703 eye2 = Layer() 2704 face.add(eye) 2705 face.add(eye2) 2706 return face 2707 ''', 2708 ''' 2709def f(___args___): 2710 eye = Layer() 2711 ___ 2712 '''), 2713 (True, 1, '1 == 2', '2 == 1'), 2714 (True, 1, '1 <= 2', '2 >= 1'), 2715 (False, 0, '1 <= 2', '2 <= 1'), 2716 (False, 0, 'f() <= 2', '2 >= f()'), 2717 # Hmnm, is this the semantics we want for `and`? 2718 (True, 1, 'a and b and c', 'b and a and c'), 2719 (True, 1, '(a == b) == (b == c)', '(a == b) == (c == b)'), 2720 (True, 1, '(a and b) and c', 'a and (b and c)'), 2721 (True, 1, 'a and b', 'a and b'), 2722 (True, 1, 'g == "a" or g == "b" or g == "c"', 2723 '_g_ == _a_ or _g_ == _b_ or _c_ == _g_'), 2724 (True, 1, ''' 2725x = 1 2726y = 2 2727''', ''' 2728x = 1 2729y = 2 2730'''), 2731 (True, 1, ''' 2732x = 1 2733y = 2 2734''', ''' 2735___ 2736y = 2 2737'''), 2738 (True, 1, ''' 2739x = 1 2740if (a or b or c): 2741 return True 2742else: 2743 return False 2744 ''', 2745 ''' 2746___ 2747if _: 2748 return _a_bool_ 2749else: 2750 return _b_bool_ 2751___ 2752 '''), 2753 (True, 1, ''' 2754if (a or b or c): 2755 return True 2756else: 2757 return False 2758 ''', 2759 ''' 2760if _: 2761 return _a_bool_ 2762else: 2763 return _b_bool_ 2764 '''), 2765 (True, 1, ''' 2766x = 1 2767if (a or b or c): 2768 return True 2769else: 2770 return False 2771 ''', 2772 ''' 2773___ 2774if _: 2775 return _a_bool_ 2776else: 2777 return _b_bool_ 2778 '''), 2779 (False, 1, ''' 2780def f(): 2781 if (a or b or c): 2782 return True 2783 else: 2784 return False 2785 ''', 2786 ''' 2787___ 2788if _: 2789 return _a_bool_ 2790else: 2791 return _b_bool_ 2792___ 2793 '''), 2794 (False, 1, ''' 2795def isValidGesture(gesture): 2796 if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'): 2797 return True 2798 else: 2799 return False 2800 ''', ''' 2801if _: 2802 return _a_bool_ 2803else: 2804 return _b_bool_ 2805 '''), 2806 (False, 1, ''' 2807def isValidGesture(gesture): 2808 print('blah') 2809 if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'): 2810 return True 2811 return False 2812 ''', ''' 2813___ 2814if _: 2815 return _a_bool_ 2816return _b_bool_ 2817 '''), 2818 2819 (False, 1, ''' 2820def isValidGesture(gesture): 2821 print('blah') 2822 if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'): 2823 return True 2824 return False 2825 ''', ''' 2826if _: 2827 return _a_bool_ 2828return _b_bool_ 2829 '''), 2830 (False, 1, ''' 2831def isValidGesture(gesture): 2832 if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'): 2833 return True 2834 return False 2835 ''', ''' 2836if _: 2837 return _a_bool_ 2838return _b_bool_ 2839 '''), 2840 (False, 1, ''' 2841def isValidGesture(gesture): 2842 if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'): 2843 x = True 2844 x = False 2845 return x 2846 ''', ''' 2847if _: 2848 _x_ = _a_bool_ 2849_x_ = _b_bool_ 2850 '''), 2851 (True, 1, ''' 2852def isValidGesture(gesture): 2853 if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'): 2854 return True 2855 return False 2856 ''', ''' 2857def _(_): 2858 if _: 2859 return _a_bool_ 2860 return _b_bool_ 2861 '''), 2862 (True, 1, 'x, y = f()', '___vars___ = f()'), 2863 (True, 1, 'x, y = f()', '_vars_ = f()'), 2864 (True, 1, 'f(a=1, b=2)', 'f(b=2, a=1)'), 2865 (True, 1, ''' 2866def f(x,y): 2867 """with a docstring""" 2868 if level <= 0: 2869 pass 2870 else: 2871 fd(3) 2872 lt(90) 2873 ''', ''' 2874def f(_, _): 2875 ___ 2876 if _: 2877 ___t___ 2878 else: 2879 ___e___ 2880 ___s___ 2881'''), 2882 (True, 1, ''' 2883class A(B): 2884 def f(self, x): 2885 pass 2886 ''', 'class _(_): ___'), 2887 # (True, True, 2888 # 'for x, y in f():\n ___', 2889 # 'for ___vars___ in f():\n ___'), 2890 (False, 0, 'drawLs(size/2, level - 1)', 'drawLs(_size_/2.0, _)'), 2891 (False, 0, '2', '2.0'), 2892 (False, 0, '2', '_d_float_'), 2893 (True, 1, ''' 2894def keepFirstLetter(phrase): 2895 """Returns a new string that contains only the first occurrence of a 2896 letter from the original phrase. 2897 The first letter occurrence can be upper or lower case. 2898 Non-alpha characters (such as punctuations and space) are left unchanged. 2899 """ 2900 #this list holds lower-cased versions of all of the letters already used 2901 usedCharacters = [] 2902 2903 finalPhrase = "" 2904 for n in phrase: 2905 if n.isalpha(): 2906 #we need to create a temporary lower-cased version of the letter, 2907 #so that we can check and see if we've seen an upper or lower-cased 2908 #version of this letter before 2909 tempN = n.lower() 2910 if tempN not in usedCharacters: 2911 usedCharacters.append(tempN) 2912 #but we need to add the original n, so that we can preserve 2913 #if it was upper cased or not 2914 finalPhrase = finalPhrase + n 2915 2916 #this adds all non-letter characters into the final phrase list 2917 else: 2918 finalPhrase = finalPhrase + n 2919 2920 return finalPhrase 2921 ''', ''' 2922def _(___): 2923 ___ 2924 _acc_ = "" 2925 ___ 2926 for _ in _: 2927 ___ 2928 return _acc_ 2929 ___ 2930 '''), 2931 (False, # Pattern should not match program 2932 1, # Pattern should be found within program 2933 # Program 2934 ''' 2935def keepFirstLetter(phrase): 2936 #this list holds lower-cased versions of all of the letters already used 2937 usedCharacters = [] 2938 2939 finalPhrase = "" 2940 for n in phrase: 2941 if n.isalpha(): 2942 #we need to create a temporary lower-cased version of the letter, 2943 #so that we can check and see if we've seen an upper or lower-cased 2944 #version of this letter before 2945 tempN = n.lower() 2946 if tempN not in usedCharacters: 2947 usedCharacters.append(tempN) 2948 #but we need to add the original n, so that we can preserve 2949 #if it was upper cased or not 2950 finalPhrase = finalPhrase + n 2951 2952 #this adds all non-letter characters into the final phrase list 2953 else: 2954 finalPhrase = finalPhrase + n 2955 2956 return finalPhrase 2957 ''', 2958 # Pattern 2959 ''' 2960___prefix___ 2961_acc_ = "" 2962___middle___ 2963for _ in _: 2964 ___ 2965return _acc_ 2966___suffix___ 2967 '''), 2968 (True, 1, ''' 2969a = 1 2970b = 2 2971c = 3 2972 ''', '''___; _x_ = _n_'''), 2973 (False, # Pattern should not match program 2974 1, # Pattern should be found within program 2975 # Program 2976 ''' 2977def f(a): 2978 x = "" 2979 return x 2980 ''', 2981 # Pattern 2982 ''' 2983_acc_ = "" 2984return _acc_ 2985 '''), 2986 2987 # Treatment of elses: 2988 (True, 1, 'if x: print(1)', 'if _: _'), 2989 (False, 0, 'if x: print(1)\nelse: pass', 'if _: _'), 2990 (True, 1, 'if x: print(1)\nelse: pass', 'if _: _\nelse: ___'), 2991 (False, 0, 'if x: print(1)', ''' 2992if _: _ 2993else: 2994 _ 2995 ___ 2996'''), 2997 (True, 1, 'if x: print(1)\nelse: pass', ''' 2998if _: _ 2999else: 3000 _ 3001 ___ 3002'''), 3003 3004 # If ExpandExplicitElsePattern is used: 3005 # (False, 0, 'if x: print(1)', 'if _: _\nelse: ___'), 3006 3007 # If ExpandExplicitElsePattern is NOT used: 3008 (True, 1, 'if x: print(1)', 'if _: _\nelse: ___'), 3009 3010 # Keyword arguments 3011 (True, 1, 'f(a=1)', 'f(a=1)'), 3012 (True, 1, 'f(a=1)', 'f(_kw_=1)'), 3013 (True, 1, 'f(a=1)', 'f(_kw_=_arg_)'), 3014 (True, 1, 'f(a=1)', 'f(_=1)'), 3015 (True, 1, 'f(a=1)', 'f(_=_arg_)'), 3016 (True, 1, 'f(a=1, b=2)', 'f(_x_=_, _y_=_)'), 3017 (True, 1, 'f(a=1, b=2)', 'f(_x_=2, _y_=_)'), 3018 (False, 0, 'f(a=1, b=2)', 'f(_x_=_)'), 3019 (False, 0, 'f(a=1, b=2)', 'f(_x_=_, _y_=_, _z_=_)'), 3020 (False, 0, 'f(a=1, b=2)', 'f(_x_=2, _y_=2)'), 3021 (True, 1, 'f(a=1, b=2)', 'f(_x_=_, b=_)'), 3022 (True, 1, 'f(a=1, b=2)', 'f(b=_, _x_=_)'), 3023 (True, 1, 'f(a=1+1, b=1+1)', 'f(_c_=_x_+_x_, _d_=_y_+_y_)'), 3024 (True, 1, 'f(a=1+1, b=2+2)', 'f(_c_=_x_+_x_, _d_=_y_+_y_)'), 3025 (True, 1, 'f(a=1+2, b=2+1)', 'f(_c_=_x_+_y_, _d_=_y_+_x_)'), 3026 (True, 1, 'f(a=1+1, b=1+1)', 'f(_c_=_x_+_x_, _d_=_x_+_x_)'), 3027 (False, 0, 'f(a=1+1, b=2+2)', 'f(_c_=_x_+_x_, _d_=_x_+_x_)'), 3028 (True, 1, 'f(a=1, b=2)', 'f(___=_)'), 3029 (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_)'), 3030 (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_, b=_)'), 3031 (True, 1, 'f(a=1, b=2)', 'f(b=_, ___kwargs___=_)'), 3032 (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_, a=_, b=_)'), 3033 (True, 1, 'f(a=1, b=2)', 'f(a=_, b=_, ___kwargs___=_)'), 3034 (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_, b=_, a=_)'), 3035 (True, 1, 'f(a=1, b=2)', 'f(b=_, a=_, ___kwargs___=_)'), 3036 (True, 1, 'f(a=1, b=2)', 'f(a=_, ___kwargs___=_, b=_)'), 3037 (True, 1, 'f(a=1, b=2)', 'f(b=_, ___kwargs___=_, a=_)'), 3038 (True, 1, 'b = 7; f(a=1, b=2)', '_x_ = _; f(_x_=_, _y_=_)'), 3039 (False, 0, 'b = 7; f(a=1, b=2)', '_x_ = _; f(_x_=1, _y_=_)'), 3040 3041 # and/or set wildcards 3042 (True, 1, 'x or y or z', '_x_ or ___rest___'), 3043 (True, 1, 'x or y or z', '_x_ or _y_ or ___rest___'), 3044 (True, 1, 'x or y or z', '_x_ or _y_ or _z_ or ___rest___'), 3045 3046 # kwarg matching 3047 (True, 1, 'def f(x=3):\n return x', 'def _f_(x=3):\n return x'), 3048 (True, 1, 'def f(x=3):\n return x', 'def _f_(_=3):\n return _'), 3049 (True, 1, 'def f(x=3):\n return x', 'def _f_(_=3):\n ___'), 3050 (True, 1, 'def f(x=3):\n return x', 'def _f_(_x_=3):\n return _x_'), 3051 (True, 1, 3052 'def f(y=7):\n return y', 'def _f_(_x_=_y_):\n return _x_'), 3053 (False, 0, 'def f(x=3):\n return x', 'def _f_(_y_=7):\n return _x_'), 3054 3055 # Succeeds! 3056 (True, 1, 'def f(x=12):\n return x', 'def _f_(_=_):\n ___'), 3057 3058 # TODO: Fails because we compare all defaults as one block 3059 # Zero extra keyword arguments: 3060 (True, 1, 'def f(x=17):\n return x', 'def _f_(_=17, ___=_):\n ___'), 3061 3062 # Should match because ___ has no default 3063 (True, 1, 'def f(x=17):\n return x', 'def _f_(___, _=17):\n ___'), 3064 3065 # Multiple kwargs 3066 (True, 1, 'def f(x=3, y=4):\n return x', 'def _f_(_=3, _=4):\n ___'), 3067 (True, 1, 'def f(x=5, y=6):\n return x', 'def _f_(_=_, _=_):\n ___'), 3068 # ___ doesn't match kwargs 3069 (False, 0, 'def f(x=7, y=8):\n return x', 'def _f_(___):\n ___'), 3070 # TODO: Fails because of default matching bug 3071 (True, 1, 'def f(x=9, y=10):\n return x', 'def _f_(___=_):\n ___'), 3072 3073 # Exact matching of kwarg expressions 3074 (True, 1, 'def f(x=y+3):\n return x', 'def _f_(_=y+3):\n ___'), 3075 3076 # Matching of kw-only args 3077 (True, 1, 3078 'def f(*a, x=5):\n return x', 3079 'def _f_(*_, _x_=5):\n return _x_'), 3080 # ___ does not match *_ 3081 (False, 0, 3082 'def f(*a, x=6):\n return x', 3083 'def _f_(___, _x_=6):\n return _x_'), 3084 3085 # Multiple kw-only args 3086 (True, 1, 3087 'def f(*a, x=5, y=6):\n return x, y', 3088 'def _f_(*_, _x_=5, _y_=6):\n return _x_, _y_'), 3089 (False, 0, 3090 'def f(*a, x=7, y=8):\n return x, y', 3091 'def _f_(___, _x_=7, _y_=8):\n return _x_, _y_'), 3092 3093 # Function with docstring (must use ast.get_docstring!) 3094 (False, 0, 'def f():\n """docstring"""', 'def _f_(___):\n _a_str_'), 3095 3096 # Function with docstring (using ast.get_docstring) 3097 (True, 1, 3098 'def f(x):\n """doc1"""\n return x', 3099 'def _f_(___):\n ___', 3100 lambda node, env: ast.get_docstring(node) is not None), 3101 3102 # Function without docstring (using ast.get_docstring) 3103 (False, 0, 3104 'def f(x):\n """doc2"""\n return x', 3105 'def _f_(___):\n ___', 3106 lambda node, env:( 3107 ast.get_docstring(node) is None 3108 or ast.get_docstring(node).strip() == '' 3109 )), 3110 (True, 1, 3111 'def f(x):\n """"""\n return x', 3112 'def _f_(___):\n ___', 3113 lambda node, env:( 3114 ast.get_docstring(node) is None 3115 or ast.get_docstring(node).strip() == '' 3116 )), 3117 (True, 1, 3118 'def nodoc(x):\n return x', 3119 'def _f_(___):\n ___', 3120 lambda node, env:( 3121 ast.get_docstring(node) is None 3122 or ast.get_docstring(node).strip() == '' 3123 )), 3124 3125 # TODO: Recursive matching of kwarg expressions 3126 (True, 1, 'def f(x=y+3):\n return x', 'def _f_(_=_+3):\n ___'), 3127 3128 # Function with multiple normal arguments 3129 (True, 1, 'def f(x, y, z):\n return x', 'def _f_(___):\n ___'), 3130 3131 # Matching redundant elif conditions 3132 ( 3133 True, 3134 1, 3135 'if x == 3:\n return x\nelif not x == 3 and x > 5:\n return x-2', 3136 'if _cond_:\n ___\nelif not _cond_ and ___:\n ___' 3137 ), 3138 ( # order matters 3139 False, 3140 0, 3141 'if x == 3:\n return x\nelif x > 5 and not x == 3:\n return x-2', 3142 'if _cond_:\n ___\nelif not _cond_ and ___:\n ___' 3143 ), 3144 ( # not == is not the same as != 3145 False, 3146 0, 3147 'if x == 3:\n return x\nelif x > 5 and x != 3:\n return x-2', 3148 'if _cond_:\n ___\nelif not _cond_ and ___:\n ___' 3149 ), 3150 ( # not == is not the same as != 3151 True, 3152 1, 3153 'if x == 3:\n return x\nelif x > 5 and not x == 3:\n return x-2', 3154 'if _cond_:\n ___\nelif ___ and not _cond_:\n ___' 3155 ), 3156 ( # not == is not the same as != 3157 True, 3158 1, 3159 'if x == 3:\n return x\nelif x > 5 and x != 3:\n return x-2', 3160 'if _n_ == _v_:\n ___\nelif ___ and _n_ != _v_:\n ___' 3161 ), 3162 ( # extra conditions do matter! 3163 False, 3164 0, 3165 'if x == 3:\n return x\nelif not x == 3 and x > 5:\n return x-2\n' 3166 + 'elif not x == 3 and x <= 5 and x < 0:\n return 0', 3167 'if _cond_:\n ___\nelif not _cond_ and ___:\n ___' 3168 ), 3169 ( # match extra conditions: 3170 True, 3171 1, 3172 'if x == 3:\n return x\nelif not x == 3 and x > 5:\n return x-2\n' 3173 + 'elif not x == 3 and x <= 5 and x < 0:\n return 0', 3174 'if _cond_:\n ___\nelif not _cond_ and ___:\n ___\nelif _:\n ___' 3175 ), 3176 ( # number of conditions must match exactly: 3177 False, 3178 0, 3179 'if x == 3:\n return x\nelif not x == 3 and x > 5:\n return x-2\n' 3180 + 'elif not x == 3 and x <= 5 and x < 0:\n return 0\n' 3181 + 'elif not x == 3 and x <= 5 and x >= 0 and x == 1:\n return 1.5', 3182 'if _cond_:\n ___\nelif not _cond_ and ___:\n ___\nelif _:\n ___' 3183 ), 3184 ( # matching with elif: 3185 True, 3186 2, 3187 'if x < 0:\n x += 1\nelif x < 10:\n x += 0.5\nelse:\n x += 0.25', 3188 'if _:\n ___\nelse:\n ___' 3189 ), 3190 ( # left/right branches are both okay 3191 True, 3192 1, 3193 'x == 3', 3194 '3 == x', 3195 ), 3196 ( # order does matter for some operators 3197 False, 3198 0, 3199 '1 + 2 + 3', 3200 '3 + 2 + 1' 3201 ), 3202 ( # but not for others 3203 True, 3204 1, 3205 '1 and 2 and 3', 3206 '3 and 2 and 1' 3207 ), 3208 ] 3209 testNum = 1 3210 passes = 0 3211 fails = 0 3212 #for (expect_match, expect_count, ns, ps) in tests: 3213 for test_spec in tests: 3214 if len(test_spec) == 4: 3215 expect_match, expect_count, ns, ps = test_spec 3216 matchpred = predtrue 3217 elif len(test_spec) == 5: 3218 expect_match, expect_count, ns, ps, matchpred = test_spec 3219 else: 3220 print("Tests must have 4 or 5 parts!") 3221 print(test_spec) 3222 exit(1) 3223 3224 print('$' + '-' * 60) 3225 print('Test #{}'.format(testNum)) 3226 testNum += 1 3227 n = parse(ns) 3228 n = ( 3229 n.body[0].value 3230 if type(n.body[0]) == ast.Expr and len(n.body) == 1 3231 else ( 3232 n.body[0] 3233 if len(n.body) == 1 3234 else n.body 3235 ) 3236 ) 3237 p = parse_pattern(ps) 3238 print('Program: %s => %s' % (ns, dump(n))) 3239 if isinstance(ns, ast.FunctionDef): 3240 print('Docstring: %s' % (ast.get_docstring(ns))) 3241 print('Pattern: %s => %s' % (ps, dump(p))) 3242 if matchpred != predtrue: 3243 print('Matchpred: {}'.format(matchpred.__name__)) 3244 for gen in [False, True]: 3245 result = match(n, p, gen=gen, matchpred=matchpred) 3246 if gen: 3247 result = takeone(result) 3248 pass 3249 passed = bool(result) == expect_match 3250 if passed: 3251 passes += 1 3252 else: 3253 fails += 1 3254 pass 3255 print('match(gen=%s): %s [%s]' % ( 3256 gen, bool(result), "PASS" if passed else "FAIL" 3257 )) 3258 # if result: 3259 # for (k,v) in result.value.items(): 3260 # print(" %s = %s" % (k, dump(v))) 3261 # pass 3262 # pass 3263 # pass 3264 opt = find(n, p, matchpred=matchpred) 3265 findpassed = bool(opt) == (0 < expect_count) 3266 if findpassed: 3267 passes += 1 3268 else: 3269 fails += 1 3270 pass 3271 print( 3272 'find: {} [{}]'.format( 3273 bool(opt), 3274 # dump(opt.value[0]) if opt else None, 3275 "PASS" if findpassed else "FAIL" 3276 ) 3277 ) 3278 # if opt: 3279 # for (k,v) in opt.value[1].items(): 3280 # print(" %s = %s" % (k, dump(v))) 3281 # pass 3282 c = count(n, p, matchpred=matchpred) 3283 if c == expect_count: 3284 passes += 1 3285 print('count: %s [PASS]' % c) 3286 else: 3287 fails += 1 3288 print('count: %s [FAIL], expected %d' % (c, expect_count)) 3289 pass 3290 print('findall:') 3291 nmatches = 0 3292 for (node, envs) in findall(n, p, matchpred=matchpred): 3293 print(" %d. %s" % (nmatches, dump(node))) 3294 for (i, env) in enumerate(envs): 3295 print(" %s. " % chr(ord('a') + i)) 3296 for (k, v) in env.items(): 3297 print(" %s = %s" % (k, dump(v))) 3298 pass 3299 pass 3300 # find should return first findall result. 3301 # assert 0 < nmatches or (opt and opt.value == (node, bindings)) 3302 nmatches += 1 3303 pass 3304 # count should count the same number of matches that findall finds. 3305 assert nmatches == c 3306 print() 3307 pass 3308 print("%d passed, %d failed" % (passes, fails)) 3309 3310 with open(__file__) as src: 3311 with open('self_printed_' + os.path.basename(__file__), 'w') as dest: 3312 dest.write(source(ast.parse(src.read())))
Alternative chain() constructor taking a single iterable argument that evaluates lazily.
98def dump(node): 99 """Nicer display of ASTs.""" 100 return ( 101 ast.dump(node) if isinstance(node, ast.AST) 102 else '[' + ', '.join(map(dump, node)) + ']' if type(node) == list 103 else '(' + ', '.join(map(dump, node)) + ')' if type(node) == tuple 104 else '{' + ', '.join(dump(key) + ': ' + dump(value) 105 for key, value in node.items()) + '}' 106 if type(node) == dict 107 # Sneaky way to have dump descend into Some and FiniteIterator objects 108 # for debugging 109 else node.dump(dump) 110 if (type(node) == Some or type(node) == FiniteIterator) 111 else repr(node) 112 )
Nicer display of ASTs.
115def showret(x, prefix=''): 116 """Shim for use debugging AST return values.""" 117 print(prefix, dump(x)) 118 return x
Shim for use debugging AST return values.
Spec for names of scalar pattern variables.
Types against which sequence patterns match.
Spec for names of sequence pattern variables.
Spec for names/types of typed literal pattern variables.
205def var_is_anonymous(identifier): 206 """Determine whether a pattern variable name (string) is anonymous.""" 207 assert type(identifier) == str # All Python 3 strings are unicode 208 return not re.search(r'[a-zA-Z0-9]', identifier)
Determine whether a pattern variable name (string) is anonymous.
211def node_is_name(node): 212 """Determine if a node is an AST Name node.""" 213 return isinstance(node, ast.Name)
Determine if a node is an AST Name node.
216def identifier_key(identifier, spec): 217 """Extract the name of a pattern variable from its identifier string, 218 returning an option: Some(key) if identifier is valid variable 219 name according to spec, otherwise None. 220 221 Examples for NODE_VAR_SPEC: 222 identifier_key('_a_') => Some('a') 223 identifier_key('_') => Some('_') 224 identifier_key('_a') => None 225 226 """ 227 assert type(identifier) == str # All Python 3 strings are unicode 228 groups, regex = spec 229 match = regex.match(identifier) 230 if match: 231 if var_is_anonymous(identifier): 232 return identifier 233 elif type(groups) == tuple: 234 return tuple(match.group(i) for i in groups) 235 else: 236 return match.group(groups) 237 else: 238 return None
Extract the name of a pattern variable from its identifier string, returning an option: Some(key) if identifier is valid variable name according to spec, otherwise None.
Examples for NODE_VAR_SPEC: identifier_key('_a_') => Some('a') identifier_key('_') => Some('_') identifier_key('_a') => None
241def node_var(pat, spec=NODE_VAR_SPEC, wrap=False): 242 """Extract the key name of a scalar pattern variable, 243 returning Some(key) if pat is a scalar pattern variable, 244 otherwise None. 245 246 A named or anonymous node variable pattern, written `_a_` or `_`, 247 respectively, may appear in any expression or identifier context 248 in a pattern. It matches any single AST in the corresponding 249 position in the target program. 250 251 """ 252 if wrap: 253 pat = ast.Name(id=pat, ctx=None) 254 elif isinstance(pat, ast.Expr): 255 pat = pat.value 256 elif isinstance(pat, ast.alias) and pat.asname is None: 257 # [Peter Mawhorter 2021-8-29] Want to treat aliases without an 258 # 'as' part kind of like normal Name nodes. 259 pat = ast.Name(id=pat.name, ctx=None) 260 return ( 261 identifier_key(pat.id, spec) 262 if node_is_name(pat) 263 else None 264 )
Extract the key name of a scalar pattern variable, returning Some(key) if pat is a scalar pattern variable, otherwise None.
A named or anonymous node variable pattern, written _a_
or _
,
respectively, may appear in any expression or identifier context
in a pattern. It matches any single AST in the corresponding
position in the target program.
272def set_var(pat, wrap=False): 273 """Extract the key name of a set or sequence pattern variable, 274 returning Some(key) if pat is a set or sequence pattern variable, 275 otherwise None. 276 277 A named or anonymous set or sequence pattern variable, written 278 `___a___` or `___`, respectively, may appear as an element of a 279 set or sequence context in a pattern. It matches 0 or more nodes 280 in the corresponding context in the target program. 281 282 """ 283 return node_var(pat, 284 spec=SET_VAR_SPEC, 285 wrap=wrap)
Extract the key name of a set or sequence pattern variable, returning Some(key) if pat is a set or sequence pattern variable, otherwise None.
A named or anonymous set or sequence pattern variable, written
___a___
or ___
, respectively, may appear as an element of a
set or sequence context in a pattern. It matches 0 or more nodes
in the corresponding context in the target program.
292def typed_lit_var(pat): 293 """Extract the key name of a typed literal pattern variable, 294 returning Some(key) if pat is a typed literal pattern variable, 295 otherwise None. 296 297 A typed literal variable pattern, written `_a_type_`, may appear 298 in any expression context in a pattern. It matches any single AST 299 node for a literal of the given primitive type in the 300 corresponding position in the target program. 301 302 """ 303 return node_var(pat, spec=TYPED_LIT_VAR_SPEC)
Extract the key name of a typed literal pattern variable, returning Some(key) if pat is a typed literal pattern variable, otherwise None.
A typed literal variable pattern, written _a_type_
, may appear
in any expression context in a pattern. It matches any single AST
node for a literal of the given primitive type in the
corresponding position in the target program.
317def is_pat(p): 318 """Determine if p could be a pattern (by type).""" 319 # All Python 3 strings are unicode 320 return isinstance(p, ast.AST) or type(p) == str
Determine if p could be a pattern (by type).
323def node_is_docstring(node): 324 """Is this node a docstring node?""" 325 return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)
Is this node a docstring node?
328def node_is_bindable(node): 329 """Can a node pattern variable bind to this node?""" 330 # return isinstance(node, ast.expr) or isinstance(node, ast.stmt) 331 # Modified to allow debugging prints 332 result = ( 333 isinstance(node, ast.expr) 334 or isinstance(node, ast.stmt) 335 or (isinstance(node, ast.alias) and node.asname is None) 336 ) 337 # print('\n$ node_is_bindable({}) => {}'.format(dump(node),result)) 338 return result
Can a node pattern variable bind to this node?
341def node_is_lit(node, ty): 342 """Is this node a literal primitive node?""" 343 return ( 344 (isinstance(node, ast.Num) and ty == type(node.n)) # noqa E721 345 # All Python 3 strings are unicode 346 or (isinstance(node, ast.Str) and ty == str) 347 or ( 348 isinstance(node, ast.Name) 349 and ty == bool 350 and (node.id == 'True' or node.id == 'False') 351 ) 352 )
Is this node a literal primitive node?
355def expr_has_type(node, ty): 356 """Is this expr statically guaranteed to be of this type? 357 358 Literals and conversions have definite types. All other types are 359 conservatively statically unknown. 360 361 """ 362 return node_is_lit(node, ty) or match(node, '{}(_)'.format(ty.__name__))
Is this expr statically guaranteed to be of this type?
Literals and conversions have definite types. All other types are conservatively statically unknown.
365def node_line(node): 366 """Bet the line number of the source line on which this node starts. 367 (best effort)""" 368 try: 369 return node.lineno if type(node) != list else node[0].lineno 370 except Exception: 371 return None
Bet the line number of the source line on which this node starts. (best effort)
378class PatternSyntaxError(BaseException): 379 """Exception for errors in pattern syntax.""" 380 def __init__(self, node, message): 381 BaseException.__init__( 382 self, 383 "At pattern line {}: {}".format(node_line(node), message) 384 )
Exception for errors in pattern syntax.
Inherited Members
- builtins.BaseException
- with_traceback
387class PatternValidator(ast.NodeVisitor): 388 """AST visitor: check pattern structure.""" 389 def __init__(self): 390 self.parent = None 391 pass 392 393 def generic_visit(self, pat): 394 oldparent = self.parent 395 self.parent = pat 396 try: 397 return ast.NodeVisitor.generic_visit(self, pat) 398 finally: 399 self.parent = oldparent 400 pass 401 pass 402 403 def visit_Name(self, pat): 404 sn = set_var(pat) 405 if sn: 406 if type(self.parent) not in SEQ_TYPES: 407 raise PatternSyntaxError( 408 pat, 409 "Set/sequence variable ({}) not allowed in {} node".format( 410 sn, type(self.parent) 411 ) 412 ) 413 pass 414 pass 415 416 def visit_arg(self, pat): 417 '''[2019/01/22, lyn] Python 3 now has arg object for param (not Name object). 418 So copied visit_Name here.''' 419 sn = set_var(pat) 420 if sn: 421 if type(self.parent) not in SEQ_TYPES: 422 raise PatternSyntaxError( 423 pat, 424 "Set/sequence variable ({}) not allowed in {} node".format( 425 sn.value, type(self.parent) 426 ) 427 ) 428 pass 429 pass 430 431 def visit_Call(self, c): 432 if 1 < sum(1 for kw in c.keywords if set_var_str(kw.arg)): 433 raise PatternSyntaxError( 434 c.keywords, 435 "Calls may use at most one keyword argument set variable." 436 ) 437 return self.generic_visit(c) 438 439 def visit_keyword(self, k): 440 if (identifier_key(k.arg, SET_VAR_SPEC) 441 and not (node_is_name(k.value) 442 and var_is_anonymous(k.value.id))): 443 raise PatternSyntaxError( 444 k.value, 445 "Value patterns for keyword argument set variables must be _." 446 ) 447 return self.generic_visit(k) 448 pass
AST visitor: check pattern structure.
393 def generic_visit(self, pat): 394 oldparent = self.parent 395 self.parent = pat 396 try: 397 return ast.NodeVisitor.generic_visit(self, pat) 398 finally: 399 self.parent = oldparent 400 pass 401 pass
Called if no explicit visitor function exists for a node.
416 def visit_arg(self, pat): 417 '''[2019/01/22, lyn] Python 3 now has arg object for param (not Name object). 418 So copied visit_Name here.''' 419 sn = set_var(pat) 420 if sn: 421 if type(self.parent) not in SEQ_TYPES: 422 raise PatternSyntaxError( 423 pat, 424 "Set/sequence variable ({}) not allowed in {} node".format( 425 sn.value, type(self.parent) 426 ) 427 ) 428 pass 429 pass
[2019/01/22, lyn] Python 3 now has arg object for param (not Name object). So copied visit_Name here.
Inherited Members
- ast.NodeVisitor
- visit
- visit_Constant
455class RemoveDocstrings(ast.NodeTransformer): 456 """AST Transformer: remove all docstring nodes.""" 457 def filterDocstrings(self, seq): 458 # print('PREFILTERED', seq) 459 filt = [self.visit(n) for n in seq 460 if not node_is_docstring(n)] 461 # print('FILTERED', dump(filt)) 462 return filt 463 464 def visit_Expr(self, node): 465 if isinstance(node.value, ast.Str): 466 assert False 467 # return ast.copy_location(ast.Expr(value=None), node) 468 else: 469 return self.generic_visit(node) 470 471 def visit_FunctionDef(self, node): 472 # print('Removing docstring: %s' % dump(node.body[0].value)) 473 return ast.copy_location(ast.FunctionDef( 474 name=node.name, 475 args=self.generic_visit(node.args), 476 body=self.filterDocstrings(node.body), 477 ), node) 478 479 def visit_Module(self, node): 480 return ast.copy_location(ast.Module( 481 body=self.filterDocstrings(node.body) 482 ), node) 483 484 def visit_For(self, node): 485 return ast.copy_location(ast.For( 486 body=self.filterDocstrings(node.body), 487 target=node.target, 488 iter=node.iter, 489 orelse=self.filterDocstrings(node.orelse) 490 ), node) 491 492 def visit_While(self, node): 493 return ast.copy_location(ast.While( 494 body=self.filterDocstrings(node.body), 495 test=node.test, 496 orelse=self.filterDocstrings(node.orelse) 497 ), node) 498 499 def visit_If(self, node): 500 return ast.copy_location(ast.If( 501 body=self.filterDocstrings(node.body), 502 test=node.test, 503 orelse=self.filterDocstrings(node.orelse) 504 ), node) 505 506 def visit_With(self, node): 507 return ast.copy_location(ast.With( 508 body=self.filterDocstrings(node.body), 509 # Python 3 just has withitems: 510 items=node.items 511 # Old Python 2 stuff: 512 #context_expr=node.context_expr, 513 #optional_vars=node.optional_vars 514 ), node) 515 516# def visit_TryExcept(self, node): 517# return ast.copy_location(ast.TryExcept( 518# body=self.filterDocstrings(node.body), 519# handlers=self.filterDocstrings(node.handlers), 520# orelse=self.filterDocstrings(node.orelse) 521# ), node) 522 523# def visit_TryFinally(self, node): 524# return ast.copy_location(ast.TryFinally( 525# body=self.filterDocstrings(node.body), 526# finalbody=self.filterDocstrings(node.finalbody) 527# ), node) 528# 529 530# In Python 3, a single Try node covers both TryExcept and TryFinally 531 def visit_Try(self, node): 532 return ast.copy_location(ast.Try( 533 body=self.filterDocstrings(node.body), 534 handlers=self.filterDocstrings(node.handlers), 535 orelse=self.filterDocstrings(node.orelse), 536 finalbody=self.filterDocstrings(node.finalbody) 537 ), node) 538 539 def visit_ClassDef(self, node): 540 return ast.copy_location(ast.ClassDef( 541 body=self.filterDocstrings(node.body), 542 name=node.name, 543 bases=node.bases, 544 decorator_list=node.decorator_list 545 ), node)
AST Transformer: remove all docstring nodes.
Inherited Members
- ast.NodeTransformer
- generic_visit
- ast.NodeVisitor
- visit
- visit_Constant
548class FlattenBoolOps(ast.NodeTransformer): 549 """AST transformer: flatten nested boolean expressions.""" 550 def visit_BoolOp(self, node): 551 values = [] 552 for v in node.values: 553 if isinstance(v, ast.BoolOp) and v.op == node.op: 554 values += v.values 555 else: 556 values.append(v) 557 pass 558 pass 559 return ast.copy_location(ast.BoolOp(op=node.op, values=values), node) 560 pass
AST transformer: flatten nested boolean expressions.
Inherited Members
- ast.NodeTransformer
- generic_visit
- ast.NodeVisitor
- visit
- visit_Constant
563class ContainsCall(Exception): 564 """Exception raised when prohibiting call expressions.""" 565 pass
Exception raised when prohibiting call expressions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
568class ProhibitCall(ast.NodeVisitor): 569 """AST visitor: check call-freedom.""" 570 def visit_Call(self, c): 571 raise ContainsCall 572 pass
AST visitor: check call-freedom.
Inherited Members
- ast.NodeVisitor
- visit
- generic_visit
- visit_Constant
575class SetLoadContexts(ast.NodeTransformer): 576 """ 577 Transforms any AugStore contexts into Load contexts, for use with 578 DesugarAugAssign. 579 """ 580 def visit_AugStore(self, ctx): 581 return ast.copy_location(ast.Load(), ctx)
Transforms any AugStore contexts into Load contexts, for use with DesugarAugAssign.
Inherited Members
- ast.NodeTransformer
- generic_visit
- ast.NodeVisitor
- visit
- visit_Constant
584class DesugarAugAssign(ast.NodeTransformer): 585 """AST transformer: desugar augmented assignments (e.g., `+=`) 586 when simple desugaring does not duplicate effectful expressions. 587 588 FIXME: this desugaring and others cause surprising match 589 counts. See: https://github.com/wellesleycs111/codder/issues/31 590 591 FIXME: This desugaring should probably not happen in cases where 592 .__iadd__ and .__add__ yield different results, for example with 593 lists where .__iadd__ is really extend while .__add__ creates a new 594 list, such that 595 596 [1, 2] + "34" 597 598 is an error, but 599 600 x = [1, 2] 601 x += "34" 602 603 is not! 604 605 A better approach might be to avoid desugaring entirely and instead 606 provide common collections of match rules for common patterns. 607 608 """ 609 def visit_AugAssign(self, assign): 610 try: 611 # Desugaring *all* AugAssigns is not sound. 612 # Example: xs[f()] += 1 --> xs[f()] = xs[f()] + 1 613 # Check for presence of call in target. 614 ProhibitCall().visit(assign.target) 615 return ast.copy_location( 616 ast.Assign( 617 targets=[self.visit(assign.target)], 618 value=ast.copy_location( 619 ast.BinOp( 620 left=SetLoadContexts().visit(assign.target), 621 op=self.visit(assign.op), 622 right=self.visit(assign.value) 623 ), 624 assign 625 ) 626 ), 627 assign 628 ) 629 except ContainsCall: 630 return self.generic_visit(assign) 631 632 def visit_AugStore(self, ctx): 633 return ast.copy_location(ast.Store(), ctx) 634 635 def visit_AugLoad(self, ctx): 636 return ast.copy_location(ast.Load(), ctx)
AST transformer: desugar augmented assignments (e.g., +=
)
when simple desugaring does not duplicate effectful expressions.
FIXME: this desugaring and others cause surprising match counts. See: https://github.com/wellesleycs111/codder/issues/31
FIXME: This desugaring should probably not happen in cases where .__iadd__ and .__add__ yield different results, for example with lists where .__iadd__ is really extend while .__add__ creates a new list, such that
[1, 2] + "34"
is an error, but
x = [1, 2]
x += "34"
is not!
A better approach might be to avoid desugaring entirely and instead provide common collections of match rules for common patterns.
609 def visit_AugAssign(self, assign): 610 try: 611 # Desugaring *all* AugAssigns is not sound. 612 # Example: xs[f()] += 1 --> xs[f()] = xs[f()] + 1 613 # Check for presence of call in target. 614 ProhibitCall().visit(assign.target) 615 return ast.copy_location( 616 ast.Assign( 617 targets=[self.visit(assign.target)], 618 value=ast.copy_location( 619 ast.BinOp( 620 left=SetLoadContexts().visit(assign.target), 621 op=self.visit(assign.op), 622 right=self.visit(assign.value) 623 ), 624 assign 625 ) 626 ), 627 assign 628 ) 629 except ContainsCall: 630 return self.generic_visit(assign)
Inherited Members
- ast.NodeTransformer
- generic_visit
- ast.NodeVisitor
- visit
- visit_Constant
639class ExpandExplicitElsePattern(ast.NodeTransformer): 640 """AST transformer: transform patterns that in include `else: ___` 641 to use `else: _; ___` instead, forcing the else to have at least 642 one statement.""" 643 def visit_If(self, ifelse): 644 if 1 == len(ifelse.orelse) and set_var(ifelse.orelse[0]): 645 return ast.copy_location( 646 ast.If( 647 test=ifelse.test, 648 body=ifelse.body, 649 orelse=[ 650 ast.copy_location( 651 ast.Expr(value=ast.Name(id='_', ctx=None)), 652 ifelse.orelse[0] 653 ), 654 ifelse.orelse[0] 655 ] 656 ), 657 ifelse 658 ) 659 else: 660 return self.generic_visit(ifelse)
AST transformer: transform patterns that in include else: ___
to use else: _; ___
instead, forcing the else to have at least
one statement.
643 def visit_If(self, ifelse): 644 if 1 == len(ifelse.orelse) and set_var(ifelse.orelse[0]): 645 return ast.copy_location( 646 ast.If( 647 test=ifelse.test, 648 body=ifelse.body, 649 orelse=[ 650 ast.copy_location( 651 ast.Expr(value=ast.Name(id='_', ctx=None)), 652 ifelse.orelse[0] 653 ), 654 ifelse.orelse[0] 655 ] 656 ), 657 ifelse 658 ) 659 else: 660 return self.generic_visit(ifelse)
Inherited Members
- ast.NodeTransformer
- generic_visit
- ast.NodeVisitor
- visit
- visit_Constant
663def pipe_visit(node, visitors): 664 """Send an AST through a pipeline of AST visitors/transformers.""" 665 if 0 == len(visitors): 666 return node 667 else: 668 v = visitors[0] 669 if isinstance(v, ast.NodeTransformer): 670 # visit for the transformed result 671 return pipe_visit(v.visit(node), visitors[1:]) 672 else: 673 # visit for the effect 674 v.visit(node) 675 return pipe_visit(node, visitors[1:])
Send an AST through a pipeline of AST visitors/transformers.
682def parse_file(path, docstrings=True): 683 """Load and parse a program AST from the given path.""" 684 with open(path) as f: 685 return parse(f.read(), filename=path, docstrings=docstrings) 686 pass
Load and parse a program AST from the given path.
Thunks for AST visitors/transformers that are applied during parsing.
Extra thunks for docstring mode?
710class MastParseError(Exception): 711 def __init__(self, pattern, error): 712 super(MastParseError, self).__init__( 713 'Error parsing pattern source string <<<\n{}\n>>>\n{}'.format( 714 pattern, str(error) 715 ) 716 ) 717 self.trigger = error # hang on to the original error 718 pass 719 pass
Common base class for all non-exit exceptions.
Inherited Members
- builtins.BaseException
- with_traceback
722def parse( 723 string, 724 docstrings=True, 725 filename='<unknown>', 726 passes=STANDARD_PASSES 727): 728 """Parse an AST from a string.""" 729 if type(string) == str: # All Python 3 strings are unicode 730 string = str(string) 731 assert type(string) == str # All Python 3 strings are unicode 732 if docstrings and RemoveDocstrings in passes: 733 passes.remove(RemoveDocstrings) 734 pipe = [thunk() for thunk in passes] 735 736 try: 737 parsed_ast = ast.parse(string, filename=filename) 738 except Exception as error: 739 raise MastParseError(string, error) 740 return pipe_visit(parsed_ast, pipe)
Parse an AST from a string.
a pattern parsing cache
755def parse_pattern(pat, toplevel=False, docstrings=True): 756 """Parse and validate a pattern.""" 757 if isinstance(pat, ast.AST): 758 PatternValidator().visit(pat) 759 return pat 760 elif type(pat) == list: 761 iter(lambda x: PatternValidator().visit(x), pat) 762 return pat 763 elif type(pat) == str: # All Python 3 strings are unicode 764 cache_key = (toplevel, docstrings, pat) 765 pat_ast = PATTERN_PARSE_CACHE.get(cache_key) 766 if not pat_ast: 767 # 2021-6-16 Peter commented this out (should it be used 768 # instead of passes on the line below?) 769 # pipe = [thunk() for thunk in PATTERN_PASSES] 770 pat_ast = parse(pat, filename='<pattern>', passes=PATTERN_PASSES) 771 PatternValidator().visit(pat_ast) 772 if not toplevel: 773 # Unwrap as needed. 774 if len(pat_ast.body) == 1: 775 # If the pattern is a single definition, statement, or 776 # expression, unwrap the Module node. 777 b = pat_ast.body[0] 778 pat_ast = b.value if isinstance(b, ast.Expr) else b 779 pass 780 else: 781 # If the pattern is a sequence of definitions or 782 # statements, validate from the top, but return only 783 # the Module body. 784 pat_ast = pat_ast.body 785 pass 786 pass 787 PATTERN_PARSE_CACHE[cache_key] = pat_ast 788 pass 789 return pat_ast 790 else: 791 assert False, 'Cannot parse pattern of type {}: {}'.format( 792 type(pat), 793 dump(pat) 794 )
Parse and validate a pattern.
797def pat(pat, toplevel=False, docstrings=True): 798 """Alias for parse_pattern.""" 799 return parse_pattern(pat, toplevel, docstrings)
Alias for parse_pattern.
Types of AST operation nodes that are associative.
814def op_is_assoc(op): 815 """Determine if the given operation node is associative.""" 816 return type(op) in ASSOCIATIVE_OPS
Determine if the given operation node is associative.
Pairs of operations that are mirrors of each other.
830def mirror(op): 831 """Return the mirror operation of the given operation if any.""" 832 for (x, y) in MIRRORING: 833 if type(x) == type(op): 834 return y 835 elif type(y) == type(op): 836 return x 837 pass 838 return None
Return the mirror operation of the given operation if any.
841def op_has_mirror(op): 842 """Determine if the given operation has a mirror.""" 843 return bool(mirror)
Determine if the given operation has a mirror.
866class PermuteBoolOps(ast.NodeTransformer): 867 """AST transformer: permute topmost boolean operation 868 according to the given index order.""" 869 def __init__(self, indices): 870 self.indices = indices 871 872 def visit_BoolOp(self, node): 873 if self.indices: 874 values = [ 875 self.generic_visit(node.values[i]) 876 for i in self.indices 877 ] 878 self.indices = None 879 return ast.copy_location( 880 ast.BoolOp(op=node.op, values=values), 881 node 882 ) 883 else: 884 return self.generic_visit(node)
AST transformer: permute topmost boolean operation according to the given index order.
Inherited Members
- ast.NodeTransformer
- generic_visit
- ast.NodeVisitor
- visit
- visit_Constant
892def node_is_pure(node, purefuns=[]): 893 """Determine if the given node is (conservatively pure (effect-free).""" 894 return ( 895 count(node, FUNCALL, 896 matchpred=lambda x, env: 897 x.func not in PUREFUNS and x.func not in purefuns) == 0 898 and all(count(node, pat) == 0 for pat in IMPURE2) 899 )
Determine if the given node is (conservatively pure (effect-free).
902def permutations(node): 903 """ 904 Generate all permutations of the binary and boolean operations, as 905 well as of import orderings, in and below node, via associativity and 906 mirroring, respecting purity. Because this is heavily exponential, 907 limiting the number of imported names and the complexity of binary 908 operation trees in your patterns is a good thing. 909 """ 910 # TODO: These permutations allow matching like this: 911 # Node: import math, sys, io 912 # Pat: import _x_, ___ 913 # can bind x to math or io, but NOT sys (see TODO near 914 # parse_pattern("___")) 915 if isinstance(node, ast.Import): 916 for perm in list_permutations(node.names): 917 yield ast.Import(names=perm) 918 elif isinstance(node, ast.ImportFrom): 919 for perm in list_permutations(node.names): 920 yield ast.ImportFrom( 921 module=node.module, 922 names=perm, 923 level=node.level 924 ) 925 if ( 926 isinstance(node, ast.BinOp) 927 and op_is_assoc(node) 928 and node_is_pure(node) 929 ): 930 for left in permutations(node.left): 931 for right in permutations(node.right): 932 yield ast.copy_location(ast.BinOp( 933 left=left, 934 op=node.op, 935 right=right, 936 ctx=node.ctx 937 )) 938 yield ast.copy_location(ast.BinOp( 939 left=right, 940 op=node.op, 941 right=left, 942 ctx=node.ctx 943 )) 944 elif (isinstance(node, ast.Compare) 945 and len(node.ops) == 1 946 and op_has_mirror(node.ops[0]) 947 and node_is_pure(node)): 948 assert len(node.comparators) == 1 949 for left in permutations(node.left): 950 for right in permutations(node.comparators[0]): 951 # print('PERMUTE', dump(left), dump(node.ops), dump(right)) 952 yield ast.copy_location(ast.Compare( 953 left=left, 954 ops=node.ops, 955 comparators=[right] 956 ), node) 957 yield ast.copy_location(ast.Compare( 958 left=right, 959 ops=[mirror(node.ops[0])], 960 comparators=[left] 961 ), node) 962 elif isinstance(node, ast.BoolOp) and node_is_pure(node): 963 #print(dump(node)) 964 stuff = [[x for x in permutations(v)] for v in node.values] 965 prod = [x for x in itertools.product(*stuff)] 966 # print(prod) 967 for values in prod: 968 # print('VALUES', map(dump,values)) 969 for indices in itertools.permutations(range(len(node.values))): 970 # print('BOOL', map(dump, values)) 971 yield PermuteBoolOps(indices).visit( 972 ast.copy_location( 973 ast.BoolOp(op=node.op, values=values), 974 node 975 ) 976 ) 977 pass 978 pass 979 pass 980 else: 981 # print('NO', dump(node)) 982 yield node 983 pass
Generate all permutations of the binary and boolean operations, as well as of import orderings, in and below node, via associativity and mirroring, respecting purity. Because this is heavily exponential, limiting the number of imported names and the complexity of binary operation trees in your patterns is a good thing.
986def list_permutations(items): 987 """ 988 A generator which yields all possible orderings of the given list. 989 """ 990 if len(items) <= 1: 991 yield items[:] 992 else: 993 first = items[0] 994 for subperm in list_permutations(items[1:]): 995 for i in range(len(subperm) + 1): 996 yield subperm[:i] + [first] + subperm[i:]
A generator which yields all possible orderings of the given list.
1003def match(node, pat, **kwargs): 1004 """A convenience wrapper for matchpat that accepts a patterns either 1005 as an pre-parsed AST or as a pattern source string to be parsed. 1006 1007 """ 1008 return matchpat(node, parse_pattern(pat), **kwargs)
A convenience wrapper for matchpat that accepts a patterns either as an pre-parsed AST or as a pattern source string to be parsed.
The True predicate
1016def matchpat( 1017 node, 1018 pat, 1019 matchpred=predtrue, 1020 env={}, 1021 gen=False, 1022 normalize=False 1023): 1024 """Match an AST against a (pre-parsed) pattern. 1025 1026 The optional keyword argument matchpred gives a predicate of type 1027 (AST node * match environment) -> bool that filters structural 1028 matches by arbitrary additional criteria. The default is the true 1029 predicate, accepting all structural matches. 1030 1031 The optional keyword argument gen determines whether this function: 1032 - (gen=True) yields an environment for each way pat matches node. 1033 - (gen=False) returns Some(the first of these environments), or 1034 None if there are no matches (default). 1035 1036 EXPERIMENTAL 1037 The optional keyword argument normalize determines whether to 1038 rewrite the target AST node and the pattern to inline simple 1039 straightline variable assignments into large expressions (to the 1040 extent possible) for matching. The default is no normalization 1041 (False). Normalization is experimental. It is rather ad hoc and 1042 conservative and may causes unintuitive matching behavior. Use 1043 with caution. 1044 1045 INTERNAL 1046 The optional keyword env gives an initial match environment which 1047 may be used to constrain otherwise free pattern variables. This 1048 argument is mainly intended for internal use. 1049 1050 """ 1051 assert node is not None 1052 assert pat is not None 1053 if normalize: 1054 if isinstance(node, ast.AST) or type(node) == list: 1055 node = canonical_pure(node) 1056 pass 1057 if isinstance(node, ast.AST) or type(node) == list: 1058 pat = canonical_pure(pat) 1059 pass 1060 pass 1061 1062 # Permute the PATTERN, not the AST, so that nodes returned 1063 # always match real code. 1064 matches = (matchenv 1065 # outer iteration 1066 for permpat in permutations(pat) 1067 # inner iteration 1068 for matchenv in imatches(node, permpat, 1069 Some(env), True) 1070 if matchpred(node, matchenv)) 1071 1072 return matches if gen else takeone(matches)
Match an AST against a (pre-parsed) pattern.
The optional keyword argument matchpred gives a predicate of type (AST node * match environment) -> bool that filters structural matches by arbitrary additional criteria. The default is the true predicate, accepting all structural matches.
The optional keyword argument gen determines whether this function:
- (gen=True) yields an environment for each way pat matches node.
- (gen=False) returns Some(the first of these environments), or None if there are no matches (default).
EXPERIMENTAL The optional keyword argument normalize determines whether to rewrite the target AST node and the pattern to inline simple straightline variable assignments into large expressions (to the extent possible) for matching. The default is no normalization (False). Normalization is experimental. It is rather ad hoc and conservative and may causes unintuitive matching behavior. Use with caution.
INTERNAL The optional keyword env gives an initial match environment which may be used to constrain otherwise free pattern variables. This argument is mainly intended for internal use.
1075def bind(env1, name, value): 1076 """ 1077 Unify the environment in option env1 with a new binding of name to 1078 value. Return Some(extended environment) if env1 is Some(existing 1079 environment) in which name is not bound or is bound to value. 1080 Otherwise, env1 is None or the existing binding of name 1081 incompatible with value, return None. 1082 """ 1083 # Lyn modified to allow debugging prints 1084 assert type(name) == str 1085 if env1 is None: 1086 # return None 1087 result = None 1088 env = env1.value 1089 if var_is_anonymous(name): 1090 # return env1 1091 result = env1 1092 elif name in env: 1093 if takeone(imatches(env[name], value, Some({}), True)): 1094 # return env1 1095 result = env1 1096 else: 1097 # return None 1098 result = None 1099 else: 1100 env = env.copy() 1101 env[name] = value 1102 # print 'bind', name, dump(value), dump(env) 1103 # return Some(env) 1104 result = Some(env) 1105 # print( 1106 # '\n$ bind({}, {}, {}) => {}'.format( 1107 # dump(env1), 1108 # name, 1109 # dump(value), 1110 # dump(result) 1111 # ) 1112 # ) 1113 return result 1114 pass
Unify the environment in option env1 with a new binding of name to value. Return Some(extended environment) if env1 is Some(existing environment) in which name is not bound or is bound to value. Otherwise, env1 is None or the existing binding of name incompatible with value, return None.
AST node fields to be ignored when matching children.
1121def argObjToName(argObj): 1122 '''Convert Python 3 arg object into a Name node, 1123 ignoring any annotation, lineno, and col_offset.''' 1124 if argObj is None: 1125 return None 1126 else: 1127 return ast.Name(id=argObj.arg, ctx=None)
Convert Python 3 arg object into a Name node, ignoring any annotation, lineno, and col_offset.
1130def astStr(astObj): 1131 """ 1132 Converts an AST object to a string, imperfectly. 1133 """ 1134 if isinstance(astObj, ast.Name): 1135 return astObj.id 1136 elif isinstance(astObj, ast.Str): 1137 return repr(astObj.s) 1138 elif isinstance(astObj, ast.Num): 1139 return str(astObj.n) 1140 elif isinstance(astObj, (list, tuple)): 1141 return str([astStr(x) for x in astObj]) 1142 elif hasattr(astObj, "_fields") and len(astObj._fields) > 0: 1143 return "{}({})".format( 1144 type(astObj).__name__, 1145 ', '.join(astStr(getattr(astObj, f)) for f in astObj._fields) 1146 ) 1147 elif hasattr(astObj, "_fields"): 1148 return type(astObj).__name__ 1149 else: 1150 return '<{}>'.format(type(astObj).__name__)
Converts an AST object to a string, imperfectly.
1153def defaultToName(defObj): 1154 """ 1155 Converts a default expression to a name for matching purposes. 1156 TODO: Actually do recursive matching on these expressions! 1157 """ 1158 if isinstance(defObj, ast.Name): 1159 return defObj.id 1160 else: 1161 return astStr(defObj)
Converts a default expression to a name for matching purposes. TODO: Actually do recursive matching on these expressions!
1164def field_values(node): 1165 """Return a list of the values of all matching-relevant fields of the 1166 given AST node, with fields in consistent list positions.""" 1167 1168 # Specializations: 1169 if isinstance(node, ast.FunctionDef): 1170 return [ast.Name(id=v, ctx=None) if k == 'name' 1171 # Lyn sez: commented out following, because fields are 1172 # only name, arguments, body 1173 # else sorted(v, key=lambda x: x.arg) if k == 'keywords' 1174 else v 1175 for (k, v) in ast.iter_fields(node) 1176 if k not in IGNORE_FIELDS] 1177 1178 if isinstance(node, ast.ClassDef): 1179 return [ast.Name(id=v, ctx=None) if k == 'name' 1180 else v 1181 for (k, v) in ast.iter_fields(node) 1182 if k not in IGNORE_FIELDS] 1183 1184 if isinstance(node, ast.Call): 1185 # Old: sorted(v, key=lambda kw: kw.arg) 1186 # New: keyword args use nominal (not positional) matching. 1187 #return [sorted(v, key=lambda kw: kw.arg) if k == 'keywords' else v 1188 return [ 1189 ( 1190 odict((kw.arg, kw.value) for kw in v) 1191 if k == 'keywords' 1192 else v 1193 ) 1194 for (k, v) in ast.iter_fields(node) 1195 if k not in IGNORE_FIELDS 1196 ] 1197 1198 # Lyn sez: ast.arguments handling is new for Python 3 1199 if isinstance(node, ast.arguments): 1200 argList = [ 1201 # Need to create Name nodes to match patterns. 1202 argObjToName(argObj) 1203 for argObj in node.args 1204 ] # Ignores argObj annotation, lineno, col_offset 1205 if ( 1206 node.vararg is None 1207 and node.kwarg is None 1208 and node.kwonlyargs == [] 1209 and node.kw_defaults == [] 1210 and node.defaults == [] 1211 ): 1212 # Optimize (for debugging purposes) this very common case by 1213 # returning a singleton list of lists 1214 return [argList] 1215 else: 1216 # In unoptimized case, return list with sublists and 1217 # argObjects/Nones: 1218 # TODO: treat triple underscores separately/specially!!! 1219 return [ 1220 argList, 1221 argObjToName(node.vararg), 1222 argObjToName(node.kwarg), 1223 [argObjToName(argObj) for argObj in node.kwonlyargs], 1224 # Peter 2019-9-30 the defaults cannot be reliably converted 1225 # into names, because they are expressions! 1226 node.kw_defaults, 1227 node.defaults 1228 ] 1229 1230 if isinstance(node, ast.keyword): 1231 return [ast.Name(id=v, ctx=None) if k == 'arg' else v 1232 for (k, v) in ast.iter_fields(node) 1233 if k not in IGNORE_FIELDS] 1234 1235 if isinstance(node, ast.Global): 1236 return [ast.Name(id=n, ctx=None) for n in sorted(node.names)] 1237 1238 if isinstance(node, ast.Import): 1239 return [ node.names ] 1240 1241 if isinstance(node, ast.ImportFrom): 1242 return [ast.Name(id=node.module, ctx=None), node.level, node.names] 1243 1244 if isinstance(node, ast.alias): 1245 result = [ ast.Name(id=node.name, ctx=None) ] 1246 if node.asname is not None: 1247 result.append(ast.Name(id=node.asname, ctx=None)) 1248 return result 1249 1250 # General cases: 1251 if isinstance(node, ast.AST): 1252 return [v for (k, v) in ast.iter_fields(node) 1253 if k not in IGNORE_FIELDS] 1254 1255 if type(node) == list: 1256 return node 1257 1258 # Added by Peter Mawhorter 2019-4-2 to fix problem where subpatterns fail 1259 # to match within keyword arguments of functions because field_values 1260 # returns an odict as one value for functions with kwargs, and asking for 1261 # the field_values of that odict was hitting the base case below. 1262 if isinstance(node, (dict, odict)): 1263 return [ 1264 ast.Name(id=n, ctx=None) for n in node.keys() 1265 ] + list(node.values()) 1266 1267 return []
Return a list of the values of all matching-relevant fields of the given AST node, with fields in consistent list positions.
1273def imatches(node, pat, env1, seq): 1274 """Exponential backtracking match generating 0 or more matching 1275 environments. Supports multiple sequence patterns in one context, 1276 simple permutations of semantically equivalent but syntactically 1277 mirrored patterns. 1278 """ 1279 # Lyn change early-return pattern to named result returned at end to 1280 # allow debugging prints 1281 # global imatchCount #$ 1282 # print('\n$ {}Entering imatches({}, {}, {}, {})'.format( 1283 # '| '*imatchCount, dump(node), dump(pat), dump(env1), seq)) 1284 # imatchCount += 1 #$ 1285 result = iterempty() # default result if not overridden 1286 if env1 is None: 1287 result = iterempty() 1288 # imatchCount -= 1 #$ 1289 # print( 1290 # '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format( 1291 # '| ' * imatchCount, 1292 # dump(node), 1293 # dump(pat), 1294 # dump(env1), 1295 # seq, 1296 # dump(result) 1297 # ) 1298 # ) 1299 return result 1300 env = env1.value 1301 assert env is not None 1302 if ( 1303 ( 1304 type(pat) == bool 1305 or type(pat) == str # All Python 3 strings are unicode 1306 or pat is None 1307 ) 1308 and node == pat 1309 ): 1310 result = iterone(env) 1311 elif type(pat) == int or type(pat) == float: 1312 # Literal int or float pattern 1313 if type(node) == type(pat): 1314 if ( 1315 (type(pat) == int and node == pat) 1316 or (type(pat) == float and abs(node - pat) < 0.001) 1317 ): 1318 result = iterone(env) 1319 pass 1320 elif node_var(pat): 1321 # Var pattern. 1322 # Match and bind name to node. 1323 if node_is_bindable(node): 1324 # [Peter Mawhorter 2021-8-29] Attempting to allow import 1325 # aliases to unify with variable references later on. If the 1326 # alias has an 'as' part we unify as if it's a name with that 1327 # ID, otherwise we use the name part as the ID. 1328 if isinstance(node, ast.alias): 1329 if node.asname: 1330 bind_as = ast.Name(id=node.asname, ctx=None) 1331 else: 1332 bind_as = ast.Name(id=node.name, ctx=None) 1333 else: 1334 bind_as = node 1335 env2 = bind(env1, node_var(pat), bind_as) 1336 if env2: 1337 result = iterone(env2.value) 1338 pass 1339 elif typed_lit_var(pat): 1340 # Var pattern to bind only a literal of given type. 1341 id, ty = typed_lit_var(pat) 1342 lit = LIT_TYPES[ty](node) 1343 # Match and bind name to literal. 1344 if lit: 1345 env2 = bind(env1, id, lit.value) 1346 if env2: 1347 result = iterone(env2.value) 1348 pass 1349 elif type(node) == type(pat): 1350 # Node and pattern have same type. 1351 if type(pat) == list: 1352 # Node and pattern are both lists. Do positional matching. 1353 if len(pat) == 0: 1354 # Empty list pattern. Node must also be empty. 1355 if len(node) == 0: 1356 result = iterone(env) 1357 pass 1358 elif len(node) == 0: 1359 # Non-empty list pattern with empty node. 1360 # Try to match sequence subpatterns. 1361 if seq: 1362 psn = set_var(pat[0]) 1363 if psn: 1364 result = imatches(node, pat[1:], 1365 bind(env1, psn, []), seq) 1366 pass 1367 pass 1368 else: 1369 # Both are non-empty. 1370 psn = set_var(pat[0]) 1371 if seq and psn: 1372 # First subpattern is a sequence pattern. 1373 # Try all consumption sizes, greediest first. 1374 # Unsophisticated exponential backtracking search. 1375 result = ichain( 1376 imatches( 1377 node[i:], 1378 pat[1:], 1379 bind(env1, psn, node[:i]), 1380 seq 1381 ) 1382 for i in range(len(node), -1, -1) 1383 ) 1384 # Lyn sez: common special case helpful for more concrete 1385 # debugging results (e.g., may return FiniteIterator 1386 # rather than itertools.chain object.) 1387 elif len(node) == 1 and len(pat) == 1: 1388 result = imatches(node[0], pat[0], env1, True) 1389 else: 1390 # For all matches of scalar first element sub pattern. 1391 # Generate all corresponding matches of remainder. 1392 # Unsophisticated exponential backtracking search. 1393 result = ichain( 1394 imatches(node[1:], pat[1:], Some(bs), seq) 1395 for bs in imatches(node[0], pat[0], env1, True) 1396 ) 1397 pass 1398 elif type(node) == dict or type(node) == odict: 1399 result = match_dict(node, pat, env1) 1400 else: 1401 # Node and pat have same type, but are not lists. 1402 # Match scalar structures by matching lists of their fields. 1403 if isinstance(node, ast.AST): 1404 # TODO: DEBUG 1405 #if isinstance(node, ast.Import): 1406 # print( 1407 # "FV2i", 1408 # dump(field_values(node)), 1409 # dump(field_values(pat)) 1410 # ) 1411 #if isinstance(node, ast.alias): 1412 # print( 1413 # "FV2a", 1414 # dump(field_values(node)), 1415 # dump(field_values(pat)) 1416 # ) 1417 result = imatches( 1418 field_values(node), 1419 field_values(pat), 1420 env1, 1421 False 1422 ) 1423 pass 1424 pass 1425 # return iterempty() 1426 # imatchCount -= 1 #$ 1427 # print( 1428 # '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format( 1429 # '| ' * imatchCount, 1430 # dump(node), 1431 # dump(pat), 1432 # dump(env1), 1433 # seq, 1434 # dump(result) 1435 # ) 1436 # ) 1437 return result
Exponential backtracking match generating 0 or more matching environments. Supports multiple sequence patterns in one context, simple permutations of semantically equivalent but syntactically mirrored patterns.
1440def match_dict(node, pat, env1): 1441 # Node and pattern are both dictionaries. Do nominal matching. 1442 # Match all named key patterns, then all single-key pattern variables, 1443 # and finally the multi-key pattern variables. 1444 assert all(type(k) == str # All Python 3 strings are unicode 1445 for k in pat) 1446 assert all(type(k) == str # All Python 3 strings are unicode 1447 for k in node) 1448 1449 def match_keys(node, pat, envopt): 1450 """Match literal keys.""" 1451 keyopt = takeone(k for k in pat 1452 if not node_var_str(k) 1453 and not set_var_str(k)) 1454 if keyopt: 1455 # There is at least one named key in the pattern. 1456 # If this key is also in the program node, then for each 1457 # match of the corresponding node value and pattern value, 1458 # generate all matches for the remaining keys. 1459 key = keyopt.value 1460 if key in node: 1461 return ichain(match_keys(dict_unbind(node, key), 1462 dict_unbind(pat, key), 1463 Some(kenv)) 1464 for kenv in imatches(node[key], pat[key], 1465 envopt, False)) 1466 else: 1467 return iterempty() 1468 pass 1469 else: 1470 # The pattern contains no literal keys. 1471 # Generate all matches for the node and set key variables. 1472 return match_var_keys(node, pat, envopt) 1473 pass 1474 1475 def match_var_keys(node, pat, envopt): 1476 """Match node variable keys.""" 1477 keyvaropt = takeone(k for k in pat if node_var_str(k)) 1478 if keyvaropt: 1479 # There is at least one single-key variable in the pattern. 1480 # For each key-value pair in the node whose value matches 1481 # this single-key variable's associated value pattern, 1482 # generate all matches for the remaining keys. 1483 keyvar = keyvaropt.value 1484 return ichain(match_var_keys(dict_unbind(node, nkey), 1485 dict_unbind(pat, keyvar), 1486 bind(Some(kenv), 1487 node_var_str(keyvar), 1488 ast.Name(id=nkey, ctx=None))) 1489 # outer iteration: 1490 for nkey, nval in node.items() 1491 # inner iteration: 1492 for kenv in imatches(nval, pat[keyvar], 1493 envopt, False)) 1494 else: 1495 # The pattern contains no single-key variables. 1496 # Generate all matches for the set key variables. 1497 return match_set_var_keys(node, pat, envopt) 1498 pass 1499 1500 def match_set_var_keys(node, pat, envopt): 1501 """Match set variable keys.""" 1502 # NOTE: see discussion of match environments for this case: 1503 # https://github.com/wellesleycs111/codder/issues/25 1504 assert envopt 1505 keysetvaropt = takeone(k for k in pat if set_var_str(k)) 1506 if keysetvaropt: 1507 # There is a multi-key variable in the pattern. 1508 # Capture all remaining key-value pairs in the node. 1509 e = bind(envopt, set_var_str(keysetvaropt.value), 1510 [(ast.Name(id=kw, ctx=None), karg) 1511 for kw, karg in node.items()]) 1512 return iterone(e.value) if e else iterempty() 1513 elif 0 == len(node): 1514 # There is no multi-key variable in the pattern. 1515 # There is a match only if there are no remaining 1516 # keys in the node. 1517 # There should also be no remaining keys in the pattern. 1518 assert 0 == len(pat) 1519 return iterone(envopt.value) 1520 else: 1521 return iterempty() 1522 1523 return match_keys(node, pat, env1)
1526def find(node, pat, **kwargs): 1527 """Pre-order search for first sub-AST matching pattern, returning 1528 (matched node, bindings).""" 1529 kwargs['gen'] = True 1530 return takeone(findall(node, parse_pattern(pat), **kwargs))
Pre-order search for first sub-AST matching pattern, returning (matched node, bindings).
1533def findall(node, pat, outside=[], **kwargs): 1534 """ 1535 Search for all sub-ASTs matching pattern, returning list of (matched 1536 node, bindings). 1537 """ 1538 assert node is not None 1539 assert pat is not None 1540 gen = kwargs.get('gen', False) 1541 kwargs['gen'] = True 1542 pat = parse_pattern(pat) 1543 # Top-level sequence patterns are not "anchored" to the ends of 1544 # the containing block when *finding* a submatch within a node (as 1545 # opposed to matching a node exactly). They may match any 1546 # contiguous subsequence. 1547 # - To allow sequence patterns to match starting later than the 1548 # beginning of a program sequence, matching is attempted 1549 # recursively with smaller and smaller suffixes of the program 1550 # sequence. 1551 # - To allow sequence patterns to match ending earlier 1552 # than the end of a program block, we implicitly ensure that a 1553 # sequence wildcard pattern terminates every top-level sequence 1554 # pattern. 1555 # TODO: Because permutations are applied later, this doesn't allow 1556 # all the matching we'd like in the following scenario 1557 # Node: [ alias(name="x"), alias(name="y"), alias(name="z") ] 1558 # Pat: [ alias(name="_a_"), alias(name="___") ] 1559 # here because order of aliases doesn't matter, we *should* be able 1560 # to bind _a_ to x, y, OR z, but it can only bind to x or z, NOT y 1561 if type(pat) == list and not set_var(pat[-1]): 1562 pat = pat + [parse_pattern('___')] 1563 pass 1564 1565 def findall_scalar_pat_iter(node): 1566 """Generate all matches of a scalar (non-list) pattern at this node 1567 or any non-excluded descendant of this node.""" 1568 assert type(pat) != list 1569 assert node is not None 1570 # Yield any environment(s) for match(es) of pattern at node. 1571 envs = [e for e in matchpat(node, pat, **kwargs)] 1572 if 0 < len(envs): 1573 yield (node, envs) 1574 # Continue the search for matches in sub-ASTs of node, only if 1575 # node is not excluded by a "match outside" pattern. 1576 if not any(match(node, op) for op in outside): 1577 for n in field_values(node): 1578 if n: # Search only within non-None children. 1579 for result in findall_scalar_pat_iter(n): 1580 yield result 1581 pass 1582 pass 1583 pass 1584 pass 1585 pass 1586 1587 def findall_list_pat_iter(node): 1588 """Generate all matches of a list pattern at this node or any 1589 non-excluded descendant of this node.""" 1590 assert type(pat) == list 1591 assert 0 < len(pat) 1592 if type(node) == list: 1593 # If searching against a list: 1594 # - match against the list itself 1595 # - search against the first child of the list 1596 # - search against the tail of the list 1597 1598 # Match against the list itself. 1599 # Yield any environment(s) for match(es) of pattern at node. 1600 envs = [e for e in matchpat(node, pat, **kwargs)] 1601 if 0 < len(envs): 1602 yield (node, envs) 1603 # Continue the search for matches in sub-ASTs of node, 1604 # only if node is not excluded by a "match outside" 1605 # pattern. 1606 if ( 1607 not any(match(node, op) for op in outside) 1608 and 0 < len(node) # only in nonempty nodes... 1609 ): 1610 # Search for matches in the first sub-AST. 1611 for m in findall_list_pat_iter(node[0]): 1612 yield m 1613 pass 1614 if not set_var(pat[0]): 1615 # If this sequence pattern does not start with 1616 # a sequence wildcard, then: 1617 # Search for matches in the tail of the list. 1618 # (Includes matches against the entire tail.) 1619 for m in findall_list_pat_iter(node[1:]): 1620 yield m 1621 pass 1622 pass 1623 pass 1624 pass 1625 elif not any(match(node, op) for op in outside): # and node is not list 1626 # A list pattern cannot match against a scalar node. 1627 # Search for matches in children of this scalar (non-list) 1628 # node, only if node is not excluded by a "match outside" 1629 # pattern. 1630 1631 # Optimize to search only where list patterns could match. 1632 1633 # Body blocks 1634 for ty in [ast.ClassDef, ast.FunctionDef, ast.With, ast.Module, 1635 ast.If, ast.For, ast.While, 1636 # ast.TryExcept, ast.TryFinally, 1637 # In Python 3, a single Try node covers both 1638 # TryExcept and TryFinally 1639 ast.Try, 1640 ast.ExceptHandler]: 1641 if isinstance(node, ty): 1642 for m in findall_list_pat_iter(node.body): 1643 yield m 1644 pass 1645 break 1646 pass 1647 # Block else blocks 1648 for ty in [ast.If, ast.For, ast.While, 1649 # ast.TryExcept 1650 # In Python 3, a single Try node covers both 1651 # TryExcept and TryFinally 1652 ast.Try 1653 ]: 1654 if isinstance(node, ty) and node.orelse: 1655 for m in findall_list_pat_iter(node.orelse): 1656 yield m 1657 pass 1658 break 1659 pass 1660 1661# # Except handler blocks 1662# if isinstance(node, ast.TryExcept): 1663# for h in node.handlers: 1664# for m in findall_list_pat_iter(h.body): 1665# yield m 1666# pass 1667# pass 1668# pass 1669# # finally blocks 1670# if isinstance(node, ast.TryFinally): 1671# for m in findall_list_pat_iter(node.finalbody): 1672# yield m 1673# pass 1674# pass 1675 1676 # In Python 3, a single Try node covers both TryExcept and 1677 # TryFinally 1678 if isinstance(node, ast.Try): 1679 for h in node.handlers: 1680 for m in findall_list_pat_iter(h.body): 1681 yield m 1682 pass 1683 pass 1684 pass 1685 1686 # General non-optimized version. 1687 # Must be mutually exclusive with the above if used. 1688 # for n in field_values(node): 1689 # if n: 1690 # for result in findall_list_pat_iter(n): 1691 # yield result 1692 # pass 1693 # pass 1694 # pass 1695 pass 1696 pass 1697 # Apply the right search based on pattern type. 1698 matches = (findall_list_pat_iter if type(pat) == list 1699 else findall_scalar_pat_iter)(node) 1700 # Return the generator or a list of all generated matches, 1701 # depending on gen. 1702 return matches if gen else list(matches)
Search for all sub-ASTs matching pattern, returning list of (matched node, bindings).
1705def count(node, pat, **kwargs): 1706 """ 1707 Count all sub-ASTs matching pattern. Does NOT count individual 1708 environments that match (i.e., ways that bindings could attach at a 1709 given node), but rather counts nodes at which one or more bindings 1710 are possible. 1711 """ 1712 assert 'gen' not in kwargs 1713 return sum(1 for x in findall(node, pat, 1714 gen=True, **kwargs))
Count all sub-ASTs matching pattern. Does NOT count individual environments that match (i.e., ways that bindings could attach at a given node), but rather counts nodes at which one or more bindings are possible.
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
1734class InlineAvailableExpressions(ast.NodeTransformer): 1735 def __init__(self, other=None): 1736 self.available = dict(other.available) if other else {} 1737 1738 def visit_Name(self, name): 1739 if isinstance(name.ctx, ast.Load) and name.id in self.available: 1740 return self.available[name.id] 1741 else: 1742 return self.generic_visit(name) 1743 1744 def visit_Assign(self, assign): 1745 raise Unimplemented 1746 new = ast.copy_location(ast.Assign( 1747 targets=assign.targets, 1748 value=self.visit(assign.value) 1749 ), assign) 1750 self.available[assign.targets[0].id] = new.value 1751 return new 1752 1753 def visit_If(self, ifelse): 1754 # Inline into the test. 1755 test = self.visit(ifelse.test) 1756 # Inline and accumulate in the then and else independently. 1757 body_inliner = InlineAvailableExpressions(self) 1758 orelse_inliner = InlineAvailableExpressions(self) 1759 body = body_inliner.inline_block(ifelse.body) 1760 orelse = orelse_inliner.inline_block(ifelse.orelse) 1761 # Any var->expression that is available after both branches 1762 # is available after. 1763 self.available = { 1764 name: body_inliner.available[name] 1765 for name in (set(body_inliner.available) 1766 & set(orelse_inliner.available)) 1767 if (body_inliner.available[name] == orelse_inliner.available[name]) 1768 } 1769 return ast.copy_location( 1770 ast.If(test=test, body=body, orelse=orelse), 1771 ifelse 1772 ) 1773 1774 def generic_visit(self, node): 1775 if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES): 1776 raise Unimplemented() 1777 return ast.NodeTransformer.generic_visit(self, node) 1778 1779 def inline_block(self, block): 1780 # Introduce duplicate common subexpressions... 1781 return [self.visit(stmt) for stmt in block]
A NodeVisitor
subclass that walks the abstract syntax tree and
allows modification of nodes.
The NodeTransformer
will walk the AST and use the return value of the
visitor methods to replace or remove the old node. If the return value of
the visitor method is None
, the node will be removed from its location,
otherwise it is replaced with the return value. The return value may be the
original node in which case no replacement takes place.
Here is an example transformer that rewrites all occurrences of name lookups
(foo
) to data['foo']
::
class RewriteName(NodeTransformer):
def visit_Name(self, node):
return Subscript(
value=Name(id='data', ctx=Load()),
slice=Constant(value=node.id),
ctx=node.ctx
)
Keep in mind that if the node you're operating on has child nodes you must
either transform the child nodes yourself or call the generic_visit()
method for the node first.
For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
1753 def visit_If(self, ifelse): 1754 # Inline into the test. 1755 test = self.visit(ifelse.test) 1756 # Inline and accumulate in the then and else independently. 1757 body_inliner = InlineAvailableExpressions(self) 1758 orelse_inliner = InlineAvailableExpressions(self) 1759 body = body_inliner.inline_block(ifelse.body) 1760 orelse = orelse_inliner.inline_block(ifelse.orelse) 1761 # Any var->expression that is available after both branches 1762 # is available after. 1763 self.available = { 1764 name: body_inliner.available[name] 1765 for name in (set(body_inliner.available) 1766 & set(orelse_inliner.available)) 1767 if (body_inliner.available[name] == orelse_inliner.available[name]) 1768 } 1769 return ast.copy_location( 1770 ast.If(test=test, body=body, orelse=orelse), 1771 ifelse 1772 )
1774 def generic_visit(self, node): 1775 if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES): 1776 raise Unimplemented() 1777 return ast.NodeTransformer.generic_visit(self, node)
Called if no explicit visitor function exists for a node.
Inherited Members
- ast.NodeVisitor
- visit
- visit_Constant
1784class DeadCodeElim(ast.NodeTransformer): 1785 def __init__(self, other=None): 1786 self.used = set(other.users) if other else set() 1787 pass 1788 1789 def visit_Name(self, name): 1790 if isinstance(name.ctx, ast.Store): 1791 # Var store defines, removing from the set. 1792 self.used = self.used - {name.id} 1793 return name 1794 elif isinstance(name.ctx, ast.Load): 1795 # Var use uses, adding to the set. 1796 self.used = self.used | {name.id} 1797 return name 1798 else: 1799 return name 1800 pass 1801 1802 def visit_Assign(self, assign): 1803 # This restriction prevents worries about things like: 1804 # x, x[0] = [[1], 2] 1805 # By using this restriction it is safe to keep the single set, 1806 # thus order of removals and additions will not be a problem 1807 # since defs are discovered first (always to left of =), then 1808 # uses are discovered next (to right of =). 1809 assert all( 1810 ( 1811 node_is_name(t) 1812 or ( 1813 isinstance(t, ast.Tuple) 1814 and all(node_is_name(t) for t in t.elts) 1815 ) 1816 ) 1817 for t in assign.targets 1818 ) 1819 # Now handled by visit_Name 1820 # self.used = self.used - set(n.id for n in assign.targets) 1821 if (any(t.id in self.used for t in assign.targets if node_is_name(t)) 1822 or any(t.id in self.used 1823 for tup in assign.targets for t in tup 1824 if type(tup) == tuple and node_is_name(t))): 1825 return ast.copy_location(ast.Assign( 1826 targets=[self.visit(t) for t in assign.targets], 1827 value=self.visit(assign.value) 1828 ), assign) 1829 else: 1830 return None 1831 1832 def visit_If(self, ifelse): 1833 body_elim = DeadCodeElim(self) 1834 orelse_elim = DeadCodeElim(self) 1835 # DCE the body 1836 body = body_elim.elim_block(ifelse.body) 1837 # DCE the else 1838 orelse = orelse_elim.elim_block(ifelse.body) 1839 # Use the test -- TODO: could eliminate entire if sometimes. 1840 # Keep it for now for clarity. 1841 self.used = body_elim.used | orelse_elim.used 1842 test = self.visit(ifelse.test) 1843 return ast.copy_location(ast.If(test=test, body=body, orelse=orelse), 1844 ifelse) 1845 1846 def generic_visit(self, node): 1847 if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES): 1848 raise Unimplemented() 1849 return ast.NodeTransformer.generic_visit(self, node) 1850 1851 def elim_block(self, block): 1852 # Introduce duplicate common subexpressions... 1853 return [s for s in 1854 (self.visit(stmt) for stmt in block[::-1]) 1855 if s][::-1]
A NodeVisitor
subclass that walks the abstract syntax tree and
allows modification of nodes.
The NodeTransformer
will walk the AST and use the return value of the
visitor methods to replace or remove the old node. If the return value of
the visitor method is None
, the node will be removed from its location,
otherwise it is replaced with the return value. The return value may be the
original node in which case no replacement takes place.
Here is an example transformer that rewrites all occurrences of name lookups
(foo
) to data['foo']
::
class RewriteName(NodeTransformer):
def visit_Name(self, node):
return Subscript(
value=Name(id='data', ctx=Load()),
slice=Constant(value=node.id),
ctx=node.ctx
)
Keep in mind that if the node you're operating on has child nodes you must
either transform the child nodes yourself or call the generic_visit()
method for the node first.
For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
1789 def visit_Name(self, name): 1790 if isinstance(name.ctx, ast.Store): 1791 # Var store defines, removing from the set. 1792 self.used = self.used - {name.id} 1793 return name 1794 elif isinstance(name.ctx, ast.Load): 1795 # Var use uses, adding to the set. 1796 self.used = self.used | {name.id} 1797 return name 1798 else: 1799 return name 1800 pass
1802 def visit_Assign(self, assign): 1803 # This restriction prevents worries about things like: 1804 # x, x[0] = [[1], 2] 1805 # By using this restriction it is safe to keep the single set, 1806 # thus order of removals and additions will not be a problem 1807 # since defs are discovered first (always to left of =), then 1808 # uses are discovered next (to right of =). 1809 assert all( 1810 ( 1811 node_is_name(t) 1812 or ( 1813 isinstance(t, ast.Tuple) 1814 and all(node_is_name(t) for t in t.elts) 1815 ) 1816 ) 1817 for t in assign.targets 1818 ) 1819 # Now handled by visit_Name 1820 # self.used = self.used - set(n.id for n in assign.targets) 1821 if (any(t.id in self.used for t in assign.targets if node_is_name(t)) 1822 or any(t.id in self.used 1823 for tup in assign.targets for t in tup 1824 if type(tup) == tuple and node_is_name(t))): 1825 return ast.copy_location(ast.Assign( 1826 targets=[self.visit(t) for t in assign.targets], 1827 value=self.visit(assign.value) 1828 ), assign) 1829 else: 1830 return None
1832 def visit_If(self, ifelse): 1833 body_elim = DeadCodeElim(self) 1834 orelse_elim = DeadCodeElim(self) 1835 # DCE the body 1836 body = body_elim.elim_block(ifelse.body) 1837 # DCE the else 1838 orelse = orelse_elim.elim_block(ifelse.body) 1839 # Use the test -- TODO: could eliminate entire if sometimes. 1840 # Keep it for now for clarity. 1841 self.used = body_elim.used | orelse_elim.used 1842 test = self.visit(ifelse.test) 1843 return ast.copy_location(ast.If(test=test, body=body, orelse=orelse), 1844 ifelse)
1846 def generic_visit(self, node): 1847 if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES): 1848 raise Unimplemented() 1849 return ast.NodeTransformer.generic_visit(self, node)
Called if no explicit visitor function exists for a node.
Inherited Members
- ast.NodeVisitor
- visit
- visit_Constant
1858class NormalizePure(ast.NodeTransformer): 1859 """AST transformer: normalize/inline straightline assignments into 1860 single expressions as possible.""" 1861 def normalize_block(self, block): 1862 try: 1863 return DeadCodeElim().elim_block( 1864 InlineAvailableExpressions().inline_block(block) 1865 ) 1866 except Unimplemented: 1867 return block 1868 1869 def visit_FunctionDef(self, fun): 1870 # Lyn warning: previously commented code below is from Python 2 1871 # and won't work in Python 3 (args have changed). 1872 # assert 0 < len( 1873 # result.value[0].intersection(set(a.id for a in fun.args.args)) 1874 # ) 1875 normbody = self.normalize_block(fun.body) 1876 if normbody != fun.body: 1877 return ast.copy_location(ast.FunctionDef( 1878 name=fun.name, 1879 args=self.generic_visit(fun.args), 1880 body=normbody, 1881 ), fun) 1882 else: 1883 return fun
AST transformer: normalize/inline straightline assignments into single expressions as possible.
1869 def visit_FunctionDef(self, fun): 1870 # Lyn warning: previously commented code below is from Python 2 1871 # and won't work in Python 3 (args have changed). 1872 # assert 0 < len( 1873 # result.value[0].intersection(set(a.id for a in fun.args.args)) 1874 # ) 1875 normbody = self.normalize_block(fun.body) 1876 if normbody != fun.body: 1877 return ast.copy_location(ast.FunctionDef( 1878 name=fun.name, 1879 args=self.generic_visit(fun.args), 1880 body=normbody, 1881 ), fun) 1882 else: 1883 return fun
Inherited Members
- ast.NodeTransformer
- generic_visit
- ast.NodeVisitor
- visit
- visit_Constant
1890def canonical_pure(node): 1891 """Return the normalized/inlined version of an AST node.""" 1892 # print(type(fun)) 1893 # assert isinstance(fun, ast.FunctionDef) 1894 # print() 1895 # print('FUN', dump(fun)) 1896 if type(node) == list: 1897 return NormalizePure().normalize_block(node) 1898 # [DesugarAugAssign().visit(stmt) for stmt in node] 1899 #) 1900 else: 1901 assert isinstance(node, ast.AST) 1902 #return NormalizePure().visit(DesugarAugAssign().visit(node)) 1903 return NormalizePure().visit(node)
Return the normalized/inlined version of an AST node.
1909def parse_canonical_pure(string, toplevel=False): 1910 """Parse a normalized/inlined version of a program.""" 1911 if type(string) == list: 1912 return [parse_canonical_pure(x) for x in string] 1913 elif string not in CANON_CACHE: 1914 CANON_CACHE[string] = canonical_pure( 1915 parse_pattern(string, toplevel=toplevel) 1916 ) 1917 pass 1918 return CANON_CACHE[string]
Parse a normalized/inlined version of a program.
1928def indent(pat, indent=INDENT): 1929 """Apply indents to a source string.""" 1930 return indent + pat.replace('\n', '\n' + indent)
Apply indents to a source string.
1933class SourceFormatter(ast.NodeVisitor): 1934 """AST visitor: pretty print AST to python source string""" 1935 def __init__(self): 1936 ast.NodeVisitor.__init__(self) 1937 self._indent = '' 1938 pass 1939 1940 def indent(self): 1941 self._indent += INDENT 1942 pass 1943 1944 def unindent(self): 1945 self._indent = self._indent[:-4] 1946 pass 1947 1948 def line(self, ln): 1949 return self._indent + ln + '\n' 1950 1951 def lines(self, lst): 1952 return ''.join(lst) 1953 1954 def generic_visit(self, node): 1955 assert False, 'visiting {}'.format(ast.dump(node)) 1956 pass 1957 1958 def visit_Module(self, m): 1959 return self.lines(self.visit(n) for n in m.body) 1960 1961 def visit_Interactive(self, i): 1962 return self.lines(self.visit(n) for n in i.body) 1963 1964 def visit_Expression(self, e): 1965 return self.line(self.visit(e.body)) 1966 1967 def visit_FunctionDef(self, f): 1968 assert not f.decorator_list 1969 header = self.line('def {name}({args}):'.format( 1970 name=f.name, 1971 args=self.visit(f.args)) 1972 ) 1973 self.indent() 1974 body = self.lines(self.visit(s) for s in f.body) 1975 self.unindent() 1976 return header + body + '\n' 1977 1978 def visit_ClassDef(self, c): 1979 assert not c.decorator_list 1980 header = self.line('class {name}({bases}):'.format( 1981 name=c.name, 1982 bases=', '.join(self.visit(b) for b in c.bases) 1983 )) 1984 self.indent() 1985 body = self.lines(self.visit(s) for s in c.body) 1986 self.unindent() 1987 return header + body + '\n' 1988 1989 def visit_Return(self, r): 1990 return self.line('return' if r.value is None 1991 else 'return {}'.format(self.visit(r.value))) 1992 1993 def visit_Delete(self, d): 1994 return self.line('del ' + ''.join(self.visit(e) for e in d.targets)) 1995 1996 def visit_Assign(self, a): 1997 return self.line(', '.join(self.visit(e) 1998 for e in a.targets) 1999 + ' = ' + self.visit(a.value)) 2000 2001 def visit_AugAssign(self, a): 2002 return self.line('{target} {op}= {expr}'.format( 2003 target=self.visit(a.target), 2004 op=self.visit(a.op), 2005 expr=self.visit(a.value)) 2006 ) 2007 2008# Print removed as ast node on Python 3 2009# def visit_Print(self, p): 2010# assert p.dest == None 2011# return self.line('print {}{}'.format( 2012# ', '.join(self.visit(e) for e in p.values), 2013# ',' if p.values and not p.nl else '' 2014# )) 2015 2016 def visit_For(self, f): 2017 header = self.line('for {} in {}:'.format( 2018 self.visit(f.target), 2019 self.visit(f.iter)) 2020 ) 2021 self.indent() 2022 body = self.lines(self.visit(s) for s in f.body) 2023 orelse = self.lines(self.visit(s) for s in f.orelse) 2024 self.unindent() 2025 return header + body + ( 2026 self.line('else:') + orelse 2027 if f.orelse else '' 2028 ) 2029 2030 def visit_While(self, w): 2031 # Peter 2021-6-16: Removed this assert; orelse isn't defined here 2032 # assert not orelse 2033 header = self.line('while {}:'.format(self.visit(w.test))) 2034 self.indent() 2035 body = self.lines(self.visit(s) for s in w.body) 2036 orelse = self.lines(self.visit(s) for s in w.orelse) 2037 self.unindent() 2038 return header + body + ( 2039 self.line('else:') + orelse 2040 if w.orelse else '' 2041 ) 2042 return header + body 2043 2044 def visit_If(self, i): 2045 header = self.line('if {}:'.format(self.visit(i.test))) 2046 self.indent() 2047 body = self.lines(self.visit(s) for s in i.body) 2048 orelse = self.lines(self.visit(s) for s in i.orelse) 2049 self.unindent() 2050 return header + body + ( 2051 self.line('else:') + orelse 2052 if i.orelse else '' 2053 ) 2054 2055 def visit_With(self, w): 2056 # Converted to Python3 withitems: 2057 header = self.line( 2058 'with {items}:'.format( 2059 items=', '.join( 2060 '{expr}{asnames}'.format( 2061 expr=self.visit(item.context_expr), 2062 asnames=('as ' + self.visit(item.optional_vars) 2063 if item.optional_vars else '') 2064 ) 2065 for item in w.items 2066 ) 2067 ) 2068 ) 2069 self.indent() 2070 body = self.lines(self.visit(s) for s in w.body) 2071 self.unindent() 2072 return header + body 2073 2074 # Python 3: raise has new abstract syntax 2075 def visit_Raise(self, r): 2076 return self.line('raise{}{}'.format( 2077 (' ' + self.visit(r.exc)) if r.exc else '', 2078 (' from ' + self.visit(r.cause)) if r.cause else '' 2079 )) 2080 2081# def visit_Raise(self, r): 2082# return self.line('raise{}{}{}{}{}'.format( 2083# self.visit(r.type) if r.type else '', 2084# ', ' if r.type and r.inst else '', 2085# self.visit(r.inst) if r.inst else '', 2086# ', ' if (r.type or r.inst) and r.tback else '', 2087# self.visit(r.tback) if r.tback else '' 2088# )) 2089 2090# def visit_TryExcept(self, te): 2091# self.indent() 2092# tblock = self.lines(self.visit(s) for s in te.body) 2093# orelse = self.lines(self.visit(s) for s in te.orelse) 2094# self.unindent() 2095# return ( 2096# self.line('try:') 2097# + tblock 2098# + ''.join(self.visit(eh) for eh in te.handlers) 2099# + (self.line('else:') + orelse if orelse else '' ) 2100# ) 2101 2102# def visit_TryFinally(self, tf): 2103# self.indent() 2104# tblock = self.lines(self.visit(s) for s in tf.body) 2105# fblock = self.lines(self.visit(s) for s in tf.finalbody) 2106# self.unindent() 2107# return ( 2108# self.line('try:') 2109# + tblock 2110# + self.line('finally:') 2111# + fblock 2112# ) 2113 2114 # In Python 3, a single Try node covers both TryExcept and TryFinally 2115 def visit_Try(self, t): 2116 self.indent() 2117 tblock = self.lines(self.visit(s) for s in t.body) 2118 orelse = self.lines(self.visit(s) for s in t.orelse) 2119 fblock = self.lines(self.visit(s) for s in t.finalbody) 2120 self.unindent() 2121 return ( 2122 self.line('try:') 2123 + tblock 2124 + ''.join(self.visit(eh) for eh in t.handlers) 2125 + (self.line('else:') + orelse if orelse else '' ) 2126 + (self.line('finally:') + fblock if fblock else '' ) 2127 ) 2128 2129 def visit_ExceptHandler(self, eh): 2130 header = self.line('except{}{}{}{}:'.format( 2131 ' ' if eh.type else '', 2132 self.visit(eh.type) if eh.type else '', 2133 ' as ' if eh.type and eh.name else ' ' if eh.name else '', 2134 self.visit(eh.name) if eh.name and isinstance(eh.name, ast.AST) 2135 else (eh.name if eh.name else '') 2136 )) 2137 self.indent() 2138 body = self.lines(self.visit(s) for s in eh.body) 2139 self.unindent() 2140 return header + body 2141 2142 def visit_Assert(self, a): 2143 return self.line('assert {}{}{}'.format( 2144 self.visit(a.test), 2145 ', ' if a.msg else '', 2146 self.visit(a.msg) if a.msg else '' 2147 )) 2148 2149 def visit_Import(self, i): 2150 return self.line( 2151 'import {}'.format(', '.join(self.visit(n) for n in i.names)) 2152 ) 2153 2154 def visit_ImportFrom(self, f): 2155 return self.line('from {}{} import {}'.format( 2156 '.' * f.level, 2157 f.module if f.module else '', 2158 ', '.join(self.visit(n) for n in f.names) 2159 )) 2160 2161 def visit_Exec(self, e): 2162 return self.line('exec {}{}{}{}{}'.format( 2163 self.visit(e.body), 2164 ' in ' if e.globals else '', 2165 self.visit(e.globals) if e.globals else '', 2166 ', ' if e.locals else '', 2167 self.visit(e.locals) if e.locals else '' 2168 )) 2169 2170 def visit_Global(self, g): 2171 return self.line('global {}'.format(', '.join(g.names))) 2172 2173 def visit_Expr(self, e): 2174 return self.line(self.visit(e.value)) 2175 2176 def visit_Pass(self, p): 2177 return self.line('pass') 2178 2179 def visit_Break(self, b): 2180 return self.line('break') 2181 2182 def visit_Continue(self, c): 2183 return self.line('continue') 2184 2185 def visit_BoolOp(self, b): 2186 return ' {} '.format( 2187 self.visit(b.op) 2188 ).join('({})'.format(self.visit(e)) for e in b.values) 2189 2190 def visit_BinOp(self, b): 2191 return '({}) {} ({})'.format( 2192 self.visit(b.left), 2193 self.visit(b.op), 2194 self.visit(b.right) 2195 ) 2196 2197 def visit_UnaryOp(self, u): 2198 return '{} ({})'.format( 2199 self.visit(u.op), 2200 self.visit(u.operand) 2201 ) 2202 2203 def visit_Lambda(self, ld): 2204 return '(lambda {}: {})'.format( 2205 self.visit(ld.args), 2206 self.visit(ld.body) 2207 ) 2208 2209 def visit_IfExp(self, i): 2210 return '({} if {} else {})'.format( 2211 self.visit(i.body), 2212 self.visit(i.test), 2213 self.visit(i.orelse) 2214 ) 2215 2216 def visit_Dict(self, d): 2217 return '{{ {} }}'.format( 2218 ', '.join('{}: {}'.format(self.visit(k), self.visit(v)) 2219 for k, v in zip(d.keys, d.values)) 2220 ) 2221 2222 def visit_Set(self, s): 2223 return '{{ {} }}'.format(', '.join(self.visit(e) for e in s.elts)) 2224 2225 def visit_ListComp(self, lc): 2226 return '[{} {}]'.format( 2227 self.visit(lc.elt), 2228 ' '.join(self.visit(g) for g in lc.generators) 2229 ) 2230 2231 def visit_SetComp(self, sc): 2232 return '{{{} {}}}'.format( 2233 self.visit(sc.elt), 2234 ' '.join(self.visit(g) for g in sc.generators) 2235 ) 2236 2237 def visit_DictComp(self, dc): 2238 return '{{{} {}}}'.format( 2239 '{}: {}'.format(self.visit(dc.key), self.visit(dc.value)), 2240 ' '.join(self.visit(g) for g in dc.generators) 2241 ) 2242 2243 def visit_GeneratorExp(self, ge): 2244 return '({} {})'.format( 2245 self.visit(ge.elt), 2246 ' '.join(self.visit(g) for g in ge.generators) 2247 ) 2248 2249 def visit_Yield(self, y): 2250 return 'yield {}'.format(self.visit(y.value) if y.value else '') 2251 2252 def visit_Compare(self, c): 2253 assert len(c.ops) == len(c.comparators) 2254 return '{} {}'.format( 2255 '({})'.format(self.visit(c.left)), 2256 ' '.join( 2257 '{} ({})'.format(self.visit(op), self.visit(expr)) 2258 for op, expr in zip(c.ops, c.comparators) 2259 ) 2260 ) 2261 2262 def visit_Call(self, c): 2263 # return '{fun}({args}{keys}{starargs}{starstarargs})'.format( 2264 # Unlike Python 2, Python 3 has no starargs or startstarargs 2265 return '{fun}({args}{keys})'.format( 2266 fun=self.visit(c.func), 2267 args=', '.join(self.visit(a) for a in c.args), 2268 keys=( 2269 (', ' if c.args else '') 2270 + ( 2271 ', '.join(self.visit(ka) for ka in c.keywords) 2272 if c.keywords else '' 2273 ) 2274 ) 2275 ) 2276 2277 def visit_Repr(self, r): 2278 return 'repr({})'.format(self.visit(r.expr)) 2279 2280 def visit_Num(self, n): 2281 return repr(n.n) 2282 2283 def visit_Str(self, s): 2284 return repr(s.s) 2285 2286 def visit_Attribute(self, a): 2287 return '{}.{}'.format(self.visit(a.value), a.attr) 2288 2289 def visit_Subscript(self, s): 2290 return '{}[{}]'.format(self.visit(s.value), self.visit(s.slice)) 2291 2292 def visit_Name(self, n): 2293 return n.id 2294 2295 def visit_List(self, ls): 2296 return '[{}]'.format(', '.join(self.visit(e) for e in ls.elts)) 2297 2298 def visit_Tuple(self, tp): 2299 return '({})'.format(', '.join(self.visit(e) for e in tp.elts)) 2300 2301 def visit_Ellipsis(self, s): 2302 return '...' 2303 2304 def visit_Slice(self, s): 2305 return '{}:{}{}{}'.format( 2306 self.visit(s.lower) if s.lower else '', 2307 self.visit(s.upper) if s.upper else '', 2308 ':' if s.step else '', 2309 self.visit(s.step) if s.step else '' 2310 ) 2311 2312 def visit_ExtSlice(self, es): 2313 return ', '.join(self.visit(s) for s in es.dims) 2314 2315 def visit_Index(self, i): 2316 return self.visit(i.value) 2317 2318 def visit_And(self, a): 2319 return 'and' 2320 2321 def visit_Or(self, o): 2322 return 'or' 2323 2324 def visit_Add(self, a): 2325 return '+' 2326 2327 def visit_Sub(self, a): 2328 return '-' 2329 2330 def visit_Mult(self, a): 2331 return '*' 2332 2333 def visit_Div(self, a): 2334 return '/' 2335 2336 def visit_Mod(self, a): 2337 return '%' 2338 2339 def visit_Pow(self, a): 2340 return '**' 2341 2342 def visit_LShift(self, a): 2343 return '<<' 2344 2345 def visit_RShift(self, a): 2346 return '>>' 2347 2348 def visit_BitOr(self, a): 2349 return '|' 2350 2351 def visit_BixXor(self, a): 2352 return '^' 2353 2354 def visit_BitAnd(self, a): 2355 return '&' 2356 2357 def visit_FloorDiv(self, a): 2358 return '//' 2359 2360 def visit_Invert(self, a): 2361 return '~' 2362 2363 def visit_Not(self, a): 2364 return 'not' 2365 2366 def visit_UAdd(self, a): 2367 return '+' 2368 2369 def visit_USub(self, a): 2370 return '-' 2371 2372 def visit_Eq(self, a): 2373 return '==' 2374 2375 def visit_NotEq(self, a): 2376 return '!=' 2377 2378 def visit_Lt(self, a): 2379 return '<' 2380 2381 def visit_LtE(self, a): 2382 return '<=' 2383 2384 def visit_Gt(self, a): 2385 return '>' 2386 2387 def visit_GtE(self, a): 2388 return '>=' 2389 2390 def visit_Is(self, a): 2391 return 'is' 2392 2393 def visit_IsNot(self, a): 2394 return 'is not' 2395 2396 def visit_In(self, a): 2397 return 'in' 2398 2399 def visit_NotIn(self, a): 2400 return 'not in' 2401 2402 def visit_comprehension(self, c): 2403 return 'for {} in {}{}{}'.format( 2404 self.visit(c.target), 2405 self.visit(c.iter), 2406 ' ' if c.ifs else '', 2407 ' '.join('if {}'.format(self.visit(i)) for i in c.ifs) 2408 ) 2409 2410 def visit_arg(self, a): 2411 '''[2019/01/22, lyn] Handle new arg objects in Python 3.''' 2412 return a.arg # The name of the argument 2413 2414 def visit_keyword(self, k): 2415 return '{}={}'.format(k.arg, self.visit(k.value)) 2416 2417 def visit_alias(self, a): 2418 return '{} as {}'.format(a.name, a.asname) if a.asname else a.name 2419 2420 def visit_arguments(self, a): 2421 # [2019/01/22, lyn] Note: This does *not* handle Python 3's 2422 # keyword-only arguments (probably moot for 111, but not 2423 # beyond). 2424 stdargs = a.args[:-len(a.defaults)] if a.defaults else a.args 2425 defargs = ( 2426 zip(a.args[-len(a.defaults):], a.defaults) 2427 if a.defaults else [] 2428 ) 2429 return '{stdargs}{sep1}{defargs}{sep2}{varargs}{sep3}{kwargs}'.format( 2430 stdargs=', '.join(self.visit(sa) for sa in stdargs), 2431 sep1=', ' if 0 < len(stdargs) and defargs else '', 2432 defargs=', '.join('{}={}'.format(self.visit(da), self.visit(dd)) 2433 for da, dd in defargs), 2434 sep2=', ' if 0 < len(a.args) and a.vararg else '', 2435 varargs='*{}'.format(a.vararg) if a.vararg else '', 2436 sep3=', ' if (0 < len(a.args) or a.vararg) and a.kwarg else '', 2437 kwargs='**{}'.format(a.kwarg) if a.kwarg else '' 2438 ) 2439 2440 def visit_NameConstant(self, nc): 2441 return str(nc.value) 2442 2443 def visit_Starred(self, st): 2444 # TODO: Is this correct? 2445 return '*' + st.value.id
AST visitor: pretty print AST to python source string
1954 def generic_visit(self, node): 1955 assert False, 'visiting {}'.format(ast.dump(node)) 1956 pass
Called if no explicit visitor function exists for a node.
1967 def visit_FunctionDef(self, f): 1968 assert not f.decorator_list 1969 header = self.line('def {name}({args}):'.format( 1970 name=f.name, 1971 args=self.visit(f.args)) 1972 ) 1973 self.indent() 1974 body = self.lines(self.visit(s) for s in f.body) 1975 self.unindent() 1976 return header + body + '\n'
1978 def visit_ClassDef(self, c): 1979 assert not c.decorator_list 1980 header = self.line('class {name}({bases}):'.format( 1981 name=c.name, 1982 bases=', '.join(self.visit(b) for b in c.bases) 1983 )) 1984 self.indent() 1985 body = self.lines(self.visit(s) for s in c.body) 1986 self.unindent() 1987 return header + body + '\n'
2016 def visit_For(self, f): 2017 header = self.line('for {} in {}:'.format( 2018 self.visit(f.target), 2019 self.visit(f.iter)) 2020 ) 2021 self.indent() 2022 body = self.lines(self.visit(s) for s in f.body) 2023 orelse = self.lines(self.visit(s) for s in f.orelse) 2024 self.unindent() 2025 return header + body + ( 2026 self.line('else:') + orelse 2027 if f.orelse else '' 2028 )
2030 def visit_While(self, w): 2031 # Peter 2021-6-16: Removed this assert; orelse isn't defined here 2032 # assert not orelse 2033 header = self.line('while {}:'.format(self.visit(w.test))) 2034 self.indent() 2035 body = self.lines(self.visit(s) for s in w.body) 2036 orelse = self.lines(self.visit(s) for s in w.orelse) 2037 self.unindent() 2038 return header + body + ( 2039 self.line('else:') + orelse 2040 if w.orelse else '' 2041 ) 2042 return header + body
2044 def visit_If(self, i): 2045 header = self.line('if {}:'.format(self.visit(i.test))) 2046 self.indent() 2047 body = self.lines(self.visit(s) for s in i.body) 2048 orelse = self.lines(self.visit(s) for s in i.orelse) 2049 self.unindent() 2050 return header + body + ( 2051 self.line('else:') + orelse 2052 if i.orelse else '' 2053 )
2055 def visit_With(self, w): 2056 # Converted to Python3 withitems: 2057 header = self.line( 2058 'with {items}:'.format( 2059 items=', '.join( 2060 '{expr}{asnames}'.format( 2061 expr=self.visit(item.context_expr), 2062 asnames=('as ' + self.visit(item.optional_vars) 2063 if item.optional_vars else '') 2064 ) 2065 for item in w.items 2066 ) 2067 ) 2068 ) 2069 self.indent() 2070 body = self.lines(self.visit(s) for s in w.body) 2071 self.unindent() 2072 return header + body
2115 def visit_Try(self, t): 2116 self.indent() 2117 tblock = self.lines(self.visit(s) for s in t.body) 2118 orelse = self.lines(self.visit(s) for s in t.orelse) 2119 fblock = self.lines(self.visit(s) for s in t.finalbody) 2120 self.unindent() 2121 return ( 2122 self.line('try:') 2123 + tblock 2124 + ''.join(self.visit(eh) for eh in t.handlers) 2125 + (self.line('else:') + orelse if orelse else '' ) 2126 + (self.line('finally:') + fblock if fblock else '' ) 2127 )
2129 def visit_ExceptHandler(self, eh): 2130 header = self.line('except{}{}{}{}:'.format( 2131 ' ' if eh.type else '', 2132 self.visit(eh.type) if eh.type else '', 2133 ' as ' if eh.type and eh.name else ' ' if eh.name else '', 2134 self.visit(eh.name) if eh.name and isinstance(eh.name, ast.AST) 2135 else (eh.name if eh.name else '') 2136 )) 2137 self.indent() 2138 body = self.lines(self.visit(s) for s in eh.body) 2139 self.unindent() 2140 return header + body
2262 def visit_Call(self, c): 2263 # return '{fun}({args}{keys}{starargs}{starstarargs})'.format( 2264 # Unlike Python 2, Python 3 has no starargs or startstarargs 2265 return '{fun}({args}{keys})'.format( 2266 fun=self.visit(c.func), 2267 args=', '.join(self.visit(a) for a in c.args), 2268 keys=( 2269 (', ' if c.args else '') 2270 + ( 2271 ', '.join(self.visit(ka) for ka in c.keywords) 2272 if c.keywords else '' 2273 ) 2274 ) 2275 )
2410 def visit_arg(self, a): 2411 '''[2019/01/22, lyn] Handle new arg objects in Python 3.''' 2412 return a.arg # The name of the argument
[2019/01/22, lyn] Handle new arg objects in Python 3.
2420 def visit_arguments(self, a): 2421 # [2019/01/22, lyn] Note: This does *not* handle Python 3's 2422 # keyword-only arguments (probably moot for 111, but not 2423 # beyond). 2424 stdargs = a.args[:-len(a.defaults)] if a.defaults else a.args 2425 defargs = ( 2426 zip(a.args[-len(a.defaults):], a.defaults) 2427 if a.defaults else [] 2428 ) 2429 return '{stdargs}{sep1}{defargs}{sep2}{varargs}{sep3}{kwargs}'.format( 2430 stdargs=', '.join(self.visit(sa) for sa in stdargs), 2431 sep1=', ' if 0 < len(stdargs) and defargs else '', 2432 defargs=', '.join('{}={}'.format(self.visit(da), self.visit(dd)) 2433 for da, dd in defargs), 2434 sep2=', ' if 0 < len(a.args) and a.vararg else '', 2435 varargs='*{}'.format(a.vararg) if a.vararg else '', 2436 sep3=', ' if (0 < len(a.args) or a.vararg) and a.kwarg else '', 2437 kwargs='**{}'.format(a.kwarg) if a.kwarg else '' 2438 )
Inherited Members
- ast.NodeVisitor
- visit
- visit_Constant
2448def ast2source(node): 2449 """Pretty print an AST as a python source string""" 2450 return SourceFormatter().visit(node)
Pretty print an AST as a python source string
Alias for ast2source