potluck.mast

Mast: Matcher of ASTs for Python Copyright (c) 2017-2018, Benjamin P. Wood, Wellesley College

mast.py

This version is included in potluck instead of codder.

See NOTE below for how to run this file as a script.

Uses the builtin ast package for parsing of concrete syntax and representation of abstract syntax for both source code and patterns.

The main API for pattern matching with mast:

  • parse: parse an AST from a python source string
  • parse_file: parse an AST from a python source file
  • parse_pattern: parse a pattern from a pattern source string
  • match: check if an AST matches a pattern
  • find: search for the first (improper) sub-AST matching a pattern
  • findall: search for all (improper) sub-ASTs matching a pattern
  • count: count all (improper) sub-ASTs matching a pattern

API utility functions applicable to all ASTs in ast:

  • dump: a nicer version of ast.dump AST structure pretty-printing
  • ast2source: source: pretty-print an AST to python source

NOTE: To run this file as a script, go up one directory and execute

python -m potluck.mast

For example, if this file is ~/Sites/potluck/potluck/mast.py, then

cd ~/Sites/potluck
python -m potluck.mast

Because of relative imports, you cannot attempt to run this file as a script in ~/Sites/potluck/potluck. Here's what will happen:

$ cd ~/Sites/potluck/potluck
$ python mast.py
Traceback (most recent call last):
File "mast.py", line 40, in <module>
  from .util import (...
SystemError: Parent module '' not loaded, cannot perform relative import

This StackOverflow post is helpful for understanding this issue: https://stackoverflow.com/questions/14132789/relative-imports-for-the-billionth-time

SOME HISTORY [2019/01/19-23, lyn] Python 2 to 3 conversion plus debugging aids [2021/06/22ish, Peter Mawhorter] Linting pass; import into potluck

TODO:

  • have flag to control debugging prints (look for $)

Documentation of the Python ast package:

   1"""
   2Mast: Matcher of ASTs for Python
   3Copyright (c) 2017-2018, Benjamin P. Wood, Wellesley College
   4
   5mast.py
   6
   7This version is included in potluck instead of codder.
   8
   9See NOTE below for how to run this file as a script.
  10
  11Uses the builtin ast package for parsing of concrete syntax and
  12representation of abstract syntax for both source code and patterns.
  13
  14The main API for pattern matching with mast:
  15
  16- `parse`: parse an AST from a python source string
  17- `parse_file`: parse an AST from a python source file
  18- `parse_pattern`: parse a pattern from a pattern source string
  19- `match`: check if an AST matches a pattern
  20- `find`: search for the first (improper) sub-AST matching a pattern
  21- `findall`: search for all (improper) sub-ASTs matching a pattern
  22- `count`: count all (improper) sub-ASTs matching a pattern
  23
  24API utility functions applicable to all ASTs in ast:
  25
  26- `dump`: a nicer version of ast.dump AST structure pretty-printing
  27- `ast2source`: source: pretty-print an AST to python source
  28
  29NOTE: To run this file as a script, go up one directory and execute
  30
  31```sh
  32python -m potluck.mast
  33```
  34
  35For example, if this file is ~/Sites/potluck/potluck/mast.py, then
  36
  37```sh
  38cd ~/Sites/potluck
  39python -m potluck.mast
  40```
  41
  42Because of relative imports, you *cannot* attempt to run this file
  43as a script in ~/Sites/potluck/potluck. Here's what will happen:
  44
  45```sh
  46$ cd ~/Sites/potluck/potluck
  47$ python mast.py
  48Traceback (most recent call last):
  49File "mast.py", line 40, in <module>
  50  from .util import (...
  51SystemError: Parent module '' not loaded, cannot perform relative import
  52```
  53
  54This StackOverflow post is helpful for understanding this issue:
  55https://stackoverflow.com/questions/14132789/relative-imports-for-the-billionth-time
  56
  57SOME HISTORY
  58[2019/01/19-23, lyn] Python 2 to 3 conversion plus debugging aids
  59[2021/06/22ish, Peter Mawhorter] Linting pass; import into potluck
  60
  61TODO:
  62
  63* have flag to control debugging prints (look for $)
  64
  65Documentation of the Python ast package:
  66
  67- Python 2: https://docs.python.org/2/library/ast.html
  68- Python 3:
  69    - https://docs.python.org/3/library/ast.html
  70    - https://greentreesnakes.readthedocs.io/en/latest/
  71"""
  72
  73import ast
  74import os
  75import re
  76import sys
  77
  78from collections import OrderedDict as odict
  79
  80from .mast_utils import (
  81    Some,
  82    takeone,
  83    FiniteIterator,
  84    iterone,
  85    iterempty,
  86    dict_unbind
  87)
  88
  89import itertools
  90ichain = itertools.chain.from_iterable
  91
  92
  93#------------------------------------#
  94# AST data structure pretty printing #
  95#------------------------------------#
  96
  97def dump(node):
  98    """Nicer display of ASTs."""
  99    return (
 100        ast.dump(node) if isinstance(node, ast.AST)
 101        else '[' + ', '.join(map(dump, node)) + ']' if type(node) == list
 102        else '(' + ', '.join(map(dump, node)) + ')' if type(node) == tuple
 103        else '{' + ', '.join(dump(key) + ': ' + dump(value)
 104                             for key, value in node.items()) + '}'
 105            if type(node) == dict
 106        # Sneaky way to have dump descend into Some and FiniteIterator objects
 107        # for debugging
 108        else node.dump(dump)
 109          if (type(node) == Some or type(node) == FiniteIterator)
 110          else repr(node)
 111    )
 112
 113
 114def showret(x, prefix=''):
 115    """Shim for use debugging AST return values."""
 116    print(prefix, dump(x))
 117    return x
 118
 119
 120#-------------------#
 121# Pattern Variables #
 122#-------------------#
 123
 124NODE_VAR_SPEC = (1, re.compile(r'\A_(?:([a-zA-Z0-9]+)_)?\Z'))
 125"""Spec for names of scalar pattern variables."""
 126
 127SEQ_TYPES = set([
 128    ast.arguments,
 129    ast.Assign,
 130    ast.Call,
 131    ast.Expr,
 132    # ast.For,
 133    # ast.Print, # removed from ast module in Python 3
 134    ast.Tuple,
 135    ast.List,
 136    list,
 137    type(None),
 138    ast.BoolOp,
 139])
 140"""Types against which sequence patterns match."""
 141
 142SET_VAR_SPEC = (1, re.compile(r'\A___(?:([a-zA-Z0-9]+)___)?\Z'))
 143"""Spec for names of sequence pattern variables."""
 144
 145if sys.version_info[0] < 3 or sys.version_info[1] < 8:
 146    # This was the setup before 3.8
 147    LIT_TYPES = {
 148        'int': lambda x: (Some(x.n)
 149                          if type(x) == ast.Num and type(x.n) == int
 150                          else None),
 151        'float': lambda x: (Some(x.n)
 152                            if type(x) == ast.Num and type(x.n) == float
 153                            else None),
 154        'str': lambda x: (Some(x.s)
 155                          if type(x) == ast.Str
 156                          else None),
 157        # 'bool': lambda x: (Some(bool(x.id))
 158        # Lyn notes this was a bug! Should have been Some(bool(x.id == 'True'))!
 159        #                    if type(x) == ast.Name and (x.id == 'True'
 160        #                                                or x.id == 'False')
 161        #                    else None),
 162        # -----
 163        # Above is Python 2 code for bools; below is Python 3 (where bools
 164        # are named constants)
 165        'bool': lambda x: (
 166            Some(x.value)
 167            if (
 168                type(x) == ast.NameConstant
 169            and (x.value is True or x.value is False)
 170            ) else None
 171        ),
 172    }
 173    """Types against which typed literal pattern variables match."""
 174else:
 175    # From 3.8, Constant is used in place of lots of previous stuff
 176    LIT_TYPES = {
 177        'int': lambda x: (Some(x.value)
 178                          if type(x) == ast.Constant and type(x.value) == int
 179                          else None),
 180        'float': lambda x: (Some(x.value)
 181                            if type(x) == ast.Constant
 182                                and type(x.value) == float
 183                            else None),
 184        'str': lambda x: (Some(x.value)
 185                          if type(x) == ast.Constant and type(x.value) == str
 186                          else None),
 187        'bool': lambda x: (
 188            Some(x.value)
 189            if (
 190                type(x) == ast.Constant
 191            and (x.value is True or x.value is False)
 192            ) else None
 193        ),
 194    }
 195    """Types against which typed literal pattern variables match."""
 196
 197TYPED_LIT_VAR_SPEC = (
 198    (1, 2),
 199    re.compile(r'\A_([a-zA-Z0-9]+)_(' + '|'.join(LIT_TYPES.keys()) + r')_\Z')
 200)
 201"""Spec for names/types of typed literal pattern variables."""
 202
 203
 204def var_is_anonymous(identifier):
 205    """Determine whether a pattern variable name (string) is anonymous."""
 206    assert type(identifier) == str # All Python 3 strings are unicode
 207    return not re.search(r'[a-zA-Z0-9]', identifier)
 208
 209
 210def node_is_name(node):
 211    """Determine if a node is an AST Name node."""
 212    return isinstance(node, ast.Name)
 213
 214
 215def identifier_key(identifier, spec):
 216    """Extract the name of a pattern variable from its identifier string,
 217    returning an option: Some(key) if identifier is valid variable
 218    name according to spec, otherwise None.
 219
 220    Examples for NODE_VAR_SPEC:
 221    identifier_key('_a_') => Some('a')
 222    identifier_key('_') => Some('_')
 223    identifier_key('_a') => None
 224
 225    """
 226    assert type(identifier) == str # All Python 3 strings are unicode
 227    groups, regex = spec
 228    match = regex.match(identifier)
 229    if match:
 230        if var_is_anonymous(identifier):
 231            return identifier
 232        elif type(groups) == tuple:
 233            return tuple(match.group(i) for i in groups)
 234        else:
 235            return match.group(groups)
 236    else:
 237        return None
 238
 239
 240def node_var(pat, spec=NODE_VAR_SPEC, wrap=False):
 241    """Extract the key name of a scalar pattern variable,
 242    returning Some(key) if pat is a scalar pattern variable,
 243    otherwise None.
 244
 245    A named or anonymous node variable pattern, written `_a_` or `_`,
 246    respectively, may appear in any expression or identifier context
 247    in a pattern.  It matches any single AST in the corresponding
 248    position in the target program.
 249
 250    """
 251    if wrap:
 252        pat = ast.Name(id=pat, ctx=None)
 253    elif isinstance(pat, ast.Expr):
 254        pat = pat.value
 255    elif isinstance(pat, ast.alias) and pat.asname is None:
 256        # [Peter Mawhorter 2021-8-29] Want to treat aliases without an
 257        # 'as' part kind of like normal Name nodes.
 258        pat = ast.Name(id=pat.name, ctx=None)
 259    return (
 260        identifier_key(pat.id, spec)
 261        if node_is_name(pat)
 262        else None
 263    )
 264
 265
 266def node_var_str(pat, spec=NODE_VAR_SPEC):
 267
 268    return node_var(pat, spec=spec, wrap=True)
 269
 270
 271def set_var(pat, wrap=False):
 272    """Extract the key name of a set or sequence pattern variable,
 273    returning Some(key) if pat is a set or sequence pattern variable,
 274    otherwise None.
 275
 276    A named or anonymous set or sequence pattern variable, written
 277    `___a___` or `___`, respectively, may appear as an element of a
 278    set or sequence context in a pattern.  It matches 0 or more nodes
 279    in the corresponding context in the target program.
 280
 281    """
 282    return node_var(pat,
 283                    spec=SET_VAR_SPEC,
 284                    wrap=wrap)
 285
 286
 287def set_var_str(pat):
 288    return set_var(pat, wrap=True)
 289
 290
 291def typed_lit_var(pat):
 292    """Extract the key name of a typed literal pattern variable,
 293    returning Some(key) if pat is a typed literal pattern variable,
 294    otherwise None.
 295
 296    A typed literal variable pattern, written `_a_type_`, may appear
 297    in any expression context in a pattern.  It matches any single AST
 298    node for a literal of the given primitive type in the
 299    corresponding position in the target program.
 300
 301    """
 302    return node_var(pat, spec=TYPED_LIT_VAR_SPEC)
 303
 304# def stmt_var(pat):
 305#     """Extract the key name of a sequence pattern variable, returning
 306#     Some(key) if pat is a sequence pattern variable appearing in a
 307#     statement context, otherwise None.
 308#     """
 309#     return seq_var(pat.value) if isinstance(pat, ast.Expr) else None
 310
 311
 312#---------------------#
 313# AST Node Properties #
 314#---------------------#
 315
 316def is_pat(p):
 317    """Determine if p could be a pattern (by type)."""
 318    # All Python 3 strings are unicode
 319    return isinstance(p, ast.AST) or type(p) == str
 320
 321
 322def node_is_docstring(node):
 323    """Is this node a docstring node?"""
 324    return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)
 325
 326
 327def node_is_bindable(node):
 328    """Can a node pattern variable bind to this node?"""
 329    # return isinstance(node, ast.expr) or isinstance(node, ast.stmt)
 330    # Modified to allow debugging prints
 331    result = (
 332        isinstance(node, ast.expr)
 333     or isinstance(node, ast.stmt)
 334     or (isinstance(node, ast.alias) and node.asname is None)
 335    )
 336    # print('\n$ node_is_bindable({}) => {}'.format(dump(node),result))
 337    return result
 338
 339
 340def node_is_lit(node, ty):
 341    """Is this node a literal primitive node?"""
 342    return (
 343        (isinstance(node, ast.Num) and ty == type(node.n)) # noqa E721
 344        # All Python 3 strings are unicode
 345     or (isinstance(node, ast.Str) and ty == str)
 346     or (
 347            isinstance(node, ast.Name)
 348        and ty == bool
 349        and (node.id == 'True' or node.id == 'False')
 350        )
 351    )
 352
 353
 354def expr_has_type(node, ty):
 355    """Is this expr statically guaranteed to be of this type?
 356
 357    Literals and conversions have definite types.  All other types are
 358    conservatively statically unknown.
 359
 360    """
 361    return node_is_lit(node, ty) or match(node, '{}(_)'.format(ty.__name__))
 362
 363
 364def node_line(node):
 365    """Bet the line number of the source line on which this node starts.
 366    (best effort)"""
 367    try:
 368        return node.lineno if type(node) != list else node[0].lineno
 369    except Exception:
 370        return None
 371
 372
 373#---------------------------#
 374# Checkers and Transformers #
 375#---------------------------#
 376
 377class PatternSyntaxError(BaseException):
 378    """Exception for errors in pattern syntax."""
 379    def __init__(self, node, message):
 380        BaseException.__init__(
 381            self,
 382            "At pattern line {}: {}".format(node_line(node), message)
 383        )
 384
 385
 386class PatternValidator(ast.NodeVisitor):
 387    """AST visitor: check pattern structure."""
 388    def __init__(self):
 389        self.parent = None
 390        pass
 391
 392    def generic_visit(self, pat):
 393        oldparent = self.parent
 394        self.parent = pat
 395        try:
 396            return ast.NodeVisitor.generic_visit(self, pat)
 397        finally:
 398            self.parent = oldparent
 399            pass
 400        pass
 401
 402    def visit_Name(self, pat):
 403        sn = set_var(pat)
 404        if sn:
 405            if type(self.parent) not in SEQ_TYPES:
 406                raise PatternSyntaxError(
 407                    pat,
 408                    "Set/sequence variable ({}) not allowed in {} node".format(
 409                        sn, type(self.parent)
 410                    )
 411                )
 412            pass
 413        pass
 414
 415    def visit_arg(self, pat):
 416        '''[2019/01/22, lyn] Python 3 now has arg object for param (not Name object).
 417           So copied visit_Name here.'''
 418        sn = set_var(pat)
 419        if sn:
 420            if type(self.parent) not in SEQ_TYPES:
 421                raise PatternSyntaxError(
 422                    pat,
 423                    "Set/sequence variable ({}) not allowed in {} node".format(
 424                        sn.value, type(self.parent)
 425                    )
 426                )
 427            pass
 428        pass
 429
 430    def visit_Call(self, c):
 431        if 1 < sum(1 for kw in c.keywords if set_var_str(kw.arg)):
 432            raise PatternSyntaxError(
 433                c.keywords,
 434                "Calls may use at most one keyword argument set variable."
 435            )
 436        return self.generic_visit(c)
 437
 438    def visit_keyword(self, k):
 439        if (identifier_key(k.arg, SET_VAR_SPEC)
 440            and not (node_is_name(k.value)
 441                     and var_is_anonymous(k.value.id))):
 442            raise PatternSyntaxError(
 443                k.value,
 444                "Value patterns for keyword argument set variables must be _."
 445            )
 446        return self.generic_visit(k)
 447    pass
 448
 449
 450# TODO [Peter 2021-6-24]: This pass breaks things subtly by removing
 451# e.g., decorators lists, resulting in an AST tree that cannot be
 452# compiled and run as code! In the past there was mention that without
 453# this pass things would break... but for now it's disabled by default.
 454class RemoveDocstrings(ast.NodeTransformer):
 455    """AST Transformer: remove all docstring nodes."""
 456    def filterDocstrings(self, seq):
 457        # print('PREFILTERED', seq)
 458        filt = [self.visit(n) for n in seq
 459                if not node_is_docstring(n)]
 460        # print('FILTERED', dump(filt))
 461        return filt
 462
 463    def visit_Expr(self, node):
 464        if isinstance(node.value, ast.Str):
 465            assert False
 466            # return ast.copy_location(ast.Expr(value=None), node)
 467        else:
 468            return self.generic_visit(node)
 469
 470    def visit_FunctionDef(self, node):
 471        # print('Removing docstring: %s' % dump(node.body[0].value))
 472        return ast.copy_location(ast.FunctionDef(
 473            name=node.name,
 474            args=self.generic_visit(node.args),
 475            body=self.filterDocstrings(node.body),
 476        ), node)
 477
 478    def visit_Module(self, node):
 479        return ast.copy_location(ast.Module(
 480            body=self.filterDocstrings(node.body)
 481        ), node)
 482
 483    def visit_For(self, node):
 484        return ast.copy_location(ast.For(
 485            body=self.filterDocstrings(node.body),
 486            target=node.target,
 487            iter=node.iter,
 488           orelse=self.filterDocstrings(node.orelse)
 489        ), node)
 490
 491    def visit_While(self, node):
 492        return ast.copy_location(ast.While(
 493            body=self.filterDocstrings(node.body),
 494            test=node.test,
 495            orelse=self.filterDocstrings(node.orelse)
 496        ), node)
 497
 498    def visit_If(self, node):
 499        return ast.copy_location(ast.If(
 500            body=self.filterDocstrings(node.body),
 501            test=node.test,
 502            orelse=self.filterDocstrings(node.orelse)
 503        ), node)
 504
 505    def visit_With(self, node):
 506        return ast.copy_location(ast.With(
 507            body=self.filterDocstrings(node.body),
 508            # Python 3 just has withitems:
 509            items=node.items
 510            # Old Python 2 stuff:
 511            #context_expr=node.context_expr,
 512            #optional_vars=node.optional_vars
 513        ), node)
 514
 515#     def visit_TryExcept(self, node):
 516#         return ast.copy_location(ast.TryExcept(
 517#             body=self.filterDocstrings(node.body),
 518#             handlers=self.filterDocstrings(node.handlers),
 519#             orelse=self.filterDocstrings(node.orelse)
 520#         ), node)
 521
 522#     def visit_TryFinally(self, node):
 523#         return ast.copy_location(ast.TryFinally(
 524#             body=self.filterDocstrings(node.body),
 525#             finalbody=self.filterDocstrings(node.finalbody)
 526#         ), node)
 527#
 528
 529# In Python 3, a single Try node covers both TryExcept and TryFinally
 530    def visit_Try(self, node):
 531        return ast.copy_location(ast.Try(
 532            body=self.filterDocstrings(node.body),
 533            handlers=self.filterDocstrings(node.handlers),
 534            orelse=self.filterDocstrings(node.orelse),
 535            finalbody=self.filterDocstrings(node.finalbody)
 536         ), node)
 537
 538    def visit_ClassDef(self, node):
 539        return ast.copy_location(ast.ClassDef(
 540            body=self.filterDocstrings(node.body),
 541            name=node.name,
 542            bases=node.bases,
 543            decorator_list=node.decorator_list
 544        ), node)
 545
 546
 547class FlattenBoolOps(ast.NodeTransformer):
 548    """AST transformer: flatten nested boolean expressions."""
 549    def visit_BoolOp(self, node):
 550        values = []
 551        for v in node.values:
 552            if isinstance(v, ast.BoolOp) and v.op == node.op:
 553                values += v.values
 554            else:
 555                values.append(v)
 556                pass
 557            pass
 558        return ast.copy_location(ast.BoolOp(op=node.op, values=values), node)
 559    pass
 560
 561
 562class ContainsCall(Exception):
 563    """Exception raised when prohibiting call expressions."""
 564    pass
 565
 566
 567class ProhibitCall(ast.NodeVisitor):
 568    """AST visitor: check call-freedom."""
 569    def visit_Call(self, c):
 570        raise ContainsCall
 571    pass
 572
 573
 574class SetLoadContexts(ast.NodeTransformer):
 575    """
 576    Transforms any AugStore contexts into Load contexts, for use with
 577    DesugarAugAssign.
 578    """
 579    def visit_AugStore(self, ctx):
 580        return ast.copy_location(ast.Load(), ctx)
 581
 582
 583class DesugarAugAssign(ast.NodeTransformer):
 584    """AST transformer: desugar augmented assignments (e.g., `+=`)
 585    when simple desugaring does not duplicate effectful expressions.
 586
 587    FIXME: this desugaring and others cause surprising match
 588    counts. See: https://github.com/wellesleycs111/codder/issues/31
 589
 590    FIXME: This desugaring should probably not happen in cases where
 591    .__iadd__ and .__add__ yield different results, for example with
 592    lists where .__iadd__ is really extend while .__add__ creates a new
 593    list, such that
 594
 595        [1, 2] + "34"
 596
 597    is an error, but
 598
 599        x = [1, 2]
 600        x += "34"
 601
 602    is not!
 603
 604    A better approach might be to avoid desugaring entirely and instead
 605    provide common collections of match rules for common patterns.
 606
 607    """
 608    def visit_AugAssign(self, assign):
 609        try:
 610            # Desugaring *all* AugAssigns is not sound.
 611            # Example: xs[f()] += 1  -->  xs[f()] = xs[f()] + 1
 612            # Check for presence of call in target.
 613            ProhibitCall().visit(assign.target)
 614            return ast.copy_location(
 615                ast.Assign(
 616                    targets=[self.visit(assign.target)],
 617                    value=ast.copy_location(
 618                        ast.BinOp(
 619                            left=SetLoadContexts().visit(assign.target),
 620                            op=self.visit(assign.op),
 621                            right=self.visit(assign.value)
 622                        ),
 623                        assign
 624                    )
 625                ),
 626                assign
 627            )
 628        except ContainsCall:
 629            return self.generic_visit(assign)
 630
 631    def visit_AugStore(self, ctx):
 632        return ast.copy_location(ast.Store(), ctx)
 633
 634    def visit_AugLoad(self, ctx):
 635        return ast.copy_location(ast.Load(), ctx)
 636
 637
 638class ExpandExplicitElsePattern(ast.NodeTransformer):
 639    """AST transformer: transform patterns that in include `else: ___`
 640    to use `else: _; ___` instead, forcing the else to have at least
 641    one statement."""
 642    def visit_If(self, ifelse):
 643        if 1 == len(ifelse.orelse) and set_var(ifelse.orelse[0]):
 644            return ast.copy_location(
 645                ast.If(
 646                    test=ifelse.test,
 647                    body=ifelse.body,
 648                    orelse=[
 649                        ast.copy_location(
 650                            ast.Expr(value=ast.Name(id='_', ctx=None)),
 651                            ifelse.orelse[0]
 652                        ),
 653                        ifelse.orelse[0]
 654                    ]
 655                ),
 656                ifelse
 657            )
 658        else:
 659            return self.generic_visit(ifelse)
 660
 661
 662def pipe_visit(node, visitors):
 663    """Send an AST through a pipeline of AST visitors/transformers."""
 664    if 0 == len(visitors):
 665        return node
 666    else:
 667        v = visitors[0]
 668        if isinstance(v, ast.NodeTransformer):
 669            # visit for the transformed result
 670            return pipe_visit(v.visit(node), visitors[1:])
 671        else:
 672            # visit for the effect
 673            v.visit(node)
 674            return pipe_visit(node, visitors[1:])
 675
 676
 677#---------#
 678# Parsing #
 679#---------#
 680
 681def parse_file(path, docstrings=True):
 682    """Load and parse a program AST from the given path."""
 683    with open(path) as f:
 684        return parse(f.read(), filename=path, docstrings=docstrings)
 685    pass
 686
 687
 688# Note: 2021-7-29 Peter Mawhorter: A bug in DesugarAugAssign was found
 689# and fixed (incorrectly left Store context on RHS, which was apparently
 690# ignored in many Python versions...). However, upon consideration I've
 691# decided to remove it from standard + docstring passes, since the
 692# problem with += and lists is pretty relevant (I've seen students
 693# accidentally do list += string several times in office hours)!
 694STANDARD_PASSES = [
 695    RemoveDocstrings,
 696    #DesugarAugAssign,
 697    FlattenBoolOps
 698]
 699"""Thunks for AST visitors/transformers that are applied during parsing."""
 700
 701
 702DOCSTRING_PASSES = [
 703    #DesugarAugAssign,
 704    FlattenBoolOps
 705]
 706"""Extra thunks for docstring mode?"""
 707
 708
 709class MastParseError(Exception):
 710    def __init__(self, pattern, error):
 711        super(MastParseError, self).__init__(
 712            'Error parsing pattern source string <<<\n{}\n>>>\n{}'.format(
 713                pattern, str(error)
 714            )
 715        )
 716        self.trigger = error # hang on to the original error
 717        pass
 718    pass
 719
 720
 721def parse(
 722    string,
 723    docstrings=True,
 724    filename='<unknown>',
 725    passes=STANDARD_PASSES
 726):
 727    """Parse an AST from a string."""
 728    if type(string) == str: # All Python 3 strings are unicode
 729        string = str(string)
 730    assert type(string) == str # All Python 3 strings are unicode
 731    if docstrings and RemoveDocstrings in passes:
 732        passes.remove(RemoveDocstrings)
 733    pipe = [thunk() for thunk in passes]
 734
 735    try:
 736        parsed_ast = ast.parse(string, filename=filename)
 737    except Exception as error:
 738        raise MastParseError(string, error)
 739    return pipe_visit(parsed_ast, pipe)
 740
 741
 742PATTERN_PARSE_CACHE = {}
 743"""a pattern parsing cache"""
 744
 745
 746PATTERN_PASSES = STANDARD_PASSES + [
 747    # ExpandExplicitElsePattern,
 748    # Not currently used because a number of existing rules and specs
 749    # actually use the ability of `if _: ___; else: ___` to match
 750    # either an if-else or an elseless if.
 751]
 752
 753
 754def parse_pattern(pat, toplevel=False, docstrings=True):
 755    """Parse and validate a pattern."""
 756    if isinstance(pat, ast.AST):
 757        PatternValidator().visit(pat)
 758        return pat
 759    elif type(pat) == list:
 760        iter(lambda x: PatternValidator().visit(x), pat)
 761        return pat
 762    elif type(pat) == str: # All Python 3 strings are unicode
 763        cache_key = (toplevel, docstrings, pat)
 764        pat_ast = PATTERN_PARSE_CACHE.get(cache_key)
 765        if not pat_ast:
 766            # 2021-6-16 Peter commented this out (should it be used
 767            # instead of passes on the line below?)
 768            # pipe = [thunk() for thunk in PATTERN_PASSES]
 769            pat_ast = parse(pat, filename='<pattern>', passes=PATTERN_PASSES)
 770            PatternValidator().visit(pat_ast)
 771            if not toplevel:
 772                # Unwrap as needed.
 773                if len(pat_ast.body) == 1:
 774                    # If the pattern is a single definition, statement, or
 775                    # expression, unwrap the Module node.
 776                    b = pat_ast.body[0]
 777                    pat_ast = b.value if isinstance(b, ast.Expr) else b
 778                    pass
 779                else:
 780                    # If the pattern is a sequence of definitions or
 781                    # statements, validate from the top, but return only
 782                    # the Module body.
 783                    pat_ast = pat_ast.body
 784                    pass
 785                pass
 786            PATTERN_PARSE_CACHE[cache_key] = pat_ast
 787            pass
 788        return pat_ast
 789    else:
 790        assert False, 'Cannot parse pattern of type {}: {}'.format(
 791            type(pat),
 792            dump(pat)
 793        )
 794
 795
 796def pat(pat, toplevel=False, docstrings=True):
 797    """Alias for parse_pattern."""
 798    return parse_pattern(pat, toplevel, docstrings)
 799
 800
 801#--------------#
 802# Permutations #
 803#--------------#
 804
 805# [2019/02/09, lyn] Careful below! generally need list(map(...)) in Python 3
 806ASSOCIATIVE_OPS = list(map(type, [
 807    ast.Add(), ast.Mult(), ast.BitOr(), ast.BitXor(), ast.BitAnd(),
 808    ast.Eq(), ast.NotEq(), ast.Is(), ast.IsNot()
 809]))
 810"""Types of AST operation nodes that are associative."""
 811
 812
 813def op_is_assoc(op):
 814    """Determine if the given operation node is associative."""
 815    return type(op) in ASSOCIATIVE_OPS
 816
 817
 818MIRRORING = [
 819    (ast.Lt(), ast.Gt()),
 820    (ast.LtE(), ast.GtE()),
 821    (ast.Eq(), ast.Eq()),
 822    (ast.NotEq(), ast.NotEq()),
 823    (ast.Is(), ast.Is()),
 824    (ast.IsNot(), ast.IsNot()),
 825]
 826"""Pairs of operations that are mirrors of each other."""
 827
 828
 829def mirror(op):
 830    """Return the mirror operation of the given operation if any."""
 831    for (x, y) in MIRRORING:
 832        if type(x) == type(op):
 833            return y
 834        elif type(y) == type(op):
 835            return x
 836        pass
 837    return None
 838
 839
 840def op_has_mirror(op):
 841    """Determine if the given operation has a mirror."""
 842    return bool(mirror)
 843
 844
 845# MIRROR_OPS = set(reduce(lambda acc,(x,y): acc + [x,y], MIRRORING, []))
 846# ASSOCIATIVE_IF_PURE_OPS = set([
 847#     ast.And(), ast.Or()
 848# ])
 849
 850
 851# Patterns for potentially effectful operations
 852FUNCALL = parse_pattern('_(___)')
 853METHODCALL = parse_pattern('_._(___)')
 854ASSIGN = parse_pattern('_ = _')
 855DEL = parse_pattern('del _')
 856PRINT = parse_pattern('print(___)')
 857IMPURE2 = [METHODCALL, ASSIGN, DEL, PRINT]
 858
 859
 860# Patterns for definitely pure operations
 861PUREFUNS = set(['len', 'round', 'int', 'str', 'float', 'list', 'tuple',
 862                'map', 'filter', 'reduce', 'iter', 'all', 'any'])
 863
 864
 865class PermuteBoolOps(ast.NodeTransformer):
 866    """AST transformer: permute topmost boolean operation
 867    according to the given index order."""
 868    def __init__(self, indices):
 869        self.indices = indices
 870
 871    def visit_BoolOp(self, node):
 872        if self.indices:
 873            values = [
 874                self.generic_visit(node.values[i])
 875                for i in self.indices
 876            ]
 877            self.indices = None
 878            return ast.copy_location(
 879                ast.BoolOp(op=node.op, values=values),
 880                node
 881            )
 882        else:
 883            return self.generic_visit(node)
 884
 885# class PermuteAST(ast.NodeVisitor):
 886#     def visit_BinOp(self, node):
 887#         yield self.generic_visit(node)
 888#         if op_is_assoc(node) and ispure(node):
 889
 890
 891def node_is_pure(node, purefuns=[]):
 892    """Determine if the given node is (conservatively pure (effect-free)."""
 893    return (
 894        count(node, FUNCALL,
 895              matchpred=lambda x, env:
 896              x.func not in PUREFUNS and x.func not in purefuns) == 0
 897        and all(count(node, pat) == 0 for pat in IMPURE2)
 898    )
 899
 900
 901def permutations(node):
 902    """
 903    Generate all permutations of the binary and boolean operations, as
 904    well as of import orderings, in and below node, via associativity and
 905    mirroring, respecting purity. Because this is heavily exponential,
 906    limiting the number of imported names and the complexity of binary
 907    operation trees in your patterns is a good thing.
 908    """
 909    # TODO: These permutations allow matching like this:
 910    # Node: import math, sys, io
 911    # Pat: import _x_, ___
 912    # can bind x to math or io, but NOT sys (see TODO near
 913    # parse_pattern("___"))
 914    if isinstance(node, ast.Import):
 915        for perm in list_permutations(node.names):
 916            yield ast.Import(names=perm)
 917    elif isinstance(node, ast.ImportFrom):
 918        for perm in list_permutations(node.names):
 919            yield ast.ImportFrom(
 920                module=node.module,
 921                names=perm,
 922                level=node.level
 923            )
 924    if (
 925        isinstance(node, ast.BinOp)
 926    and op_is_assoc(node)
 927    and node_is_pure(node)
 928    ):
 929        for left in permutations(node.left):
 930            for right in permutations(node.right):
 931                yield ast.copy_location(ast.BinOp(
 932                    left=left,
 933                    op=node.op,
 934                    right=right,
 935                    ctx=node.ctx
 936                ))
 937                yield ast.copy_location(ast.BinOp(
 938                    left=right,
 939                    op=node.op,
 940                    right=left,
 941                    ctx=node.ctx
 942                ))
 943    elif (isinstance(node, ast.Compare)
 944          and len(node.ops) == 1
 945          and op_has_mirror(node.ops[0])
 946          and node_is_pure(node)):
 947        assert len(node.comparators) == 1
 948        for left in permutations(node.left):
 949            for right in permutations(node.comparators[0]):
 950                # print('PERMUTE', dump(left), dump(node.ops), dump(right))
 951                yield ast.copy_location(ast.Compare(
 952                    left=left,
 953                    ops=node.ops,
 954                    comparators=[right]
 955                ), node)
 956                yield ast.copy_location(ast.Compare(
 957                    left=right,
 958                    ops=[mirror(node.ops[0])],
 959                    comparators=[left]
 960                ), node)
 961    elif isinstance(node, ast.BoolOp) and node_is_pure(node):
 962        #print(dump(node))
 963        stuff = [[x for x in permutations(v)] for v in node.values]
 964        prod = [x for x in itertools.product(*stuff)]
 965        # print(prod)
 966        for values in prod:
 967            # print('VALUES', map(dump,values))
 968            for indices in itertools.permutations(range(len(node.values))):
 969                # print('BOOL', map(dump, values))
 970                yield PermuteBoolOps(indices).visit(
 971                    ast.copy_location(
 972                        ast.BoolOp(op=node.op, values=values),
 973                        node
 974                    )
 975                )
 976                pass
 977            pass
 978        pass
 979    else:
 980        # print('NO', dump(node))
 981        yield node
 982    pass
 983
 984
 985def list_permutations(items):
 986    """
 987    A generator which yields all possible orderings of the given list.
 988    """
 989    if len(items) <= 1:
 990        yield items[:]
 991    else:
 992        first = items[0]
 993        for subperm in list_permutations(items[1:]):
 994            for i in range(len(subperm) + 1):
 995                yield subperm[:i] + [first] + subperm[i:]
 996
 997
 998#----------#
 999# Matching #
1000#----------#
1001
1002def match(node, pat, **kwargs):
1003    """A convenience wrapper for matchpat that accepts a patterns either
1004    as an pre-parsed AST or as a pattern source string to be parsed.
1005
1006    """
1007    return matchpat(node, parse_pattern(pat), **kwargs)
1008
1009
1010def predtrue(node, matchenv):
1011    """The True predicate"""
1012    return True
1013
1014
1015def matchpat(
1016    node,
1017    pat,
1018    matchpred=predtrue,
1019    env={},
1020    gen=False,
1021    normalize=False
1022):
1023    """Match an AST against a (pre-parsed) pattern.
1024
1025    The optional keyword argument matchpred gives a predicate of type
1026    (AST node * match environment) -> bool that filters structural
1027    matches by arbitrary additional criteria.  The default is the true
1028    predicate, accepting all structural matches.
1029
1030    The optional keyword argument gen determines whether this function:
1031      - (gen=True)  yields an environment for each way pat matches node.
1032      - (gen=False) returns Some(the first of these environments), or
1033                    None if there are no matches (default).
1034
1035    EXPERIMENTAL
1036    The optional keyword argument normalize determines whether to
1037    rewrite the target AST node and the pattern to inline simple
1038    straightline variable assignments into large expressions (to the
1039    extent possible) for matching.  The default is no normalization
1040    (False).  Normalization is experimental.  It is rather ad hoc and
1041    conservative and may causes unintuitive matching behavior.  Use
1042    with caution.
1043
1044    INTERNAL
1045    The optional keyword env gives an initial match environment which
1046    may be used to constrain otherwise free pattern variables.  This
1047    argument is mainly intended for internal use.
1048
1049    """
1050    assert node is not None
1051    assert pat is not None
1052    if normalize:
1053        if isinstance(node, ast.AST) or type(node) == list:
1054            node = canonical_pure(node)
1055            pass
1056        if isinstance(node, ast.AST) or type(node) == list:
1057            pat = canonical_pure(pat)
1058            pass
1059        pass
1060
1061    # Permute the PATTERN, not the AST, so that nodes returned
1062    # always match real code.
1063    matches = (matchenv
1064               # outer iteration
1065               for permpat in permutations(pat)
1066               # inner iteration
1067               for matchenv in imatches(node, permpat,
1068                                        Some(env), True)
1069               if matchpred(node, matchenv))
1070
1071    return matches if gen else takeone(matches)
1072
1073
1074def bind(env1, name, value):
1075    """
1076    Unify the environment in option env1 with a new binding of name to
1077    value.  Return Some(extended environment) if env1 is Some(existing
1078    environment) in which name is not bound or is bound to value.
1079    Otherwise, env1 is None or the existing binding of name
1080    incompatible with value, return None.
1081    """
1082    # Lyn modified to allow debugging prints
1083    assert type(name) == str
1084    if env1 is None:
1085        # return None
1086        result = None
1087    env = env1.value
1088    if var_is_anonymous(name):
1089        # return env1
1090        result = env1
1091    elif name in env:
1092        if takeone(imatches(env[name], value, Some({}), True)):
1093            # return env1
1094            result = env1
1095        else:
1096            # return None
1097            result = None
1098    else:
1099        env = env.copy()
1100        env[name] = value
1101        # print 'bind', name, dump(value), dump(env)
1102        # return Some(env)
1103        result = Some(env)
1104    # print(
1105    #    '\n$ bind({}, {}, {}) => {}'.format(
1106    #         dump(env1),
1107    #         name,
1108    #         dump(value),
1109    #         dump(result)
1110    #     )
1111    # )
1112    return result
1113    pass
1114
1115
1116IGNORE_FIELDS = set(['ctx', 'lineno', 'col_offset'])
1117"""AST node fields to be ignored when matching children."""
1118
1119
1120def argObjToName(argObj):
1121    '''Convert Python 3 arg object into a Name node,
1122       ignoring any annotation, lineno, and col_offset.'''
1123    if argObj is None:
1124        return None
1125    else:
1126        return ast.Name(id=argObj.arg, ctx=None)
1127
1128
1129def astStr(astObj):
1130    """
1131    Converts an AST object to a string, imperfectly.
1132    """
1133    if isinstance(astObj, ast.Name):
1134        return astObj.id
1135    elif isinstance(astObj, ast.Str):
1136        return repr(astObj.s)
1137    elif isinstance(astObj, ast.Num):
1138        return str(astObj.n)
1139    elif isinstance(astObj, (list, tuple)):
1140        return str([astStr(x) for x in astObj])
1141    elif hasattr(astObj, "_fields") and len(astObj._fields) > 0:
1142        return "{}({})".format(
1143            type(astObj).__name__,
1144            ', '.join(astStr(getattr(astObj, f)) for f in astObj._fields)
1145        )
1146    elif hasattr(astObj, "_fields"):
1147        return type(astObj).__name__
1148    else:
1149        return '<{}>'.format(type(astObj).__name__)
1150
1151
1152def defaultToName(defObj):
1153    """
1154    Converts a default expression to a name for matching purposes.
1155    TODO: Actually do recursive matching on these expressions!
1156    """
1157    if isinstance(defObj, ast.Name):
1158        return defObj.id
1159    else:
1160        return astStr(defObj)
1161
1162
1163def field_values(node):
1164    """Return a list of the values of all matching-relevant fields of the
1165    given AST node, with fields in consistent list positions."""
1166
1167    # Specializations:
1168    if isinstance(node, ast.FunctionDef):
1169        return [ast.Name(id=v, ctx=None) if k == 'name'
1170                # Lyn sez: commented out following, because fields are
1171                # only name, arguments, body
1172                # else sorted(v, key=lambda x: x.arg) if k == 'keywords'
1173                else v
1174                for (k, v) in ast.iter_fields(node)
1175                if k not in IGNORE_FIELDS]
1176
1177    if isinstance(node, ast.ClassDef):
1178        return [ast.Name(id=v, ctx=None) if k == 'name'
1179                else v
1180                for (k, v) in ast.iter_fields(node)
1181                if k not in IGNORE_FIELDS]
1182
1183    if isinstance(node, ast.Call):
1184        # Old: sorted(v, key=lambda kw: kw.arg)
1185        # New: keyword args use nominal (not positional) matching.
1186        #return [sorted(v, key=lambda kw: kw.arg) if k == 'keywords' else v
1187        return [
1188            (
1189                odict((kw.arg, kw.value) for kw in v)
1190                if k == 'keywords'
1191                else v
1192            )
1193            for (k, v) in ast.iter_fields(node)
1194            if k not in IGNORE_FIELDS
1195        ]
1196
1197    # Lyn sez: ast.arguments handling is new for Python 3
1198    if isinstance(node, ast.arguments):
1199        argList = [
1200            # Need to create Name nodes to match patterns.
1201            argObjToName(argObj)
1202            for argObj in node.args
1203        ] # Ignores argObj annotation, lineno, col_offset
1204        if (
1205            node.vararg is None
1206        and node.kwarg is None
1207        and node.kwonlyargs == []
1208        and node.kw_defaults == []
1209        and node.defaults == []
1210        ):
1211            # Optimize (for debugging purposes) this very common case by
1212            # returning a singleton list of lists
1213            return [argList]
1214        else:
1215            # In unoptimized case, return list with sublists and
1216            # argObjects/Nones:
1217            # TODO: treat triple underscores separately/specially!!!
1218            return [
1219                argList,
1220                argObjToName(node.vararg),
1221                argObjToName(node.kwarg),
1222                [argObjToName(argObj) for argObj in node.kwonlyargs],
1223                # Peter 2019-9-30 the defaults cannot be reliably converted
1224                # into names, because they are expressions!
1225                node.kw_defaults,
1226                node.defaults
1227            ]
1228
1229    if isinstance(node, ast.keyword):
1230        return [ast.Name(id=v, ctx=None) if k == 'arg' else v
1231                for (k, v) in ast.iter_fields(node)
1232                if k not in IGNORE_FIELDS]
1233
1234    if isinstance(node, ast.Global):
1235        return [ast.Name(id=n, ctx=None) for n in sorted(node.names)]
1236
1237    if isinstance(node, ast.Import):
1238        return [ node.names ]
1239
1240    if isinstance(node, ast.ImportFrom):
1241        return [ast.Name(id=node.module, ctx=None), node.level, node.names]
1242
1243    if isinstance(node, ast.alias):
1244        result = [ ast.Name(id=node.name, ctx=None) ]
1245        if node.asname is not None:
1246            result.append(ast.Name(id=node.asname, ctx=None))
1247        return result
1248
1249    # General cases:
1250    if isinstance(node, ast.AST):
1251        return [v for (k, v) in ast.iter_fields(node)
1252                if k not in IGNORE_FIELDS]
1253
1254    if type(node) == list:
1255        return node
1256
1257    # Added by Peter Mawhorter 2019-4-2 to fix problem where subpatterns fail
1258    # to match within keyword arguments of functions because field_values
1259    # returns an odict as one value for functions with kwargs, and asking for
1260    # the field_values of that odict was hitting the base case below.
1261    if isinstance(node, (dict, odict)):
1262        return [
1263          ast.Name(id=n, ctx=None) for n in node.keys()
1264        ] + list(node.values())
1265
1266    return []
1267
1268
1269imatchCount = 0 # For showing stack depth in debugging prints for imathces
1270
1271
1272def imatches(node, pat, env1, seq):
1273    """Exponential backtracking match generating 0 or more matching
1274    environments.  Supports multiple sequence patterns in one context,
1275    simple permutations of semantically equivalent but syntactically
1276    mirrored patterns.
1277    """
1278    # Lyn change early-return pattern to named result returned at end to
1279    # allow debugging prints
1280    # global imatchCount #$
1281    # print('\n$ {}Entering imatches({}, {}, {}, {})'.format(
1282    #         '| '*imatchCount, dump(node), dump(pat), dump(env1), seq))
1283    # imatchCount += 1 #$
1284    result = iterempty() # default result if not overridden
1285    if env1 is None:
1286        result = iterempty()
1287        # imatchCount -= 1 #$
1288        # print(
1289        #     '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format(
1290        #         '| ' * imatchCount,
1291        #         dump(node),
1292        #         dump(pat),
1293        #         dump(env1),
1294        #         seq,
1295        #         dump(result)
1296        #     )
1297        # )
1298        return result
1299    env = env1.value
1300    assert env is not None
1301    if (
1302        (
1303            type(pat) == bool
1304         or type(pat) == str # All Python 3 strings are unicode
1305         or pat is None
1306        )
1307    and node == pat
1308    ):
1309        result = iterone(env)
1310    elif type(pat) == int or type(pat) == float:
1311        # Literal int or float pattern
1312        if type(node) == type(pat):
1313            if (
1314                (type(pat) == int and node == pat)
1315             or (type(pat) == float and abs(node - pat) < 0.001)
1316            ):
1317                result = iterone(env)
1318            pass
1319    elif node_var(pat):
1320        # Var pattern.
1321        # Match and bind name to node.
1322        if node_is_bindable(node):
1323            # [Peter Mawhorter 2021-8-29] Attempting to allow import
1324            # aliases to unify with variable references later on. If the
1325            # alias has an 'as' part we unify as if it's a name with that
1326            # ID, otherwise we use the name part as the ID.
1327            if isinstance(node, ast.alias):
1328                if node.asname:
1329                    bind_as = ast.Name(id=node.asname, ctx=None)
1330                else:
1331                    bind_as = ast.Name(id=node.name, ctx=None)
1332            else:
1333                bind_as = node
1334            env2 = bind(env1, node_var(pat), bind_as)
1335            if env2:
1336                result = iterone(env2.value)
1337            pass
1338    elif typed_lit_var(pat):
1339        # Var pattern to bind only a literal of given type.
1340        id, ty = typed_lit_var(pat)
1341        lit = LIT_TYPES[ty](node)
1342        # Match and bind name to literal.
1343        if lit:
1344            env2 = bind(env1, id, lit.value)
1345            if env2:
1346                result = iterone(env2.value)
1347            pass
1348    elif type(node) == type(pat):
1349        # Node and pattern have same type.
1350        if type(pat) == list:
1351            # Node and pattern are both lists.  Do positional matching.
1352            if len(pat) == 0:
1353                # Empty list pattern.  Node must also be empty.
1354                if len(node) == 0:
1355                    result = iterone(env)
1356                pass
1357            elif len(node) == 0:
1358                # Non-empty list pattern with empty node.
1359                # Try to match sequence subpatterns.
1360                if seq:
1361                    psn = set_var(pat[0])
1362                    if psn:
1363                        result = imatches(node, pat[1:],
1364                                          bind(env1, psn, []), seq)
1365                    pass
1366                pass
1367            else:
1368                # Both are non-empty.
1369                psn = set_var(pat[0])
1370                if seq and psn:
1371                    # First subpattern is a sequence pattern.
1372                    # Try all consumption sizes, greediest first.
1373                    # Unsophisticated exponential backtracking search.
1374                    result = ichain(
1375                        imatches(
1376                            node[i:],
1377                            pat[1:],
1378                            bind(env1, psn, node[:i]),
1379                            seq
1380                        )
1381                        for i in range(len(node), -1, -1)
1382                    )
1383                # Lyn sez: common special case helpful for more concrete
1384                # debugging results (e.g., may return FiniteIterator
1385                # rather than itertools.chain object.)
1386                elif len(node) == 1 and len(pat) == 1:
1387                    result = imatches(node[0], pat[0], env1, True)
1388                else:
1389                    # For all matches of scalar first element sub pattern.
1390                    # Generate all corresponding matches of remainder.
1391                    # Unsophisticated exponential backtracking search.
1392                    result = ichain(
1393                        imatches(node[1:], pat[1:], Some(bs), seq)
1394                       for bs in imatches(node[0], pat[0], env1, True)
1395                    )
1396                pass
1397        elif type(node) == dict or type(node) == odict:
1398            result = match_dict(node, pat, env1)
1399        else:
1400            # Node and pat have same type, but are not lists.
1401            # Match scalar structures by matching lists of their fields.
1402            if isinstance(node, ast.AST):
1403                # TODO: DEBUG
1404                #if isinstance(node, ast.Import):
1405                #    print(
1406                #        "FV2i",
1407                #        dump(field_values(node)),
1408                #        dump(field_values(pat))
1409                #    )
1410                #if isinstance(node, ast.alias):
1411                #    print(
1412                #        "FV2a",
1413                #        dump(field_values(node)),
1414                #        dump(field_values(pat))
1415                #    )
1416                result = imatches(
1417                    field_values(node),
1418                    field_values(pat),
1419                    env1,
1420                    False
1421                )
1422            pass
1423        pass
1424    # return iterempty()
1425    # imatchCount -= 1 #$
1426    # print(
1427    #     '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format(
1428    #         '| ' * imatchCount,
1429    #         dump(node),
1430    #         dump(pat),
1431    #         dump(env1),
1432    #         seq,
1433    #         dump(result)
1434    #     )
1435    # )
1436    return result
1437
1438
1439def match_dict(node, pat, env1):
1440    # Node and pattern are both dictionaries. Do nominal matching.
1441    # Match all named key patterns, then all single-key pattern variables,
1442    # and finally the multi-key pattern variables.
1443    assert all(type(k) == str # All Python 3 strings are unicode
1444               for k in pat)
1445    assert all(type(k) == str # All Python 3 strings are unicode
1446               for k in node)
1447
1448    def match_keys(node, pat, envopt):
1449        """Match literal keys."""
1450        keyopt = takeone(k for k in pat
1451                         if not node_var_str(k)
1452                         and not set_var_str(k))
1453        if keyopt:
1454            # There is at least one named key in the pattern.
1455            # If this key is also in the program node, then for each
1456            # match of the corresponding node value and pattern value,
1457            # generate all matches for the remaining keys.
1458            key = keyopt.value
1459            if key in node:
1460                return ichain(match_keys(dict_unbind(node, key),
1461                                         dict_unbind(pat, key),
1462                                         Some(kenv))
1463                              for kenv in imatches(node[key], pat[key],
1464                                                   envopt, False))
1465            else:
1466                return iterempty()
1467            pass
1468        else:
1469            # The pattern contains no literal keys.
1470            # Generate all matches for the node and set key variables.
1471            return match_var_keys(node, pat, envopt)
1472        pass
1473
1474    def match_var_keys(node, pat, envopt):
1475        """Match node variable keys."""
1476        keyvaropt = takeone(k for k in pat if node_var_str(k))
1477        if keyvaropt:
1478            # There is at least one single-key variable in the pattern.
1479            # For each key-value pair in the node whose value matches
1480            # this single-key variable's associated value pattern,
1481            # generate all matches for the remaining keys.
1482            keyvar = keyvaropt.value
1483            return ichain(match_var_keys(dict_unbind(node, nkey),
1484                                         dict_unbind(pat, keyvar),
1485                                         bind(Some(kenv),
1486                                              node_var_str(keyvar),
1487                                              ast.Name(id=nkey, ctx=None)))
1488                          # outer iteration:
1489                          for nkey, nval in node.items()
1490                          # inner iteration:
1491                          for kenv in imatches(nval, pat[keyvar],
1492                                               envopt, False))
1493        else:
1494            # The pattern contains no single-key variables.
1495            # Generate all matches for the set key variables.
1496            return match_set_var_keys(node, pat, envopt)
1497        pass
1498
1499    def match_set_var_keys(node, pat, envopt):
1500        """Match set variable keys."""
1501        # NOTE: see discussion of match environments for this case:
1502        # https://github.com/wellesleycs111/codder/issues/25
1503        assert envopt
1504        keysetvaropt = takeone(k for k in pat if set_var_str(k))
1505        if keysetvaropt:
1506            # There is a multi-key variable in the pattern.
1507            # Capture all remaining key-value pairs in the node.
1508            e = bind(envopt, set_var_str(keysetvaropt.value),
1509                     [(ast.Name(id=kw, ctx=None), karg)
1510                      for kw, karg in node.items()])
1511            return iterone(e.value) if e else iterempty()
1512        elif 0 == len(node):
1513            # There is no multi-key variable in the pattern.
1514            # There is a match only if there are no remaining
1515            # keys in the node.
1516            # There should also be no remaining keys in the pattern.
1517            assert 0 == len(pat)
1518            return iterone(envopt.value)
1519        else:
1520            return iterempty()
1521
1522    return match_keys(node, pat, env1)
1523
1524
1525def find(node, pat, **kwargs):
1526    """Pre-order search for first sub-AST matching pattern, returning
1527    (matched node, bindings)."""
1528    kwargs['gen'] = True
1529    return takeone(findall(node, parse_pattern(pat), **kwargs))
1530
1531
1532def findall(node, pat, outside=[], **kwargs):
1533    """
1534    Search for all sub-ASTs matching pattern, returning list of (matched
1535    node, bindings).
1536    """
1537    assert node is not None
1538    assert pat is not None
1539    gen = kwargs.get('gen', False)
1540    kwargs['gen'] = True
1541    pat = parse_pattern(pat)
1542    # Top-level sequence patterns are not "anchored" to the ends of
1543    # the containing block when *finding* a submatch within a node (as
1544    # opposed to matching a node exactly).  They may match any
1545    # contiguous subsequence.
1546    # - To allow sequence patterns to match starting later than the
1547    #   beginning of a program sequence, matching is attempted
1548    #   recursively with smaller and smaller suffixes of the program
1549    #   sequence.
1550    # - To allow sequence patterns to match ending earlier
1551    #   than the end of a program block, we implicitly ensure that a
1552    #   sequence wildcard pattern terminates every top-level sequence
1553    #   pattern.
1554    # TODO: Because permutations are applied later, this doesn't allow
1555    # all the matching we'd like in the following scenario
1556    # Node: [ alias(name="x"), alias(name="y"), alias(name="z") ]
1557    # Pat: [ alias(name="_a_"), alias(name="___") ]
1558    # here because order of aliases doesn't matter, we *should* be able
1559    # to bind _a_ to x, y, OR z, but it can only bind to x or z, NOT y
1560    if type(pat) == list and not set_var(pat[-1]):
1561        pat = pat + [parse_pattern('___')]
1562        pass
1563
1564    def findall_scalar_pat_iter(node):
1565        """Generate all matches of a scalar (non-list) pattern at this node
1566        or any non-excluded descendant of this node."""
1567        assert type(pat) != list
1568        assert node is not None
1569        # Yield any environment(s) for match(es) of pattern at node.
1570        envs = [e for e in matchpat(node, pat, **kwargs)]
1571        if 0 < len(envs):
1572            yield (node, envs)
1573        # Continue the search for matches in sub-ASTs of node, only if
1574        # node is not excluded by a "match outside" pattern.
1575        if not any(match(node, op) for op in outside):
1576            for n in field_values(node):
1577                if n: # Search only within non-None children.
1578                    for result in findall_scalar_pat_iter(n):
1579                        yield result
1580                        pass
1581                    pass
1582                pass
1583            pass
1584        pass
1585
1586    def findall_list_pat_iter(node):
1587        """Generate all matches of a list pattern at this node or any
1588        non-excluded descendant of this node."""
1589        assert type(pat) == list
1590        assert 0 < len(pat)
1591        if type(node) == list:
1592            # If searching against a list:
1593            # - match against the list itself
1594            # - search against the first child of the list
1595            # - search against the tail of the list
1596
1597            # Match against the list itself.
1598            # Yield any environment(s) for match(es) of pattern at node.
1599            envs = [e for e in matchpat(node, pat, **kwargs)]
1600            if 0 < len(envs):
1601                yield (node, envs)
1602            # Continue the search for matches in sub-ASTs of node,
1603            # only if node is not excluded by a "match outside"
1604            # pattern.
1605            if (
1606                not any(match(node, op) for op in outside)
1607            and 0 < len(node) # only in nonempty nodes...
1608            ):
1609                # Search for matches in the first sub-AST.
1610                for m in findall_list_pat_iter(node[0]):
1611                    yield m
1612                    pass
1613                if not set_var(pat[0]):
1614                    # If this sequence pattern does not start with
1615                    # a sequence wildcard, then:
1616                    # Search for matches in the tail of the list.
1617                    # (Includes matches against the entire tail.)
1618                    for m in findall_list_pat_iter(node[1:]):
1619                        yield m
1620                        pass
1621                    pass
1622                pass
1623            pass
1624        elif not any(match(node, op) for op in outside): # and node is not list
1625            # A list pattern cannot match against a scalar node.
1626            # Search for matches in children of this scalar (non-list)
1627            # node, only if node is not excluded by a "match outside"
1628            # pattern.
1629
1630            # Optimize to search only where list patterns could match.
1631
1632            # Body blocks
1633            for ty in [ast.ClassDef, ast.FunctionDef, ast.With, ast.Module,
1634                       ast.If, ast.For, ast.While,
1635                       # ast.TryExcept, ast.TryFinally,
1636                       # In Python 3, a single Try node covers both
1637                       # TryExcept and TryFinally
1638                       ast.Try,
1639                       ast.ExceptHandler]:
1640                if isinstance(node, ty):
1641                    for m in findall_list_pat_iter(node.body):
1642                        yield m
1643                        pass
1644                    break
1645                pass
1646            # Block else blocks
1647            for ty in [ast.If, ast.For, ast.While,
1648                       # ast.TryExcept
1649                       # In Python 3, a single Try node covers both
1650                       # TryExcept and TryFinally
1651                       ast.Try
1652                       ]:
1653                if isinstance(node, ty) and node.orelse:
1654                    for m in findall_list_pat_iter(node.orelse):
1655                        yield m
1656                        pass
1657                    break
1658                pass
1659
1660#             # Except handler blocks
1661#             if isinstance(node, ast.TryExcept):
1662#                 for h in node.handlers:
1663#                     for m in findall_list_pat_iter(h.body):
1664#                         yield m
1665#                         pass
1666#                     pass
1667#                 pass
1668#             # finally blocks
1669#             if isinstance(node, ast.TryFinally):
1670#                 for m in findall_list_pat_iter(node.finalbody):
1671#                         yield m
1672#                         pass
1673#                 pass
1674
1675            # In Python 3, a single Try node covers both TryExcept and
1676            # TryFinally
1677            if isinstance(node, ast.Try):
1678                for h in node.handlers:
1679                    for m in findall_list_pat_iter(h.body):
1680                        yield m
1681                        pass
1682                    pass
1683                pass
1684
1685            # General non-optimized version.
1686            # Must be mutually exclusive with the above if used.
1687            # for n in field_values(node):
1688            #                 if n:
1689            #                     for result in findall_list_pat_iter(n):
1690            #                         yield result
1691            #                         pass
1692            #                     pass
1693            #                 pass
1694            pass
1695        pass
1696    # Apply the right search based on pattern type.
1697    matches = (findall_list_pat_iter if type(pat) == list
1698               else findall_scalar_pat_iter)(node)
1699    # Return the generator or a list of all generated matches,
1700    # depending on gen.
1701    return matches if gen else list(matches)
1702
1703
1704def count(node, pat, **kwargs):
1705    """
1706    Count all sub-ASTs matching pattern. Does NOT count individual
1707    environments that match (i.e., ways that bindings could attach at a
1708    given node), but rather counts nodes at which one or more bindings
1709    are possible.
1710    """
1711    assert 'gen' not in kwargs
1712    return sum(1 for x in findall(node, pat,
1713                                  gen=True, **kwargs))
1714
1715
1716#---------------------------------------------------------#
1717# EXPERIMENTAL: Normalize/Inline Simple Straightline Code #
1718#---------------------------------------------------------#
1719
1720class Unimplemented(Exception):
1721    pass
1722
1723
1724SIMPLE_INLINE_TYPES = [
1725    ast.Expr, ast.Return,
1726    # ast.Print, # Removed from ast module in Python 3
1727    ast.Name, ast.Store, ast.Load, ast.Param,
1728    # Simple augs should be desugared already.
1729    # Those remaining are too tricky for a simple implementation.
1730]
1731
1732
1733class InlineAvailableExpressions(ast.NodeTransformer):
1734    def __init__(self, other=None):
1735        self.available = dict(other.available) if other else {}
1736
1737    def visit_Name(self, name):
1738        if isinstance(name.ctx, ast.Load) and name.id in self.available:
1739            return self.available[name.id]
1740        else:
1741            return self.generic_visit(name)
1742
1743    def visit_Assign(self, assign):
1744        raise Unimplemented
1745        new = ast.copy_location(ast.Assign(
1746            targets=assign.targets,
1747            value=self.visit(assign.value)
1748        ), assign)
1749        self.available[assign.targets[0].id] = new.value
1750        return new
1751
1752    def visit_If(self, ifelse):
1753        # Inline into the test.
1754        test = self.visit(ifelse.test)
1755        # Inline and accumulate in the then and else independently.
1756        body_inliner = InlineAvailableExpressions(self)
1757        orelse_inliner = InlineAvailableExpressions(self)
1758        body = body_inliner.inline_block(ifelse.body)
1759        orelse = orelse_inliner.inline_block(ifelse.orelse)
1760        # Any var->expression that is available after both branches
1761        # is available after.
1762        self.available = {
1763            name: body_inliner.available[name]
1764            for name in (set(body_inliner.available)
1765                         & set(orelse_inliner.available))
1766            if (body_inliner.available[name] == orelse_inliner.available[name])
1767        }
1768        return ast.copy_location(
1769            ast.If(test=test, body=body, orelse=orelse),
1770            ifelse
1771        )
1772
1773    def generic_visit(self, node):
1774        if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES):
1775            raise Unimplemented()
1776        return ast.NodeTransformer.generic_visit(self, node)
1777
1778    def inline_block(self, block):
1779        # Introduce duplicate common subexpressions...
1780        return [self.visit(stmt) for stmt in block]
1781
1782
1783class DeadCodeElim(ast.NodeTransformer):
1784    def __init__(self, other=None):
1785        self.used = set(other.users) if other else set()
1786        pass
1787
1788    def visit_Name(self, name):
1789        if isinstance(name.ctx, ast.Store):
1790            # Var store defines, removing from the set.
1791            self.used = self.used - {name.id}
1792            return name
1793        elif isinstance(name.ctx, ast.Load):
1794            # Var use uses, adding to the set.
1795            self.used = self.used | {name.id}
1796            return name
1797        else:
1798            return name
1799        pass
1800
1801    def visit_Assign(self, assign):
1802        # This restriction prevents worries about things like:
1803        # x, x[0] = [[1], 2]
1804        # By using this restriction it is safe to keep the single set,
1805        # thus order of removals and additions will not be a problem
1806        # since defs are discovered first (always to left of =), then
1807        # uses are discovered next (to right of =).
1808        assert all(
1809            (
1810                node_is_name(t)
1811             or (
1812                    isinstance(t, ast.Tuple)
1813                and all(node_is_name(t) for t in t.elts)
1814                )
1815            )
1816            for t in assign.targets
1817        )
1818        # Now handled by visit_Name
1819        # self.used = self.used - set(n.id for n in assign.targets)
1820        if (any(t.id in self.used for t in assign.targets if node_is_name(t))
1821            or any(t.id in self.used
1822                   for tup in assign.targets for t in tup
1823                   if type(tup) == tuple and node_is_name(t))):
1824            return ast.copy_location(ast.Assign(
1825                targets=[self.visit(t) for t in assign.targets],
1826                value=self.visit(assign.value)
1827            ), assign)
1828        else:
1829            return None
1830
1831    def visit_If(self, ifelse):
1832        body_elim = DeadCodeElim(self)
1833        orelse_elim = DeadCodeElim(self)
1834        # DCE the body
1835        body = body_elim.elim_block(ifelse.body)
1836        # DCE the else
1837        orelse = orelse_elim.elim_block(ifelse.body)
1838        # Use the test -- TODO: could eliminate entire if sometimes.
1839        # Keep it for now for clarity.
1840        self.used = body_elim.used | orelse_elim.used
1841        test = self.visit(ifelse.test)
1842        return ast.copy_location(ast.If(test=test, body=body, orelse=orelse),
1843                                 ifelse)
1844
1845    def generic_visit(self, node):
1846        if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES):
1847            raise Unimplemented()
1848        return ast.NodeTransformer.generic_visit(self, node)
1849
1850    def elim_block(self, block):
1851        # Introduce duplicate common subexpressions...
1852        return [s for s in
1853                (self.visit(stmt) for stmt in block[::-1])
1854                if s][::-1]
1855
1856
1857class NormalizePure(ast.NodeTransformer):
1858    """AST transformer: normalize/inline straightline assignments into
1859    single expressions as possible."""
1860    def normalize_block(self, block):
1861        try:
1862            return DeadCodeElim().elim_block(
1863                InlineAvailableExpressions().inline_block(block)
1864            )
1865        except Unimplemented:
1866            return block
1867
1868    def visit_FunctionDef(self, fun):
1869        # Lyn warning: previously commented code below is from Python 2
1870        # and won't work in Python 3 (args have changed).
1871        # assert 0 < len(
1872        #    result.value[0].intersection(set(a.id for a in fun.args.args))
1873        # )
1874        normbody = self.normalize_block(fun.body)
1875        if normbody != fun.body:
1876            return ast.copy_location(ast.FunctionDef(
1877                name=fun.name,
1878                args=self.generic_visit(fun.args),
1879                body=normbody,
1880            ), fun)
1881        else:
1882            return fun
1883
1884
1885# Note 2021-7-29 Peter Mawhorter: Disabled use of DesugarAugAssign here
1886# because of concerns about students using list += string. If you want
1887# to match += you'll have to do so explicitly (and if that's a common use
1888# case, we can add something to specifications and/or patterns for that).
1889def canonical_pure(node):
1890    """Return the normalized/inlined version of an AST node."""
1891    # print(type(fun))
1892    # assert isinstance(fun, ast.FunctionDef)
1893    # print()
1894    # print('FUN', dump(fun))
1895    if type(node) == list:
1896        return NormalizePure().normalize_block(node)
1897        #    [DesugarAugAssign().visit(stmt) for stmt in node]
1898        #)
1899    else:
1900        assert isinstance(node, ast.AST)
1901        #return NormalizePure().visit(DesugarAugAssign().visit(node))
1902        return NormalizePure().visit(node)
1903
1904
1905CANON_CACHE = {}
1906
1907
1908def parse_canonical_pure(string, toplevel=False):
1909    """Parse a normalized/inlined version of a program."""
1910    if type(string) == list:
1911        return [parse_canonical_pure(x) for x in string]
1912    elif string not in CANON_CACHE:
1913        CANON_CACHE[string] = canonical_pure(
1914            parse_pattern(string, toplevel=toplevel)
1915        )
1916        pass
1917    return CANON_CACHE[string]
1918
1919#------------------------------------#
1920# Pretty Print ASTs to Python Source #
1921#------------------------------------#
1922
1923
1924INDENT = '    '
1925
1926
1927def indent(pat, indent=INDENT):
1928    """Apply indents to a source string."""
1929    return indent + pat.replace('\n', '\n' + indent)
1930
1931
1932class SourceFormatter(ast.NodeVisitor):
1933    """AST visitor: pretty print AST to python source string"""
1934    def __init__(self):
1935        ast.NodeVisitor.__init__(self)
1936        self._indent = ''
1937        pass
1938
1939    def indent(self):
1940        self._indent += INDENT
1941        pass
1942
1943    def unindent(self):
1944        self._indent = self._indent[:-4]
1945        pass
1946
1947    def line(self, ln):
1948        return self._indent + ln + '\n'
1949
1950    def lines(self, lst):
1951        return ''.join(lst)
1952
1953    def generic_visit(self, node):
1954        assert False, 'visiting {}'.format(ast.dump(node))
1955        pass
1956
1957    def visit_Module(self, m):
1958        return self.lines(self.visit(n) for n in m.body)
1959
1960    def visit_Interactive(self, i):
1961        return self.lines(self.visit(n) for n in i.body)
1962
1963    def visit_Expression(self, e):
1964        return self.line(self.visit(e.body))
1965
1966    def visit_FunctionDef(self, f):
1967        assert not f.decorator_list
1968        header = self.line('def {name}({args}):'.format(
1969            name=f.name,
1970            args=self.visit(f.args))
1971        )
1972        self.indent()
1973        body = self.lines(self.visit(s) for s in f.body)
1974        self.unindent()
1975        return header + body + '\n'
1976
1977    def visit_ClassDef(self, c):
1978        assert not c.decorator_list
1979        header = self.line('class {name}({bases}):'.format(
1980            name=c.name,
1981            bases=', '.join(self.visit(b) for b in c.bases)
1982        ))
1983        self.indent()
1984        body = self.lines(self.visit(s) for s in c.body)
1985        self.unindent()
1986        return header + body + '\n'
1987
1988    def visit_Return(self, r):
1989        return self.line('return' if r.value is None
1990                         else 'return {}'.format(self.visit(r.value)))
1991
1992    def visit_Delete(self, d):
1993        return self.line('del ' + ''.join(self.visit(e) for e in d.targets))
1994
1995    def visit_Assign(self, a):
1996        return self.line(', '.join(self.visit(e)
1997                                   for e in a.targets)
1998                         + ' = ' + self.visit(a.value))
1999
2000    def visit_AugAssign(self, a):
2001        return self.line('{target} {op}= {expr}'.format(
2002            target=self.visit(a.target),
2003            op=self.visit(a.op),
2004            expr=self.visit(a.value))
2005        )
2006
2007# Print removed as ast node on Python 3
2008#     def visit_Print(self, p):
2009#         assert p.dest == None
2010#         return self.line('print {}{}'.format(
2011#             ', '.join(self.visit(e) for e in p.values),
2012#             ',' if p.values and not p.nl else ''
2013#         ))
2014
2015    def visit_For(self, f):
2016        header = self.line('for {} in {}:'.format(
2017            self.visit(f.target),
2018            self.visit(f.iter))
2019        )
2020        self.indent()
2021        body = self.lines(self.visit(s) for s in f.body)
2022        orelse = self.lines(self.visit(s) for s in f.orelse)
2023        self.unindent()
2024        return header + body + (
2025            self.line('else:') + orelse
2026            if f.orelse else ''
2027        )
2028
2029    def visit_While(self, w):
2030        # Peter 2021-6-16: Removed this assert; orelse isn't defined here
2031        # assert not orelse
2032        header = self.line('while {}:'.format(self.visit(w.test)))
2033        self.indent()
2034        body = self.lines(self.visit(s) for s in w.body)
2035        orelse = self.lines(self.visit(s) for s in w.orelse)
2036        self.unindent()
2037        return header + body + (
2038            self.line('else:') + orelse
2039            if w.orelse else ''
2040        )
2041        return header + body
2042
2043    def visit_If(self, i):
2044        header = self.line('if {}:'.format(self.visit(i.test)))
2045        self.indent()
2046        body = self.lines(self.visit(s) for s in i.body)
2047        orelse = self.lines(self.visit(s) for s in i.orelse)
2048        self.unindent()
2049        return header + body + (
2050            self.line('else:') + orelse
2051            if i.orelse else ''
2052        )
2053
2054    def visit_With(self, w):
2055        # Converted to Python3 withitems:
2056        header = self.line(
2057            'with {items}:'.format(
2058                items=', '.join(
2059                    '{expr}{asnames}'.format(
2060                        expr=self.visit(item.context_expr),
2061                        asnames=('as ' + self.visit(item.optional_vars)
2062                                 if item.optional_vars else '')
2063                    )
2064                        for item in w.items
2065                )
2066            )
2067        )
2068        self.indent()
2069        body = self.lines(self.visit(s) for s in w.body)
2070        self.unindent()
2071        return header + body
2072
2073    # Python 3: raise has new abstract syntax
2074    def visit_Raise(self, r):
2075        return self.line('raise{}{}'.format(
2076            (' ' + self.visit(r.exc)) if r.exc else '',
2077            (' from ' + self.visit(r.cause)) if r.cause else ''
2078        ))
2079
2080#    def visit_Raise(self, r):
2081#        return self.line('raise{}{}{}{}{}'.format(
2082#             self.visit(r.type) if r.type else '',
2083#             ', ' if r.type and r.inst else '',
2084#             self.visit(r.inst) if r.inst else '',
2085#             ', ' if (r.type or r.inst) and r.tback else '',
2086#             self.visit(r.tback) if r.tback else ''
2087#                 ))
2088
2089#     def visit_TryExcept(self, te):
2090#         self.indent()
2091#         tblock = self.lines(self.visit(s) for s in te.body)
2092#         orelse = self.lines(self.visit(s) for s in te.orelse)
2093#         self.unindent()
2094#         return (
2095#             self.line('try:')
2096#             + tblock
2097#             + ''.join(self.visit(eh) for eh in te.handlers)
2098#             + (self.line('else:') + orelse if orelse else '' )
2099#         )
2100
2101#     def visit_TryFinally(self, tf):
2102#         self.indent()
2103#         tblock = self.lines(self.visit(s) for s in tf.body)
2104#         fblock = self.lines(self.visit(s) for s in tf.finalbody)
2105#         self.unindent()
2106#         return (
2107#             self.line('try:')
2108#             + tblock
2109#             + self.line('finally:')
2110#             + fblock
2111#         )
2112
2113    # In Python 3, a single Try node covers both TryExcept and TryFinally
2114    def visit_Try(self, t):
2115        self.indent()
2116        tblock = self.lines(self.visit(s) for s in t.body)
2117        orelse = self.lines(self.visit(s) for s in t.orelse)
2118        fblock = self.lines(self.visit(s) for s in t.finalbody)
2119        self.unindent()
2120        return (
2121            self.line('try:')
2122            + tblock
2123            + ''.join(self.visit(eh) for eh in t.handlers)
2124            + (self.line('else:') + orelse if orelse else '' )
2125            + (self.line('finally:') + fblock if fblock else '' )
2126        )
2127
2128    def visit_ExceptHandler(self, eh):
2129        header = self.line('except{}{}{}{}:'.format(
2130            ' ' if eh.type else '',
2131            self.visit(eh.type) if eh.type else '',
2132            ' as ' if eh.type and eh.name else ' ' if eh.name else '',
2133            self.visit(eh.name) if eh.name and isinstance(eh.name, ast.AST)
2134              else (eh.name if eh.name else '')
2135        ))
2136        self.indent()
2137        body = self.lines(self.visit(s) for s in eh.body)
2138        self.unindent()
2139        return header + body
2140
2141    def visit_Assert(self, a):
2142        return self.line('assert {}{}{}'.format(
2143            self.visit(a.test),
2144            ', ' if a.msg else '',
2145            self.visit(a.msg) if a.msg else ''
2146        ))
2147
2148    def visit_Import(self, i):
2149        return self.line(
2150            'import {}'.format(', '.join(self.visit(n) for n in i.names))
2151        )
2152
2153    def visit_ImportFrom(self, f):
2154        return self.line('from {}{} import {}'.format(
2155            '.' * f.level,
2156            f.module if f.module else '',
2157            ', '.join(self.visit(n) for n in f.names)
2158        ))
2159
2160    def visit_Exec(self, e):
2161        return self.line('exec {}{}{}{}{}'.format(
2162            self.visit(e.body),
2163            ' in ' if e.globals else '',
2164            self.visit(e.globals) if e.globals else '',
2165            ', ' if e.locals else '',
2166            self.visit(e.locals) if e.locals else ''
2167        ))
2168
2169    def visit_Global(self, g):
2170        return self.line('global {}'.format(', '.join(g.names)))
2171
2172    def visit_Expr(self, e):
2173        return self.line(self.visit(e.value))
2174
2175    def visit_Pass(self, p):
2176        return self.line('pass')
2177
2178    def visit_Break(self, b):
2179        return self.line('break')
2180
2181    def visit_Continue(self, c):
2182        return self.line('continue')
2183
2184    def visit_BoolOp(self, b):
2185        return ' {} '.format(
2186            self.visit(b.op)
2187        ).join('({})'.format(self.visit(e)) for e in b.values)
2188
2189    def visit_BinOp(self, b):
2190        return '({}) {} ({})'.format(
2191            self.visit(b.left),
2192            self.visit(b.op),
2193            self.visit(b.right)
2194        )
2195
2196    def visit_UnaryOp(self, u):
2197        return '{} ({})'.format(
2198            self.visit(u.op),
2199            self.visit(u.operand)
2200        )
2201
2202    def visit_Lambda(self, ld):
2203        return '(lambda {}: {})'.format(
2204            self.visit(ld.args),
2205            self.visit(ld.body)
2206        )
2207
2208    def visit_IfExp(self, i):
2209        return '({} if {} else {})'.format(
2210            self.visit(i.body),
2211            self.visit(i.test),
2212            self.visit(i.orelse)
2213        )
2214
2215    def visit_Dict(self, d):
2216        return '{{ {} }}'.format(
2217            ', '.join('{}: {}'.format(self.visit(k), self.visit(v))
2218                      for k, v in zip(d.keys, d.values))
2219        )
2220
2221    def visit_Set(self, s):
2222        return '{{ {} }}'.format(', '.join(self.visit(e) for e in s.elts))
2223
2224    def visit_ListComp(self, lc):
2225        return '[{} {}]'.format(
2226            self.visit(lc.elt),
2227            ' '.join(self.visit(g) for g in lc.generators)
2228        )
2229
2230    def visit_SetComp(self, sc):
2231        return '{{{} {}}}'.format(
2232            self.visit(sc.elt),
2233            ' '.join(self.visit(g) for g in sc.generators)
2234        )
2235
2236    def visit_DictComp(self, dc):
2237        return '{{{} {}}}'.format(
2238            '{}: {}'.format(self.visit(dc.key), self.visit(dc.value)),
2239            ' '.join(self.visit(g) for g in dc.generators)
2240        )
2241
2242    def visit_GeneratorExp(self, ge):
2243        return '({} {})'.format(
2244            self.visit(ge.elt),
2245            ' '.join(self.visit(g) for g in ge.generators)
2246        )
2247
2248    def visit_Yield(self, y):
2249        return 'yield {}'.format(self.visit(y.value) if y.value else '')
2250
2251    def visit_Compare(self, c):
2252        assert len(c.ops) == len(c.comparators)
2253        return '{} {}'.format(
2254            '({})'.format(self.visit(c.left)),
2255            ' '.join(
2256                '{} ({})'.format(self.visit(op), self.visit(expr))
2257                for op, expr in zip(c.ops, c.comparators)
2258            )
2259        )
2260
2261    def visit_Call(self, c):
2262        # return '{fun}({args}{keys}{starargs}{starstarargs})'.format(
2263        # Unlike Python 2, Python 3 has no starargs or startstarargs
2264        return '{fun}({args}{keys})'.format(
2265            fun=self.visit(c.func),
2266            args=', '.join(self.visit(a) for a in c.args),
2267            keys=(
2268                (', ' if c.args else '')
2269              + (
2270                  ', '.join(self.visit(ka) for ka in c.keywords)
2271                    if c.keywords else ''
2272                )
2273            )
2274        )
2275
2276    def visit_Repr(self, r):
2277        return 'repr({})'.format(self.visit(r.expr))
2278
2279    def visit_Num(self, n):
2280        return repr(n.n)
2281
2282    def visit_Str(self, s):
2283        return repr(s.s)
2284
2285    def visit_Attribute(self, a):
2286        return '{}.{}'.format(self.visit(a.value), a.attr)
2287
2288    def visit_Subscript(self, s):
2289        return '{}[{}]'.format(self.visit(s.value), self.visit(s.slice))
2290
2291    def visit_Name(self, n):
2292        return n.id
2293
2294    def visit_List(self, ls):
2295        return '[{}]'.format(', '.join(self.visit(e) for e in ls.elts))
2296
2297    def visit_Tuple(self, tp):
2298        return '({})'.format(', '.join(self.visit(e) for e in tp.elts))
2299
2300    def visit_Ellipsis(self, s):
2301        return '...'
2302
2303    def visit_Slice(self, s):
2304        return '{}:{}{}{}'.format(
2305            self.visit(s.lower) if s.lower else '',
2306            self.visit(s.upper) if s.upper else '',
2307            ':' if s.step else '',
2308            self.visit(s.step) if s.step else ''
2309        )
2310
2311    def visit_ExtSlice(self, es):
2312        return ', '.join(self.visit(s) for s in es.dims)
2313
2314    def visit_Index(self, i):
2315        return self.visit(i.value)
2316
2317    def visit_And(self, a):
2318        return 'and'
2319
2320    def visit_Or(self, o):
2321        return 'or'
2322
2323    def visit_Add(self, a):
2324        return '+'
2325
2326    def visit_Sub(self, a):
2327        return '-'
2328
2329    def visit_Mult(self, a):
2330        return '*'
2331
2332    def visit_Div(self, a):
2333        return '/'
2334
2335    def visit_Mod(self, a):
2336        return '%'
2337
2338    def visit_Pow(self, a):
2339        return '**'
2340
2341    def visit_LShift(self, a):
2342        return '<<'
2343
2344    def visit_RShift(self, a):
2345        return '>>'
2346
2347    def visit_BitOr(self, a):
2348        return '|'
2349
2350    def visit_BixXor(self, a):
2351        return '^'
2352
2353    def visit_BitAnd(self, a):
2354        return '&'
2355
2356    def visit_FloorDiv(self, a):
2357        return '//'
2358
2359    def visit_Invert(self, a):
2360        return '~'
2361
2362    def visit_Not(self, a):
2363        return 'not'
2364
2365    def visit_UAdd(self, a):
2366        return '+'
2367
2368    def visit_USub(self, a):
2369        return '-'
2370
2371    def visit_Eq(self, a):
2372        return '=='
2373
2374    def visit_NotEq(self, a):
2375        return '!='
2376
2377    def visit_Lt(self, a):
2378        return '<'
2379
2380    def visit_LtE(self, a):
2381        return '<='
2382
2383    def visit_Gt(self, a):
2384        return '>'
2385
2386    def visit_GtE(self, a):
2387        return '>='
2388
2389    def visit_Is(self, a):
2390        return 'is'
2391
2392    def visit_IsNot(self, a):
2393        return 'is not'
2394
2395    def visit_In(self, a):
2396        return 'in'
2397
2398    def visit_NotIn(self, a):
2399        return 'not in'
2400
2401    def visit_comprehension(self, c):
2402        return 'for {} in {}{}{}'.format(
2403            self.visit(c.target),
2404            self.visit(c.iter),
2405            ' ' if c.ifs else '',
2406            ' '.join('if {}'.format(self.visit(i)) for i in c.ifs)
2407        )
2408
2409    def visit_arg(self, a):
2410        '''[2019/01/22, lyn] Handle new arg objects in Python 3.'''
2411        return a.arg # The name of the argument
2412
2413    def visit_keyword(self, k):
2414        return '{}={}'.format(k.arg, self.visit(k.value))
2415
2416    def visit_alias(self, a):
2417        return '{} as {}'.format(a.name, a.asname) if a.asname else a.name
2418
2419    def visit_arguments(self, a):
2420        # [2019/01/22, lyn] Note: This does *not* handle Python 3's
2421        # keyword-only arguments (probably moot for 111, but not
2422        # beyond).
2423        stdargs = a.args[:-len(a.defaults)] if a.defaults else a.args
2424        defargs = (
2425            zip(a.args[-len(a.defaults):], a.defaults)
2426            if a.defaults else []
2427        )
2428        return '{stdargs}{sep1}{defargs}{sep2}{varargs}{sep3}{kwargs}'.format(
2429            stdargs=', '.join(self.visit(sa) for sa in stdargs),
2430            sep1=', ' if 0 < len(stdargs) and defargs else '',
2431            defargs=', '.join('{}={}'.format(self.visit(da), self.visit(dd))
2432                              for da, dd in defargs),
2433            sep2=', ' if 0 < len(a.args) and a.vararg else '',
2434            varargs='*{}'.format(a.vararg) if a.vararg else '',
2435            sep3=', ' if (0 < len(a.args) or a.vararg) and a.kwarg else '',
2436            kwargs='**{}'.format(a.kwarg) if a.kwarg else ''
2437        )
2438
2439    def visit_NameConstant(self, nc):
2440        return str(nc.value)
2441
2442    def visit_Starred(self, st):
2443        # TODO: Is this correct?
2444        return '*' + st.value.id
2445
2446
2447def ast2source(node):
2448    """Pretty print an AST as a python source string"""
2449    return SourceFormatter().visit(node)
2450
2451
2452def source(node):
2453    """Alias for ast2source"""
2454    return ast2source(node)
2455
2456
2457#---------#
2458# Testing #
2459#---------#
2460
2461if __name__ == '__main__':
2462    tests = [
2463        (True, 1, '2', '2'),
2464        (False, 0, '2', '3'),
2465        (True, 1, 'x', '_a_'),
2466        (True, 1, '(x, y)', '(_a_, _b_)'),
2467        (False, 0, '(x, y)', '(_a_, _a_)'),
2468        (True, 1, '(x, x)', '(_a_, _a_)'),
2469        (True, 4, 'max(7,3)', '_x_'),
2470        (True, 1, 'max(7,3)', 'max(7,3)'),
2471        (False, 0, 'max(7,2)', 'max(7,3)'),
2472        (True, 1, 'max(7,3,5)', 'max(___args___)'),
2473        (False, 1, 'min(max(7,3),5)', 'max(___args___)'),
2474        (True, 2, 'min(max(7,3),5)', '_f_(___args___)'),
2475        (True, 1, 'min(max(7,3),5)', 'min(max(___maxargs___),___minargs___)'),
2476        (True, 1, 'max()', 'max(___args___)'),
2477        (True, 1, 'max(4)', 'max(4,___args___)'),
2478        (True, 1, 'max(4,5,6)', 'max(4,___args___)'),
2479        (True, 1, '"hello %s" % x', '_a_str_ % _b_'),
2480        (False, 0, 'y % x', '_a_str_ % _b_'),
2481        (False, 0, '7 % x', '_a_str_ % _b_'),
2482        (True, 1, '3', '_a_int_'),
2483        (True, 1, '3.4', '_a_float_'),
2484        (False, 0, '3', '_a_float_'),
2485        (False, 0, '3.4', '_a_int_'),
2486        (True, 1, 'True', '_a_bool_'),
2487        (True, 1, 'False', '_a_bool_'),
2488        (True, 1, 'None', 'None'),
2489        # node vard can bind statements or exprs based on context.
2490        (True, 7, 'print("hello"+str(3))', '_x_'), # 6 => 7 in Python 3
2491        (True, 1, 'print("hello"+str(3))', 'print(_x_)'),
2492        (True, 1, 'print(1)', 'print(_x_, ___args___)'),
2493        (False, 0, 'print(1)', 'print(_x_, _y_, ___args___)'),
2494        (False, 0, 'print(1, 2)', 'print(_x_)'),
2495        (True, 1, 'print(1, 2)', 'print(_x_, ___args___)'),
2496        (True, 1, 'print(1, 2)', 'print(_x_, _y_, ___args___)'),
2497        (True, 1, 'print(1, 2, 3)', 'print(_x_, ___args___)'),
2498        (True, 1, 'print(1, 2, 3)', 'print(_x_, _y_, ___args___)'),
2499        (True, 1,
2500         '''
2501def f(x):
2502    return 17
2503         ''',
2504         '''
2505def f(_a_):
2506    return _b_
2507         '''),
2508        (True, 1,
2509         '''
2510def f(x):
2511    return x
2512         ''',
2513         '''
2514def f(_a_):
2515    return _b_
2516         '''),
2517        (True, 1,
2518         '''
2519def f(x):
2520    return x
2521         ''',
2522         '''
2523def f(_a_):
2524    return _a_
2525         '''),
2526        (False, 0,
2527         '''
2528def f(x):
2529    return 17
2530         ''',
2531         '''
2532def f(_a_):
2533    return _a_
2534         '''),
2535        (True, 1,
2536         '''
2537def f(x):
2538    return 17
2539         ''',
2540         '''
2541def f(_x_):
2542    return _y_
2543         '''),
2544        (True, 1,
2545         '''
2546def f(x,y):
2547    print('hi')
2548    return x
2549         ''',
2550         '''
2551def f(_x_,_y_):
2552    print('hi')
2553    return _x_
2554         '''),
2555        (True, 1,
2556         '''
2557def f(x,y):
2558    print('hi')
2559    return x
2560         ''',
2561         '''
2562def _f_(_x_,_y_):
2563    print('hi')
2564    return _x_
2565         '''),
2566        (False, 0,
2567         '''
2568def f(x,y):
2569    print('hi')
2570    return y
2571         ''',
2572         '''
2573def f(_a_,_b_):
2574    print('hi')
2575    return _a_
2576         '''),
2577        (False, 0, 'x', 'y'),
2578        (True, 1,
2579         '''
2580def f(x,y):
2581    print('hi')
2582    return y
2583         ''',
2584         '''
2585def _f_(_x_,_y_):
2586    ___
2587    return _y_
2588         '''),
2589        (True, 1,
2590         '''
2591def f(x,y):
2592    print('hi')
2593    print('world')
2594    print('bye')
2595    return y
2596         ''',
2597         '''
2598def _f_(_x_,_y_):
2599    ___stmts___
2600    return _y_
2601         '''),
2602        (False, 0,
2603         '''
2604def f(x,y):
2605    print('hi')
2606    print('world')
2607    x = 4
2608    print('really')
2609    y = 7
2610    print('bye')
2611    return y
2612         ''',
2613         '''
2614def _f_(_x_,_y_):
2615    ___stmts___
2616    print(_z_)
2617    _y_ = _a_int_
2618    return _y_
2619         '''),
2620        (True, 1,
2621         '''
2622def f(x,y):
2623    print('hi')
2624    print('world')
2625    x = 4
2626    print('really')
2627    y = 7
2628    print('bye')
2629    return y
2630         ''',
2631         '''
2632def _f_(_x_,_y_):
2633    ___stmts___
2634    print(_z_)
2635    _y_ = _a_int_
2636    ___more___
2637    return _y_
2638         '''),
2639        (True, 1,
2640         '''
2641def f(x,y):
2642    print('hi')
2643    print('world')
2644    x = 4
2645    print('really')
2646    y = 7
2647    print('bye')
2648    return y
2649         ''',
2650         '''
2651def _f_(_x_,_y_):
2652    ___stmts___
2653    print(_a_)
2654    _b_ = _c_
2655    ___more___
2656    return _d_
2657         '''),
2658        (False, 1,
2659         '''
2660def f(x,y):
2661    print('hi')
2662    print('world')
2663    x = 4
2664    print('really')
2665    y = 7
2666    print('bye')
2667    return y
2668         ''',
2669         '''
2670___stmts___
2671print(_a_)
2672_b_ = _c_
2673___more___
2674return _d_
2675         '''),
2676        (True, 1,
2677         '''
2678def eyes():
2679    eye1 = Layer()
2680    eye2 = Layer()
2681    face = Layer()
2682    face.add(eye)
2683    face.add(eye2)
2684    return face
2685         ''',
2686         '''
2687def eyes():
2688    ___
2689    _face_ = Layer()
2690    ___
2691    return _face_
2692         '''),
2693        (True, 1,
2694         '''
2695def f(x,y):
2696    """I have a docstring.
2697    It is a couple lines long."""
2698    eye = Layer()
2699    if eye:
2700        face = Layer()
2701    else:
2702        face = Layer()
2703    eye2 = Layer()
2704    face.add(eye)
2705    face.add(eye2)
2706    return face
2707         ''',
2708         '''
2709def f(___args___):
2710    eye = Layer()
2711    ___
2712         '''),
2713        (True, 1, '1 == 2', '2 == 1'),
2714        (True, 1, '1 <= 2', '2 >= 1'),
2715        (False, 0, '1 <= 2', '2 <= 1'),
2716        (False, 0, 'f() <= 2', '2 >= f()'),
2717        # Hmnm, is this the semantics we want for `and`?
2718        (True, 1, 'a and b and c', 'b and a and c'),
2719        (True, 1, '(a == b) == (b == c)', '(a == b) == (c == b)'),
2720        (True, 1, '(a and b) and c', 'a and (b and c)'),
2721        (True, 1, 'a and b', 'a and b'),
2722        (True, 1, 'g == "a" or g == "b" or g == "c"',
2723         '_g_ == _a_ or _g_ == _b_ or _c_ == _g_'),
2724        (True, 1, '''
2725x = 1
2726y = 2
2727''', '''
2728x = 1
2729y = 2
2730'''),
2731        (True, 1, '''
2732x = 1
2733y = 2
2734''', '''
2735___
2736y = 2
2737'''),
2738        (True, 1, '''
2739x = 1
2740if (a or b or c):
2741    return True
2742else:
2743    return False
2744        ''',
2745         '''
2746___
2747if _:
2748    return _a_bool_
2749else:
2750    return _b_bool_
2751___
2752         '''),
2753        (True, 1, '''
2754if (a or b or c):
2755    return True
2756else:
2757    return False
2758        ''',
2759         '''
2760if _:
2761    return _a_bool_
2762else:
2763    return _b_bool_
2764         '''),
2765        (True, 1, '''
2766x = 1
2767if (a or b or c):
2768    return True
2769else:
2770    return False
2771        ''',
2772         '''
2773___
2774if _:
2775    return _a_bool_
2776else:
2777    return _b_bool_
2778         '''),
2779        (False, 1, '''
2780def f():
2781    if (a or b or c):
2782        return True
2783    else:
2784        return False
2785        ''',
2786         '''
2787___
2788if _:
2789    return _a_bool_
2790else:
2791    return _b_bool_
2792___
2793         '''),
2794        (False, 1, '''
2795def isValidGesture(gesture):
2796    if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'):
2797        return True
2798    else:
2799        return False
2800        ''', '''
2801if _:
2802    return _a_bool_
2803else:
2804    return _b_bool_
2805        '''),
2806        (False, 1, '''
2807def isValidGesture(gesture):
2808    print('blah')
2809    if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'):
2810        return True
2811    return False
2812        ''', '''
2813___
2814if _:
2815    return _a_bool_
2816return _b_bool_
2817        '''),
2818
2819        (False, 1, '''
2820def isValidGesture(gesture):
2821    print('blah')
2822    if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'):
2823        return True
2824    return False
2825        ''', '''
2826if _:
2827    return _a_bool_
2828return _b_bool_
2829        '''),
2830        (False, 1, '''
2831def isValidGesture(gesture):
2832    if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'):
2833        return True
2834    return False
2835        ''', '''
2836if _:
2837    return _a_bool_
2838return _b_bool_
2839        '''),
2840        (False, 1, '''
2841def isValidGesture(gesture):
2842    if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'):
2843        x = True
2844    x = False
2845    return x
2846        ''', '''
2847if _:
2848    _x_ = _a_bool_
2849_x_ = _b_bool_
2850        '''),
2851        (True, 1, '''
2852def isValidGesture(gesture):
2853    if (gesture == 'rock' or gesture == 'paper' or gesture == 'scissors'):
2854        return True
2855    return False
2856        ''', '''
2857def _(_):
2858    if _:
2859        return _a_bool_
2860    return _b_bool_
2861        '''),
2862        (True, 1, 'x, y = f()', '___vars___ = f()'),
2863        (True, 1, 'x, y = f()', '_vars_ = f()'),
2864        (True, 1, 'f(a=1, b=2)', 'f(b=2, a=1)'),
2865        (True, 1, '''
2866def f(x,y):
2867    """with a docstring"""
2868    if level <= 0:
2869        pass
2870    else:
2871        fd(3)
2872        lt(90)
2873        ''', '''
2874def f(_, _):
2875    ___
2876    if _:
2877        ___t___
2878    else:
2879        ___e___
2880    ___s___
2881'''),
2882        (True, 1, '''
2883class A(B):
2884    def f(self, x):
2885        pass
2886        ''', 'class _(_): ___'),
2887        # (True, True,
2888        #  'for x, y in f():\n    ___',
2889        #  'for ___vars___ in f():\n    ___'),
2890        (False, 0, 'drawLs(size/2, level - 1)', 'drawLs(_size_/2.0, _)'),
2891        (False, 0, '2', '2.0'),
2892        (False, 0, '2', '_d_float_'),
2893        (True, 1, '''
2894def keepFirstLetter(phrase):
2895    """Returns a new string that contains only the first occurrence of a
2896    letter from the original phrase.
2897    The first letter occurrence can be upper or lower case.
2898    Non-alpha characters (such as punctuations and space) are left unchanged.
2899    """
2900    #this list holds lower-cased versions of all of the letters already used
2901    usedCharacters = []
2902
2903    finalPhrase = ""
2904    for n in phrase:
2905        if n.isalpha():
2906            #we need to create a temporary lower-cased version of the letter,
2907            #so that we can check and see if we've seen an upper or lower-cased
2908            #version of this letter before
2909            tempN = n.lower()
2910            if tempN not in usedCharacters:
2911                usedCharacters.append(tempN)
2912                #but we need to add the original n, so that we can preserve
2913                #if it was upper cased or not
2914                finalPhrase = finalPhrase + n
2915
2916        #this adds all non-letter characters into the final phrase list
2917        else:
2918            finalPhrase = finalPhrase + n
2919
2920    return finalPhrase
2921        ''', '''
2922def _(___):
2923    ___
2924    _acc_ = ""
2925    ___
2926    for _ in _:
2927        ___
2928    return _acc_
2929    ___
2930        '''),
2931        (False, # Pattern should not match program
2932         1, # Pattern should be found within program
2933         # Program
2934         '''
2935def keepFirstLetter(phrase):
2936    #this list holds lower-cased versions of all of the letters already used
2937    usedCharacters = []
2938
2939    finalPhrase = ""
2940    for n in phrase:
2941        if n.isalpha():
2942            #we need to create a temporary lower-cased version of the letter,
2943            #so that we can check and see if we've seen an upper or lower-cased
2944            #version of this letter before
2945            tempN = n.lower()
2946            if tempN not in usedCharacters:
2947                usedCharacters.append(tempN)
2948                #but we need to add the original n, so that we can preserve
2949                #if it was upper cased or not
2950                finalPhrase = finalPhrase + n
2951
2952        #this adds all non-letter characters into the final phrase list
2953        else:
2954            finalPhrase = finalPhrase + n
2955
2956    return finalPhrase
2957        ''',
2958         # Pattern
2959         '''
2960___prefix___
2961_acc_ = ""
2962___middle___
2963for _ in _:
2964    ___
2965return _acc_
2966___suffix___
2967        '''),
2968        (True, 1, '''
2969a = 1
2970b = 2
2971c = 3
2972        ''', '''___; _x_ = _n_'''),
2973        (False, # Pattern should not match program
2974         1, # Pattern should be found within program
2975         # Program
2976         '''
2977def f(a):
2978    x = ""
2979    return x
2980        ''',
2981         # Pattern
2982         '''
2983_acc_ = ""
2984return _acc_
2985        '''),
2986
2987        # Treatment of elses:
2988        (True, 1, 'if x: print(1)', 'if _: _'),
2989        (False, 0, 'if x: print(1)\nelse: pass', 'if _: _'),
2990        (True, 1, 'if x: print(1)\nelse: pass', 'if _: _\nelse: ___'),
2991        (False, 0, 'if x: print(1)', '''
2992if _: _
2993else:
2994    _
2995    ___
2996'''),
2997        (True, 1, 'if x: print(1)\nelse: pass', '''
2998if _: _
2999else:
3000    _
3001    ___
3002'''),
3003
3004        # If ExpandExplicitElsePattern is used:
3005        # (False, 0, 'if x: print(1)', 'if _: _\nelse: ___'),
3006
3007        # If ExpandExplicitElsePattern is NOT used:
3008        (True, 1, 'if x: print(1)', 'if _: _\nelse: ___'),
3009
3010        # Keyword arguments
3011        (True, 1, 'f(a=1)', 'f(a=1)'),
3012        (True, 1, 'f(a=1)', 'f(_kw_=1)'),
3013        (True, 1, 'f(a=1)', 'f(_kw_=_arg_)'),
3014        (True, 1, 'f(a=1)', 'f(_=1)'),
3015        (True, 1, 'f(a=1)', 'f(_=_arg_)'),
3016        (True, 1, 'f(a=1, b=2)', 'f(_x_=_, _y_=_)'),
3017        (True, 1, 'f(a=1, b=2)', 'f(_x_=2, _y_=_)'),
3018        (False, 0, 'f(a=1, b=2)', 'f(_x_=_)'),
3019        (False, 0, 'f(a=1, b=2)', 'f(_x_=_, _y_=_, _z_=_)'),
3020        (False, 0, 'f(a=1, b=2)', 'f(_x_=2, _y_=2)'),
3021        (True, 1, 'f(a=1, b=2)', 'f(_x_=_, b=_)'),
3022        (True, 1, 'f(a=1, b=2)', 'f(b=_, _x_=_)'),
3023        (True, 1, 'f(a=1+1, b=1+1)', 'f(_c_=_x_+_x_, _d_=_y_+_y_)'),
3024        (True, 1, 'f(a=1+1, b=2+2)', 'f(_c_=_x_+_x_, _d_=_y_+_y_)'),
3025        (True, 1, 'f(a=1+2, b=2+1)', 'f(_c_=_x_+_y_, _d_=_y_+_x_)'),
3026        (True, 1, 'f(a=1+1, b=1+1)', 'f(_c_=_x_+_x_, _d_=_x_+_x_)'),
3027        (False, 0, 'f(a=1+1, b=2+2)', 'f(_c_=_x_+_x_, _d_=_x_+_x_)'),
3028        (True, 1, 'f(a=1, b=2)', 'f(___=_)'),
3029        (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_)'),
3030        (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_, b=_)'),
3031        (True, 1, 'f(a=1, b=2)', 'f(b=_, ___kwargs___=_)'),
3032        (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_, a=_, b=_)'),
3033        (True, 1, 'f(a=1, b=2)', 'f(a=_, b=_, ___kwargs___=_)'),
3034        (True, 1, 'f(a=1, b=2)', 'f(___kwargs___=_, b=_, a=_)'),
3035        (True, 1, 'f(a=1, b=2)', 'f(b=_, a=_, ___kwargs___=_)'),
3036        (True, 1, 'f(a=1, b=2)', 'f(a=_, ___kwargs___=_, b=_)'),
3037        (True, 1, 'f(a=1, b=2)', 'f(b=_, ___kwargs___=_, a=_)'),
3038        (True, 1, 'b = 7; f(a=1, b=2)', '_x_ = _; f(_x_=_, _y_=_)'),
3039        (False, 0, 'b = 7; f(a=1, b=2)', '_x_ = _; f(_x_=1, _y_=_)'),
3040
3041        # and/or set wildcards
3042        (True, 1, 'x or y or z', '_x_ or ___rest___'),
3043        (True, 1, 'x or y or z', '_x_ or _y_ or ___rest___'),
3044        (True, 1, 'x or y or z', '_x_ or _y_ or _z_ or ___rest___'),
3045
3046        # kwarg matching
3047        (True, 1, 'def f(x=3):\n  return x', 'def _f_(x=3):\n  return x'),
3048        (True, 1, 'def f(x=3):\n  return x', 'def _f_(_=3):\n  return _'),
3049        (True, 1, 'def f(x=3):\n  return x', 'def _f_(_=3):\n  ___'),
3050        (True, 1, 'def f(x=3):\n  return x', 'def _f_(_x_=3):\n  return _x_'),
3051        (True, 1,
3052         'def f(y=7):\n  return y', 'def _f_(_x_=_y_):\n  return _x_'),
3053        (False, 0, 'def f(x=3):\n  return x', 'def _f_(_y_=7):\n  return _x_'),
3054
3055        # Succeeds!
3056        (True, 1, 'def f(x=12):\n  return x', 'def _f_(_=_):\n  ___'),
3057
3058        # TODO: Fails because we compare all defaults as one block
3059        # Zero extra keyword arguments:
3060        (True, 1, 'def f(x=17):\n  return x', 'def _f_(_=17, ___=_):\n  ___'),
3061
3062        # Should match because ___ has no default
3063        (True, 1, 'def f(x=17):\n  return x', 'def _f_(___, _=17):\n  ___'),
3064
3065        # Multiple kwargs
3066        (True, 1, 'def f(x=3, y=4):\n  return x', 'def _f_(_=3, _=4):\n  ___'),
3067        (True, 1, 'def f(x=5, y=6):\n  return x', 'def _f_(_=_, _=_):\n  ___'),
3068        # ___ doesn't match kwargs
3069        (False, 0, 'def f(x=7, y=8):\n  return x', 'def _f_(___):\n  ___'),
3070        # TODO: Fails because of default matching bug
3071        (True, 1, 'def f(x=9, y=10):\n  return x', 'def _f_(___=_):\n  ___'),
3072
3073        # Exact matching of kwarg expressions
3074        (True, 1, 'def f(x=y+3):\n  return x', 'def _f_(_=y+3):\n  ___'),
3075
3076        # Matching of kw-only args
3077        (True, 1,
3078         'def f(*a, x=5):\n  return x',
3079         'def _f_(*_, _x_=5):\n  return _x_'),
3080        # ___ does not match *_
3081        (False, 0,
3082         'def f(*a, x=6):\n  return x',
3083         'def _f_(___, _x_=6):\n  return _x_'),
3084
3085        # Multiple kw-only args
3086        (True, 1,
3087         'def f(*a, x=5, y=6):\n  return x, y',
3088         'def _f_(*_, _x_=5, _y_=6):\n  return _x_, _y_'),
3089        (False, 0,
3090         'def f(*a, x=7, y=8):\n  return x, y',
3091         'def _f_(___, _x_=7, _y_=8):\n  return _x_, _y_'),
3092
3093        # Function with docstring (must use ast.get_docstring!)
3094        (False, 0, 'def f():\n  """docstring"""', 'def _f_(___):\n  _a_str_'),
3095
3096        # Function with docstring (using ast.get_docstring)
3097        (True, 1,
3098         'def f(x):\n  """doc1"""\n  return x',
3099         'def _f_(___):\n  ___',
3100         lambda node, env: ast.get_docstring(node) is not None),
3101
3102        # Function without docstring (using ast.get_docstring)
3103        (False, 0,
3104         'def f(x):\n  """doc2"""\n  return x',
3105         'def _f_(___):\n  ___',
3106         lambda node, env:(
3107             ast.get_docstring(node) is None
3108          or ast.get_docstring(node).strip() == ''
3109         )),
3110        (True, 1,
3111         'def f(x):\n  """"""\n  return x',
3112         'def _f_(___):\n  ___',
3113         lambda node, env:(
3114             ast.get_docstring(node) is None
3115          or ast.get_docstring(node).strip() == ''
3116         )),
3117        (True, 1,
3118         'def nodoc(x):\n  return x',
3119         'def _f_(___):\n  ___',
3120         lambda node, env:(
3121             ast.get_docstring(node) is None
3122          or ast.get_docstring(node).strip() == ''
3123         )),
3124
3125        # TODO: Recursive matching of kwarg expressions
3126        (True, 1, 'def f(x=y+3):\n  return x', 'def _f_(_=_+3):\n  ___'),
3127
3128        # Function with multiple normal arguments
3129        (True, 1, 'def f(x, y, z):\n  return x', 'def _f_(___):\n  ___'),
3130
3131        # Matching redundant elif conditions
3132        (
3133          True,
3134          1,
3135          'if x == 3:\n  return x\nelif not x == 3 and x > 5:\n  return x-2',
3136          'if _cond_:\n  ___\nelif not _cond_ and ___:\n  ___'
3137        ),
3138        ( # order matters
3139          False,
3140          0,
3141          'if x == 3:\n  return x\nelif x > 5 and not x == 3:\n  return x-2',
3142          'if _cond_:\n  ___\nelif not _cond_ and ___:\n  ___'
3143        ),
3144        ( # not == is not the same as !=
3145          False,
3146          0,
3147          'if x == 3:\n  return x\nelif x > 5 and x != 3:\n  return x-2',
3148          'if _cond_:\n  ___\nelif not _cond_ and ___:\n  ___'
3149        ),
3150        ( # not == is not the same as !=
3151          True,
3152          1,
3153          'if x == 3:\n  return x\nelif x > 5 and not x == 3:\n  return x-2',
3154          'if _cond_:\n  ___\nelif ___ and not _cond_:\n  ___'
3155        ),
3156        ( # not == is not the same as !=
3157          True,
3158          1,
3159          'if x == 3:\n  return x\nelif x > 5 and x != 3:\n  return x-2',
3160          'if _n_ == _v_:\n  ___\nelif ___ and _n_ != _v_:\n  ___'
3161        ),
3162        ( # extra conditions do matter!
3163          False,
3164          0,
3165          'if x == 3:\n  return x\nelif not x == 3 and x > 5:\n  return x-2\n'
3166        + 'elif not x == 3 and x <= 5 and x < 0:\n  return 0',
3167          'if _cond_:\n  ___\nelif not _cond_ and ___:\n  ___'
3168        ),
3169        ( # match extra conditions:
3170          True,
3171          1,
3172          'if x == 3:\n  return x\nelif not x == 3 and x > 5:\n  return x-2\n'
3173        + 'elif not x == 3 and x <= 5 and x < 0:\n  return 0',
3174          'if _cond_:\n  ___\nelif not _cond_ and ___:\n  ___\nelif _:\n  ___'
3175        ),
3176        ( # number of conditions must match exactly:
3177          False,
3178          0,
3179          'if x == 3:\n  return x\nelif not x == 3 and x > 5:\n  return x-2\n'
3180        + 'elif not x == 3 and x <= 5 and x < 0:\n  return 0\n'
3181        + 'elif not x == 3 and x <= 5 and x >= 0 and x == 1:\n  return 1.5',
3182          'if _cond_:\n  ___\nelif not _cond_ and ___:\n  ___\nelif _:\n  ___'
3183        ),
3184        ( # matching with elif:
3185          True,
3186          2,
3187          'if x < 0:\n  x += 1\nelif x < 10:\n  x += 0.5\nelse:\n  x += 0.25',
3188          'if _:\n  ___\nelse:\n  ___'
3189        ),
3190        ( # left/right branches are both okay
3191          True,
3192          1,
3193          'x == 3',
3194          '3 == x',
3195        ),
3196        ( # order does matter for some operators
3197         False,
3198         0,
3199         '1 + 2 + 3',
3200         '3 + 2 + 1'
3201        ),
3202        ( # but not for others
3203         True,
3204         1,
3205         '1 and 2 and 3',
3206         '3 and 2 and 1'
3207        ),
3208    ]
3209    testNum = 1
3210    passes = 0
3211    fails = 0
3212    #for (expect_match, expect_count, ns, ps) in tests:
3213    for test_spec in tests:
3214        if len(test_spec) == 4:
3215            expect_match, expect_count, ns, ps = test_spec
3216            matchpred = predtrue
3217        elif len(test_spec) == 5:
3218            expect_match, expect_count, ns, ps, matchpred = test_spec
3219        else:
3220            print("Tests must have 4 or 5 parts!")
3221            print(test_spec)
3222            exit(1)
3223
3224        print('$' + '-' * 60)
3225        print('Test #{}'.format(testNum))
3226        testNum += 1
3227        n = parse(ns)
3228        n = (
3229            n.body[0].value
3230            if type(n.body[0]) == ast.Expr and len(n.body) == 1
3231            else (
3232                n.body[0]
3233                if len(n.body) == 1
3234                else n.body
3235            )
3236        )
3237        p = parse_pattern(ps)
3238        print('Program: %s => %s' % (ns, dump(n)))
3239        if isinstance(ns, ast.FunctionDef):
3240            print('Docstring: %s' % (ast.get_docstring(ns)))
3241        print('Pattern: %s => %s' % (ps, dump(p)))
3242        if matchpred != predtrue:
3243            print('Matchpred: {}'.format(matchpred.__name__))
3244        for gen in [False, True]:
3245            result = match(n, p, gen=gen, matchpred=matchpred)
3246            if gen:
3247                result = takeone(result)
3248                pass
3249            passed = bool(result) == expect_match
3250            if passed:
3251                passes += 1
3252            else:
3253                fails += 1
3254                pass
3255            print('match(gen=%s): %s  [%s]' % (
3256                gen, bool(result), "PASS" if passed else "FAIL"
3257            ))
3258            # if result:
3259            #     for (k,v) in result.value.items():
3260            #         print("  %s = %s" % (k, dump(v)))
3261            #         pass
3262            #     pass
3263            # pass
3264        opt = find(n, p, matchpred=matchpred)
3265        findpassed = bool(opt) == (0 < expect_count)
3266        if findpassed:
3267            passes += 1
3268        else:
3269            fails += 1
3270            pass
3271        print(
3272            'find: {}  [{}]'.format(
3273                bool(opt),
3274                # dump(opt.value[0]) if opt else None,
3275                "PASS" if findpassed else "FAIL"
3276            )
3277        )
3278        # if opt:
3279        #     for (k,v) in opt.value[1].items():
3280        #         print("  %s = %s" % (k, dump(v)))
3281        #         pass
3282        c = count(n, p, matchpred=matchpred)
3283        if c == expect_count:
3284            passes += 1
3285            print('count: %s  [PASS]' % c)
3286        else:
3287            fails += 1
3288            print('count: %s  [FAIL], expected %d' % (c, expect_count))
3289            pass
3290        print('findall:')
3291        nmatches = 0
3292        for (node, envs) in findall(n, p, matchpred=matchpred):
3293            print("  %d.  %s" % (nmatches, dump(node)))
3294            for (i, env) in enumerate(envs):
3295                print("      %s.  " % chr(ord('a') + i))
3296                for (k, v) in env.items():
3297                    print("          %s = %s" % (k, dump(v)))
3298                    pass
3299                pass
3300            # find should return first findall result.
3301            # assert 0 < nmatches or (opt and opt.value == (node, bindings))
3302            nmatches += 1
3303            pass
3304        # count should count the same number of matches that findall finds.
3305        assert nmatches == c
3306        print()
3307        pass
3308    print("%d passed, %d failed" % (passes, fails))
3309
3310    with open(__file__) as src:
3311        with open('self_printed_' + os.path.basename(__file__), 'w') as dest:
3312            dest.write(source(ast.parse(src.read())))
def ichain(iterable, /):

Alternative chain() constructor taking a single iterable argument that evaluates lazily.

def dump(node):
 98def dump(node):
 99    """Nicer display of ASTs."""
100    return (
101        ast.dump(node) if isinstance(node, ast.AST)
102        else '[' + ', '.join(map(dump, node)) + ']' if type(node) == list
103        else '(' + ', '.join(map(dump, node)) + ')' if type(node) == tuple
104        else '{' + ', '.join(dump(key) + ': ' + dump(value)
105                             for key, value in node.items()) + '}'
106            if type(node) == dict
107        # Sneaky way to have dump descend into Some and FiniteIterator objects
108        # for debugging
109        else node.dump(dump)
110          if (type(node) == Some or type(node) == FiniteIterator)
111          else repr(node)
112    )

Nicer display of ASTs.

def showret(x, prefix=''):
115def showret(x, prefix=''):
116    """Shim for use debugging AST return values."""
117    print(prefix, dump(x))
118    return x

Shim for use debugging AST return values.

NODE_VAR_SPEC = (1, re.compile('\\A_(?:([a-zA-Z0-9]+)_)?\\Z'))

Spec for names of scalar pattern variables.

SEQ_TYPES = {<class 'ast.Call'>, <class 'ast.BoolOp'>, <class 'ast.Tuple'>, <class 'ast.List'>, <class 'NoneType'>, <class 'list'>, <class 'ast.Assign'>, <class 'ast.arguments'>, <class 'ast.Expr'>}

Types against which sequence patterns match.

SET_VAR_SPEC = (1, re.compile('\\A___(?:([a-zA-Z0-9]+)___)?\\Z'))

Spec for names of sequence pattern variables.

TYPED_LIT_VAR_SPEC = ((1, 2), re.compile('\\A_([a-zA-Z0-9]+)_(int|float|str|bool)_\\Z'))

Spec for names/types of typed literal pattern variables.

def var_is_anonymous(identifier):
205def var_is_anonymous(identifier):
206    """Determine whether a pattern variable name (string) is anonymous."""
207    assert type(identifier) == str # All Python 3 strings are unicode
208    return not re.search(r'[a-zA-Z0-9]', identifier)

Determine whether a pattern variable name (string) is anonymous.

def node_is_name(node):
211def node_is_name(node):
212    """Determine if a node is an AST Name node."""
213    return isinstance(node, ast.Name)

Determine if a node is an AST Name node.

def identifier_key(identifier, spec):
216def identifier_key(identifier, spec):
217    """Extract the name of a pattern variable from its identifier string,
218    returning an option: Some(key) if identifier is valid variable
219    name according to spec, otherwise None.
220
221    Examples for NODE_VAR_SPEC:
222    identifier_key('_a_') => Some('a')
223    identifier_key('_') => Some('_')
224    identifier_key('_a') => None
225
226    """
227    assert type(identifier) == str # All Python 3 strings are unicode
228    groups, regex = spec
229    match = regex.match(identifier)
230    if match:
231        if var_is_anonymous(identifier):
232            return identifier
233        elif type(groups) == tuple:
234            return tuple(match.group(i) for i in groups)
235        else:
236            return match.group(groups)
237    else:
238        return None

Extract the name of a pattern variable from its identifier string, returning an option: Some(key) if identifier is valid variable name according to spec, otherwise None.

Examples for NODE_VAR_SPEC: identifier_key('_a_') => Some('a') identifier_key('_') => Some('_') identifier_key('_a') => None

def node_var(pat, spec=(1, re.compile('\\A_(?:([a-zA-Z0-9]+)_)?\\Z')), wrap=False):
241def node_var(pat, spec=NODE_VAR_SPEC, wrap=False):
242    """Extract the key name of a scalar pattern variable,
243    returning Some(key) if pat is a scalar pattern variable,
244    otherwise None.
245
246    A named or anonymous node variable pattern, written `_a_` or `_`,
247    respectively, may appear in any expression or identifier context
248    in a pattern.  It matches any single AST in the corresponding
249    position in the target program.
250
251    """
252    if wrap:
253        pat = ast.Name(id=pat, ctx=None)
254    elif isinstance(pat, ast.Expr):
255        pat = pat.value
256    elif isinstance(pat, ast.alias) and pat.asname is None:
257        # [Peter Mawhorter 2021-8-29] Want to treat aliases without an
258        # 'as' part kind of like normal Name nodes.
259        pat = ast.Name(id=pat.name, ctx=None)
260    return (
261        identifier_key(pat.id, spec)
262        if node_is_name(pat)
263        else None
264    )

Extract the key name of a scalar pattern variable, returning Some(key) if pat is a scalar pattern variable, otherwise None.

A named or anonymous node variable pattern, written _a_ or _, respectively, may appear in any expression or identifier context in a pattern. It matches any single AST in the corresponding position in the target program.

def node_var_str(pat, spec=(1, re.compile('\\A_(?:([a-zA-Z0-9]+)_)?\\Z'))):
267def node_var_str(pat, spec=NODE_VAR_SPEC):
268
269    return node_var(pat, spec=spec, wrap=True)
def set_var(pat, wrap=False):
272def set_var(pat, wrap=False):
273    """Extract the key name of a set or sequence pattern variable,
274    returning Some(key) if pat is a set or sequence pattern variable,
275    otherwise None.
276
277    A named or anonymous set or sequence pattern variable, written
278    `___a___` or `___`, respectively, may appear as an element of a
279    set or sequence context in a pattern.  It matches 0 or more nodes
280    in the corresponding context in the target program.
281
282    """
283    return node_var(pat,
284                    spec=SET_VAR_SPEC,
285                    wrap=wrap)

Extract the key name of a set or sequence pattern variable, returning Some(key) if pat is a set or sequence pattern variable, otherwise None.

A named or anonymous set or sequence pattern variable, written ___a___ or ___, respectively, may appear as an element of a set or sequence context in a pattern. It matches 0 or more nodes in the corresponding context in the target program.

def set_var_str(pat):
288def set_var_str(pat):
289    return set_var(pat, wrap=True)
def typed_lit_var(pat):
292def typed_lit_var(pat):
293    """Extract the key name of a typed literal pattern variable,
294    returning Some(key) if pat is a typed literal pattern variable,
295    otherwise None.
296
297    A typed literal variable pattern, written `_a_type_`, may appear
298    in any expression context in a pattern.  It matches any single AST
299    node for a literal of the given primitive type in the
300    corresponding position in the target program.
301
302    """
303    return node_var(pat, spec=TYPED_LIT_VAR_SPEC)

Extract the key name of a typed literal pattern variable, returning Some(key) if pat is a typed literal pattern variable, otherwise None.

A typed literal variable pattern, written _a_type_, may appear in any expression context in a pattern. It matches any single AST node for a literal of the given primitive type in the corresponding position in the target program.

def is_pat(p):
317def is_pat(p):
318    """Determine if p could be a pattern (by type)."""
319    # All Python 3 strings are unicode
320    return isinstance(p, ast.AST) or type(p) == str

Determine if p could be a pattern (by type).

def node_is_docstring(node):
323def node_is_docstring(node):
324    """Is this node a docstring node?"""
325    return isinstance(node, ast.Expr) and isinstance(node.value, ast.Str)

Is this node a docstring node?

def node_is_bindable(node):
328def node_is_bindable(node):
329    """Can a node pattern variable bind to this node?"""
330    # return isinstance(node, ast.expr) or isinstance(node, ast.stmt)
331    # Modified to allow debugging prints
332    result = (
333        isinstance(node, ast.expr)
334     or isinstance(node, ast.stmt)
335     or (isinstance(node, ast.alias) and node.asname is None)
336    )
337    # print('\n$ node_is_bindable({}) => {}'.format(dump(node),result))
338    return result

Can a node pattern variable bind to this node?

def node_is_lit(node, ty):
341def node_is_lit(node, ty):
342    """Is this node a literal primitive node?"""
343    return (
344        (isinstance(node, ast.Num) and ty == type(node.n)) # noqa E721
345        # All Python 3 strings are unicode
346     or (isinstance(node, ast.Str) and ty == str)
347     or (
348            isinstance(node, ast.Name)
349        and ty == bool
350        and (node.id == 'True' or node.id == 'False')
351        )
352    )

Is this node a literal primitive node?

def expr_has_type(node, ty):
355def expr_has_type(node, ty):
356    """Is this expr statically guaranteed to be of this type?
357
358    Literals and conversions have definite types.  All other types are
359    conservatively statically unknown.
360
361    """
362    return node_is_lit(node, ty) or match(node, '{}(_)'.format(ty.__name__))

Is this expr statically guaranteed to be of this type?

Literals and conversions have definite types. All other types are conservatively statically unknown.

def node_line(node):
365def node_line(node):
366    """Bet the line number of the source line on which this node starts.
367    (best effort)"""
368    try:
369        return node.lineno if type(node) != list else node[0].lineno
370    except Exception:
371        return None

Bet the line number of the source line on which this node starts. (best effort)

class PatternSyntaxError(builtins.BaseException):
378class PatternSyntaxError(BaseException):
379    """Exception for errors in pattern syntax."""
380    def __init__(self, node, message):
381        BaseException.__init__(
382            self,
383            "At pattern line {}: {}".format(node_line(node), message)
384        )

Exception for errors in pattern syntax.

PatternSyntaxError(node, message)
380    def __init__(self, node, message):
381        BaseException.__init__(
382            self,
383            "At pattern line {}: {}".format(node_line(node), message)
384        )
Inherited Members
builtins.BaseException
with_traceback
class PatternValidator(ast.NodeVisitor):
387class PatternValidator(ast.NodeVisitor):
388    """AST visitor: check pattern structure."""
389    def __init__(self):
390        self.parent = None
391        pass
392
393    def generic_visit(self, pat):
394        oldparent = self.parent
395        self.parent = pat
396        try:
397            return ast.NodeVisitor.generic_visit(self, pat)
398        finally:
399            self.parent = oldparent
400            pass
401        pass
402
403    def visit_Name(self, pat):
404        sn = set_var(pat)
405        if sn:
406            if type(self.parent) not in SEQ_TYPES:
407                raise PatternSyntaxError(
408                    pat,
409                    "Set/sequence variable ({}) not allowed in {} node".format(
410                        sn, type(self.parent)
411                    )
412                )
413            pass
414        pass
415
416    def visit_arg(self, pat):
417        '''[2019/01/22, lyn] Python 3 now has arg object for param (not Name object).
418           So copied visit_Name here.'''
419        sn = set_var(pat)
420        if sn:
421            if type(self.parent) not in SEQ_TYPES:
422                raise PatternSyntaxError(
423                    pat,
424                    "Set/sequence variable ({}) not allowed in {} node".format(
425                        sn.value, type(self.parent)
426                    )
427                )
428            pass
429        pass
430
431    def visit_Call(self, c):
432        if 1 < sum(1 for kw in c.keywords if set_var_str(kw.arg)):
433            raise PatternSyntaxError(
434                c.keywords,
435                "Calls may use at most one keyword argument set variable."
436            )
437        return self.generic_visit(c)
438
439    def visit_keyword(self, k):
440        if (identifier_key(k.arg, SET_VAR_SPEC)
441            and not (node_is_name(k.value)
442                     and var_is_anonymous(k.value.id))):
443            raise PatternSyntaxError(
444                k.value,
445                "Value patterns for keyword argument set variables must be _."
446            )
447        return self.generic_visit(k)
448    pass

AST visitor: check pattern structure.

PatternValidator()
389    def __init__(self):
390        self.parent = None
391        pass
def generic_visit(self, pat):
393    def generic_visit(self, pat):
394        oldparent = self.parent
395        self.parent = pat
396        try:
397            return ast.NodeVisitor.generic_visit(self, pat)
398        finally:
399            self.parent = oldparent
400            pass
401        pass

Called if no explicit visitor function exists for a node.

def visit_Name(self, pat):
403    def visit_Name(self, pat):
404        sn = set_var(pat)
405        if sn:
406            if type(self.parent) not in SEQ_TYPES:
407                raise PatternSyntaxError(
408                    pat,
409                    "Set/sequence variable ({}) not allowed in {} node".format(
410                        sn, type(self.parent)
411                    )
412                )
413            pass
414        pass
def visit_arg(self, pat):
416    def visit_arg(self, pat):
417        '''[2019/01/22, lyn] Python 3 now has arg object for param (not Name object).
418           So copied visit_Name here.'''
419        sn = set_var(pat)
420        if sn:
421            if type(self.parent) not in SEQ_TYPES:
422                raise PatternSyntaxError(
423                    pat,
424                    "Set/sequence variable ({}) not allowed in {} node".format(
425                        sn.value, type(self.parent)
426                    )
427                )
428            pass
429        pass

[2019/01/22, lyn] Python 3 now has arg object for param (not Name object). So copied visit_Name here.

def visit_Call(self, c):
431    def visit_Call(self, c):
432        if 1 < sum(1 for kw in c.keywords if set_var_str(kw.arg)):
433            raise PatternSyntaxError(
434                c.keywords,
435                "Calls may use at most one keyword argument set variable."
436            )
437        return self.generic_visit(c)
def visit_keyword(self, k):
439    def visit_keyword(self, k):
440        if (identifier_key(k.arg, SET_VAR_SPEC)
441            and not (node_is_name(k.value)
442                     and var_is_anonymous(k.value.id))):
443            raise PatternSyntaxError(
444                k.value,
445                "Value patterns for keyword argument set variables must be _."
446            )
447        return self.generic_visit(k)
Inherited Members
ast.NodeVisitor
visit
visit_Constant
class RemoveDocstrings(ast.NodeTransformer):
455class RemoveDocstrings(ast.NodeTransformer):
456    """AST Transformer: remove all docstring nodes."""
457    def filterDocstrings(self, seq):
458        # print('PREFILTERED', seq)
459        filt = [self.visit(n) for n in seq
460                if not node_is_docstring(n)]
461        # print('FILTERED', dump(filt))
462        return filt
463
464    def visit_Expr(self, node):
465        if isinstance(node.value, ast.Str):
466            assert False
467            # return ast.copy_location(ast.Expr(value=None), node)
468        else:
469            return self.generic_visit(node)
470
471    def visit_FunctionDef(self, node):
472        # print('Removing docstring: %s' % dump(node.body[0].value))
473        return ast.copy_location(ast.FunctionDef(
474            name=node.name,
475            args=self.generic_visit(node.args),
476            body=self.filterDocstrings(node.body),
477        ), node)
478
479    def visit_Module(self, node):
480        return ast.copy_location(ast.Module(
481            body=self.filterDocstrings(node.body)
482        ), node)
483
484    def visit_For(self, node):
485        return ast.copy_location(ast.For(
486            body=self.filterDocstrings(node.body),
487            target=node.target,
488            iter=node.iter,
489           orelse=self.filterDocstrings(node.orelse)
490        ), node)
491
492    def visit_While(self, node):
493        return ast.copy_location(ast.While(
494            body=self.filterDocstrings(node.body),
495            test=node.test,
496            orelse=self.filterDocstrings(node.orelse)
497        ), node)
498
499    def visit_If(self, node):
500        return ast.copy_location(ast.If(
501            body=self.filterDocstrings(node.body),
502            test=node.test,
503            orelse=self.filterDocstrings(node.orelse)
504        ), node)
505
506    def visit_With(self, node):
507        return ast.copy_location(ast.With(
508            body=self.filterDocstrings(node.body),
509            # Python 3 just has withitems:
510            items=node.items
511            # Old Python 2 stuff:
512            #context_expr=node.context_expr,
513            #optional_vars=node.optional_vars
514        ), node)
515
516#     def visit_TryExcept(self, node):
517#         return ast.copy_location(ast.TryExcept(
518#             body=self.filterDocstrings(node.body),
519#             handlers=self.filterDocstrings(node.handlers),
520#             orelse=self.filterDocstrings(node.orelse)
521#         ), node)
522
523#     def visit_TryFinally(self, node):
524#         return ast.copy_location(ast.TryFinally(
525#             body=self.filterDocstrings(node.body),
526#             finalbody=self.filterDocstrings(node.finalbody)
527#         ), node)
528#
529
530# In Python 3, a single Try node covers both TryExcept and TryFinally
531    def visit_Try(self, node):
532        return ast.copy_location(ast.Try(
533            body=self.filterDocstrings(node.body),
534            handlers=self.filterDocstrings(node.handlers),
535            orelse=self.filterDocstrings(node.orelse),
536            finalbody=self.filterDocstrings(node.finalbody)
537         ), node)
538
539    def visit_ClassDef(self, node):
540        return ast.copy_location(ast.ClassDef(
541            body=self.filterDocstrings(node.body),
542            name=node.name,
543            bases=node.bases,
544            decorator_list=node.decorator_list
545        ), node)

AST Transformer: remove all docstring nodes.

RemoveDocstrings()
def filterDocstrings(self, seq):
457    def filterDocstrings(self, seq):
458        # print('PREFILTERED', seq)
459        filt = [self.visit(n) for n in seq
460                if not node_is_docstring(n)]
461        # print('FILTERED', dump(filt))
462        return filt
def visit_Expr(self, node):
464    def visit_Expr(self, node):
465        if isinstance(node.value, ast.Str):
466            assert False
467            # return ast.copy_location(ast.Expr(value=None), node)
468        else:
469            return self.generic_visit(node)
def visit_FunctionDef(self, node):
471    def visit_FunctionDef(self, node):
472        # print('Removing docstring: %s' % dump(node.body[0].value))
473        return ast.copy_location(ast.FunctionDef(
474            name=node.name,
475            args=self.generic_visit(node.args),
476            body=self.filterDocstrings(node.body),
477        ), node)
def visit_Module(self, node):
479    def visit_Module(self, node):
480        return ast.copy_location(ast.Module(
481            body=self.filterDocstrings(node.body)
482        ), node)
def visit_For(self, node):
484    def visit_For(self, node):
485        return ast.copy_location(ast.For(
486            body=self.filterDocstrings(node.body),
487            target=node.target,
488            iter=node.iter,
489           orelse=self.filterDocstrings(node.orelse)
490        ), node)
def visit_While(self, node):
492    def visit_While(self, node):
493        return ast.copy_location(ast.While(
494            body=self.filterDocstrings(node.body),
495            test=node.test,
496            orelse=self.filterDocstrings(node.orelse)
497        ), node)
def visit_If(self, node):
499    def visit_If(self, node):
500        return ast.copy_location(ast.If(
501            body=self.filterDocstrings(node.body),
502            test=node.test,
503            orelse=self.filterDocstrings(node.orelse)
504        ), node)
def visit_With(self, node):
506    def visit_With(self, node):
507        return ast.copy_location(ast.With(
508            body=self.filterDocstrings(node.body),
509            # Python 3 just has withitems:
510            items=node.items
511            # Old Python 2 stuff:
512            #context_expr=node.context_expr,
513            #optional_vars=node.optional_vars
514        ), node)
def visit_Try(self, node):
531    def visit_Try(self, node):
532        return ast.copy_location(ast.Try(
533            body=self.filterDocstrings(node.body),
534            handlers=self.filterDocstrings(node.handlers),
535            orelse=self.filterDocstrings(node.orelse),
536            finalbody=self.filterDocstrings(node.finalbody)
537         ), node)
def visit_ClassDef(self, node):
539    def visit_ClassDef(self, node):
540        return ast.copy_location(ast.ClassDef(
541            body=self.filterDocstrings(node.body),
542            name=node.name,
543            bases=node.bases,
544            decorator_list=node.decorator_list
545        ), node)
Inherited Members
ast.NodeTransformer
generic_visit
ast.NodeVisitor
visit
visit_Constant
class FlattenBoolOps(ast.NodeTransformer):
548class FlattenBoolOps(ast.NodeTransformer):
549    """AST transformer: flatten nested boolean expressions."""
550    def visit_BoolOp(self, node):
551        values = []
552        for v in node.values:
553            if isinstance(v, ast.BoolOp) and v.op == node.op:
554                values += v.values
555            else:
556                values.append(v)
557                pass
558            pass
559        return ast.copy_location(ast.BoolOp(op=node.op, values=values), node)
560    pass

AST transformer: flatten nested boolean expressions.

FlattenBoolOps()
def visit_BoolOp(self, node):
550    def visit_BoolOp(self, node):
551        values = []
552        for v in node.values:
553            if isinstance(v, ast.BoolOp) and v.op == node.op:
554                values += v.values
555            else:
556                values.append(v)
557                pass
558            pass
559        return ast.copy_location(ast.BoolOp(op=node.op, values=values), node)
Inherited Members
ast.NodeTransformer
generic_visit
ast.NodeVisitor
visit
visit_Constant
class ContainsCall(builtins.Exception):
563class ContainsCall(Exception):
564    """Exception raised when prohibiting call expressions."""
565    pass

Exception raised when prohibiting call expressions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
class ProhibitCall(ast.NodeVisitor):
568class ProhibitCall(ast.NodeVisitor):
569    """AST visitor: check call-freedom."""
570    def visit_Call(self, c):
571        raise ContainsCall
572    pass

AST visitor: check call-freedom.

ProhibitCall()
def visit_Call(self, c):
570    def visit_Call(self, c):
571        raise ContainsCall
Inherited Members
ast.NodeVisitor
visit
generic_visit
visit_Constant
class SetLoadContexts(ast.NodeTransformer):
575class SetLoadContexts(ast.NodeTransformer):
576    """
577    Transforms any AugStore contexts into Load contexts, for use with
578    DesugarAugAssign.
579    """
580    def visit_AugStore(self, ctx):
581        return ast.copy_location(ast.Load(), ctx)

Transforms any AugStore contexts into Load contexts, for use with DesugarAugAssign.

SetLoadContexts()
def visit_AugStore(self, ctx):
580    def visit_AugStore(self, ctx):
581        return ast.copy_location(ast.Load(), ctx)
Inherited Members
ast.NodeTransformer
generic_visit
ast.NodeVisitor
visit
visit_Constant
class DesugarAugAssign(ast.NodeTransformer):
584class DesugarAugAssign(ast.NodeTransformer):
585    """AST transformer: desugar augmented assignments (e.g., `+=`)
586    when simple desugaring does not duplicate effectful expressions.
587
588    FIXME: this desugaring and others cause surprising match
589    counts. See: https://github.com/wellesleycs111/codder/issues/31
590
591    FIXME: This desugaring should probably not happen in cases where
592    .__iadd__ and .__add__ yield different results, for example with
593    lists where .__iadd__ is really extend while .__add__ creates a new
594    list, such that
595
596        [1, 2] + "34"
597
598    is an error, but
599
600        x = [1, 2]
601        x += "34"
602
603    is not!
604
605    A better approach might be to avoid desugaring entirely and instead
606    provide common collections of match rules for common patterns.
607
608    """
609    def visit_AugAssign(self, assign):
610        try:
611            # Desugaring *all* AugAssigns is not sound.
612            # Example: xs[f()] += 1  -->  xs[f()] = xs[f()] + 1
613            # Check for presence of call in target.
614            ProhibitCall().visit(assign.target)
615            return ast.copy_location(
616                ast.Assign(
617                    targets=[self.visit(assign.target)],
618                    value=ast.copy_location(
619                        ast.BinOp(
620                            left=SetLoadContexts().visit(assign.target),
621                            op=self.visit(assign.op),
622                            right=self.visit(assign.value)
623                        ),
624                        assign
625                    )
626                ),
627                assign
628            )
629        except ContainsCall:
630            return self.generic_visit(assign)
631
632    def visit_AugStore(self, ctx):
633        return ast.copy_location(ast.Store(), ctx)
634
635    def visit_AugLoad(self, ctx):
636        return ast.copy_location(ast.Load(), ctx)

AST transformer: desugar augmented assignments (e.g., +=) when simple desugaring does not duplicate effectful expressions.

FIXME: this desugaring and others cause surprising match counts. See: https://github.com/wellesleycs111/codder/issues/31

FIXME: This desugaring should probably not happen in cases where .__iadd__ and .__add__ yield different results, for example with lists where .__iadd__ is really extend while .__add__ creates a new list, such that

[1, 2] + "34"

is an error, but

x = [1, 2]
x += "34"

is not!

A better approach might be to avoid desugaring entirely and instead provide common collections of match rules for common patterns.

DesugarAugAssign()
def visit_AugAssign(self, assign):
609    def visit_AugAssign(self, assign):
610        try:
611            # Desugaring *all* AugAssigns is not sound.
612            # Example: xs[f()] += 1  -->  xs[f()] = xs[f()] + 1
613            # Check for presence of call in target.
614            ProhibitCall().visit(assign.target)
615            return ast.copy_location(
616                ast.Assign(
617                    targets=[self.visit(assign.target)],
618                    value=ast.copy_location(
619                        ast.BinOp(
620                            left=SetLoadContexts().visit(assign.target),
621                            op=self.visit(assign.op),
622                            right=self.visit(assign.value)
623                        ),
624                        assign
625                    )
626                ),
627                assign
628            )
629        except ContainsCall:
630            return self.generic_visit(assign)
def visit_AugStore(self, ctx):
632    def visit_AugStore(self, ctx):
633        return ast.copy_location(ast.Store(), ctx)
def visit_AugLoad(self, ctx):
635    def visit_AugLoad(self, ctx):
636        return ast.copy_location(ast.Load(), ctx)
Inherited Members
ast.NodeTransformer
generic_visit
ast.NodeVisitor
visit
visit_Constant
class ExpandExplicitElsePattern(ast.NodeTransformer):
639class ExpandExplicitElsePattern(ast.NodeTransformer):
640    """AST transformer: transform patterns that in include `else: ___`
641    to use `else: _; ___` instead, forcing the else to have at least
642    one statement."""
643    def visit_If(self, ifelse):
644        if 1 == len(ifelse.orelse) and set_var(ifelse.orelse[0]):
645            return ast.copy_location(
646                ast.If(
647                    test=ifelse.test,
648                    body=ifelse.body,
649                    orelse=[
650                        ast.copy_location(
651                            ast.Expr(value=ast.Name(id='_', ctx=None)),
652                            ifelse.orelse[0]
653                        ),
654                        ifelse.orelse[0]
655                    ]
656                ),
657                ifelse
658            )
659        else:
660            return self.generic_visit(ifelse)

AST transformer: transform patterns that in include else: ___ to use else: _; ___ instead, forcing the else to have at least one statement.

ExpandExplicitElsePattern()
def visit_If(self, ifelse):
643    def visit_If(self, ifelse):
644        if 1 == len(ifelse.orelse) and set_var(ifelse.orelse[0]):
645            return ast.copy_location(
646                ast.If(
647                    test=ifelse.test,
648                    body=ifelse.body,
649                    orelse=[
650                        ast.copy_location(
651                            ast.Expr(value=ast.Name(id='_', ctx=None)),
652                            ifelse.orelse[0]
653                        ),
654                        ifelse.orelse[0]
655                    ]
656                ),
657                ifelse
658            )
659        else:
660            return self.generic_visit(ifelse)
Inherited Members
ast.NodeTransformer
generic_visit
ast.NodeVisitor
visit
visit_Constant
def pipe_visit(node, visitors):
663def pipe_visit(node, visitors):
664    """Send an AST through a pipeline of AST visitors/transformers."""
665    if 0 == len(visitors):
666        return node
667    else:
668        v = visitors[0]
669        if isinstance(v, ast.NodeTransformer):
670            # visit for the transformed result
671            return pipe_visit(v.visit(node), visitors[1:])
672        else:
673            # visit for the effect
674            v.visit(node)
675            return pipe_visit(node, visitors[1:])

Send an AST through a pipeline of AST visitors/transformers.

def parse_file(path, docstrings=True):
682def parse_file(path, docstrings=True):
683    """Load and parse a program AST from the given path."""
684    with open(path) as f:
685        return parse(f.read(), filename=path, docstrings=docstrings)
686    pass

Load and parse a program AST from the given path.

STANDARD_PASSES = [<class 'potluck.mast.RemoveDocstrings'>, <class 'potluck.mast.FlattenBoolOps'>]

Thunks for AST visitors/transformers that are applied during parsing.

DOCSTRING_PASSES = [<class 'potluck.mast.FlattenBoolOps'>]

Extra thunks for docstring mode?

class MastParseError(builtins.Exception):
710class MastParseError(Exception):
711    def __init__(self, pattern, error):
712        super(MastParseError, self).__init__(
713            'Error parsing pattern source string <<<\n{}\n>>>\n{}'.format(
714                pattern, str(error)
715            )
716        )
717        self.trigger = error # hang on to the original error
718        pass
719    pass

Common base class for all non-exit exceptions.

MastParseError(pattern, error)
711    def __init__(self, pattern, error):
712        super(MastParseError, self).__init__(
713            'Error parsing pattern source string <<<\n{}\n>>>\n{}'.format(
714                pattern, str(error)
715            )
716        )
717        self.trigger = error # hang on to the original error
718        pass
Inherited Members
builtins.BaseException
with_traceback
def parse( string, docstrings=True, filename='<unknown>', passes=[<class 'potluck.mast.RemoveDocstrings'>, <class 'potluck.mast.FlattenBoolOps'>]):
722def parse(
723    string,
724    docstrings=True,
725    filename='<unknown>',
726    passes=STANDARD_PASSES
727):
728    """Parse an AST from a string."""
729    if type(string) == str: # All Python 3 strings are unicode
730        string = str(string)
731    assert type(string) == str # All Python 3 strings are unicode
732    if docstrings and RemoveDocstrings in passes:
733        passes.remove(RemoveDocstrings)
734    pipe = [thunk() for thunk in passes]
735
736    try:
737        parsed_ast = ast.parse(string, filename=filename)
738    except Exception as error:
739        raise MastParseError(string, error)
740    return pipe_visit(parsed_ast, pipe)

Parse an AST from a string.

PATTERN_PARSE_CACHE = {(False, True, '_(___)'): <ast.Call object>, (False, True, '_._(___)'): <ast.Call object>, (False, True, '_ = _'): <ast.Assign object>, (False, True, 'del _'): <ast.Delete object>, (False, True, 'print(___)'): <ast.Call object>}

a pattern parsing cache

def parse_pattern(pat, toplevel=False, docstrings=True):
755def parse_pattern(pat, toplevel=False, docstrings=True):
756    """Parse and validate a pattern."""
757    if isinstance(pat, ast.AST):
758        PatternValidator().visit(pat)
759        return pat
760    elif type(pat) == list:
761        iter(lambda x: PatternValidator().visit(x), pat)
762        return pat
763    elif type(pat) == str: # All Python 3 strings are unicode
764        cache_key = (toplevel, docstrings, pat)
765        pat_ast = PATTERN_PARSE_CACHE.get(cache_key)
766        if not pat_ast:
767            # 2021-6-16 Peter commented this out (should it be used
768            # instead of passes on the line below?)
769            # pipe = [thunk() for thunk in PATTERN_PASSES]
770            pat_ast = parse(pat, filename='<pattern>', passes=PATTERN_PASSES)
771            PatternValidator().visit(pat_ast)
772            if not toplevel:
773                # Unwrap as needed.
774                if len(pat_ast.body) == 1:
775                    # If the pattern is a single definition, statement, or
776                    # expression, unwrap the Module node.
777                    b = pat_ast.body[0]
778                    pat_ast = b.value if isinstance(b, ast.Expr) else b
779                    pass
780                else:
781                    # If the pattern is a sequence of definitions or
782                    # statements, validate from the top, but return only
783                    # the Module body.
784                    pat_ast = pat_ast.body
785                    pass
786                pass
787            PATTERN_PARSE_CACHE[cache_key] = pat_ast
788            pass
789        return pat_ast
790    else:
791        assert False, 'Cannot parse pattern of type {}: {}'.format(
792            type(pat),
793            dump(pat)
794        )

Parse and validate a pattern.

def pat(pat, toplevel=False, docstrings=True):
797def pat(pat, toplevel=False, docstrings=True):
798    """Alias for parse_pattern."""
799    return parse_pattern(pat, toplevel, docstrings)

Alias for parse_pattern.

ASSOCIATIVE_OPS = [<class 'ast.Add'>, <class 'ast.Mult'>, <class 'ast.BitOr'>, <class 'ast.BitXor'>, <class 'ast.BitAnd'>, <class 'ast.Eq'>, <class 'ast.NotEq'>, <class 'ast.Is'>, <class 'ast.IsNot'>]

Types of AST operation nodes that are associative.

def op_is_assoc(op):
814def op_is_assoc(op):
815    """Determine if the given operation node is associative."""
816    return type(op) in ASSOCIATIVE_OPS

Determine if the given operation node is associative.

MIRRORING = [(<ast.Lt object>, <ast.Gt object>), (<ast.LtE object>, <ast.GtE object>), (<ast.Eq object>, <ast.Eq object>), (<ast.NotEq object>, <ast.NotEq object>), (<ast.Is object>, <ast.Is object>), (<ast.IsNot object>, <ast.IsNot object>)]

Pairs of operations that are mirrors of each other.

def mirror(op):
830def mirror(op):
831    """Return the mirror operation of the given operation if any."""
832    for (x, y) in MIRRORING:
833        if type(x) == type(op):
834            return y
835        elif type(y) == type(op):
836            return x
837        pass
838    return None

Return the mirror operation of the given operation if any.

def op_has_mirror(op):
841def op_has_mirror(op):
842    """Determine if the given operation has a mirror."""
843    return bool(mirror)

Determine if the given operation has a mirror.

class PermuteBoolOps(ast.NodeTransformer):
866class PermuteBoolOps(ast.NodeTransformer):
867    """AST transformer: permute topmost boolean operation
868    according to the given index order."""
869    def __init__(self, indices):
870        self.indices = indices
871
872    def visit_BoolOp(self, node):
873        if self.indices:
874            values = [
875                self.generic_visit(node.values[i])
876                for i in self.indices
877            ]
878            self.indices = None
879            return ast.copy_location(
880                ast.BoolOp(op=node.op, values=values),
881                node
882            )
883        else:
884            return self.generic_visit(node)

AST transformer: permute topmost boolean operation according to the given index order.

PermuteBoolOps(indices)
869    def __init__(self, indices):
870        self.indices = indices
def visit_BoolOp(self, node):
872    def visit_BoolOp(self, node):
873        if self.indices:
874            values = [
875                self.generic_visit(node.values[i])
876                for i in self.indices
877            ]
878            self.indices = None
879            return ast.copy_location(
880                ast.BoolOp(op=node.op, values=values),
881                node
882            )
883        else:
884            return self.generic_visit(node)
Inherited Members
ast.NodeTransformer
generic_visit
ast.NodeVisitor
visit
visit_Constant
def node_is_pure(node, purefuns=[]):
892def node_is_pure(node, purefuns=[]):
893    """Determine if the given node is (conservatively pure (effect-free)."""
894    return (
895        count(node, FUNCALL,
896              matchpred=lambda x, env:
897              x.func not in PUREFUNS and x.func not in purefuns) == 0
898        and all(count(node, pat) == 0 for pat in IMPURE2)
899    )

Determine if the given node is (conservatively pure (effect-free).

def permutations(node):
902def permutations(node):
903    """
904    Generate all permutations of the binary and boolean operations, as
905    well as of import orderings, in and below node, via associativity and
906    mirroring, respecting purity. Because this is heavily exponential,
907    limiting the number of imported names and the complexity of binary
908    operation trees in your patterns is a good thing.
909    """
910    # TODO: These permutations allow matching like this:
911    # Node: import math, sys, io
912    # Pat: import _x_, ___
913    # can bind x to math or io, but NOT sys (see TODO near
914    # parse_pattern("___"))
915    if isinstance(node, ast.Import):
916        for perm in list_permutations(node.names):
917            yield ast.Import(names=perm)
918    elif isinstance(node, ast.ImportFrom):
919        for perm in list_permutations(node.names):
920            yield ast.ImportFrom(
921                module=node.module,
922                names=perm,
923                level=node.level
924            )
925    if (
926        isinstance(node, ast.BinOp)
927    and op_is_assoc(node)
928    and node_is_pure(node)
929    ):
930        for left in permutations(node.left):
931            for right in permutations(node.right):
932                yield ast.copy_location(ast.BinOp(
933                    left=left,
934                    op=node.op,
935                    right=right,
936                    ctx=node.ctx
937                ))
938                yield ast.copy_location(ast.BinOp(
939                    left=right,
940                    op=node.op,
941                    right=left,
942                    ctx=node.ctx
943                ))
944    elif (isinstance(node, ast.Compare)
945          and len(node.ops) == 1
946          and op_has_mirror(node.ops[0])
947          and node_is_pure(node)):
948        assert len(node.comparators) == 1
949        for left in permutations(node.left):
950            for right in permutations(node.comparators[0]):
951                # print('PERMUTE', dump(left), dump(node.ops), dump(right))
952                yield ast.copy_location(ast.Compare(
953                    left=left,
954                    ops=node.ops,
955                    comparators=[right]
956                ), node)
957                yield ast.copy_location(ast.Compare(
958                    left=right,
959                    ops=[mirror(node.ops[0])],
960                    comparators=[left]
961                ), node)
962    elif isinstance(node, ast.BoolOp) and node_is_pure(node):
963        #print(dump(node))
964        stuff = [[x for x in permutations(v)] for v in node.values]
965        prod = [x for x in itertools.product(*stuff)]
966        # print(prod)
967        for values in prod:
968            # print('VALUES', map(dump,values))
969            for indices in itertools.permutations(range(len(node.values))):
970                # print('BOOL', map(dump, values))
971                yield PermuteBoolOps(indices).visit(
972                    ast.copy_location(
973                        ast.BoolOp(op=node.op, values=values),
974                        node
975                    )
976                )
977                pass
978            pass
979        pass
980    else:
981        # print('NO', dump(node))
982        yield node
983    pass

Generate all permutations of the binary and boolean operations, as well as of import orderings, in and below node, via associativity and mirroring, respecting purity. Because this is heavily exponential, limiting the number of imported names and the complexity of binary operation trees in your patterns is a good thing.

def list_permutations(items):
986def list_permutations(items):
987    """
988    A generator which yields all possible orderings of the given list.
989    """
990    if len(items) <= 1:
991        yield items[:]
992    else:
993        first = items[0]
994        for subperm in list_permutations(items[1:]):
995            for i in range(len(subperm) + 1):
996                yield subperm[:i] + [first] + subperm[i:]

A generator which yields all possible orderings of the given list.

def match(node, pat, **kwargs):
1003def match(node, pat, **kwargs):
1004    """A convenience wrapper for matchpat that accepts a patterns either
1005    as an pre-parsed AST or as a pattern source string to be parsed.
1006
1007    """
1008    return matchpat(node, parse_pattern(pat), **kwargs)

A convenience wrapper for matchpat that accepts a patterns either as an pre-parsed AST or as a pattern source string to be parsed.

def predtrue(node, matchenv):
1011def predtrue(node, matchenv):
1012    """The True predicate"""
1013    return True

The True predicate

def matchpat( node, pat, matchpred=<function predtrue>, env={}, gen=False, normalize=False):
1016def matchpat(
1017    node,
1018    pat,
1019    matchpred=predtrue,
1020    env={},
1021    gen=False,
1022    normalize=False
1023):
1024    """Match an AST against a (pre-parsed) pattern.
1025
1026    The optional keyword argument matchpred gives a predicate of type
1027    (AST node * match environment) -> bool that filters structural
1028    matches by arbitrary additional criteria.  The default is the true
1029    predicate, accepting all structural matches.
1030
1031    The optional keyword argument gen determines whether this function:
1032      - (gen=True)  yields an environment for each way pat matches node.
1033      - (gen=False) returns Some(the first of these environments), or
1034                    None if there are no matches (default).
1035
1036    EXPERIMENTAL
1037    The optional keyword argument normalize determines whether to
1038    rewrite the target AST node and the pattern to inline simple
1039    straightline variable assignments into large expressions (to the
1040    extent possible) for matching.  The default is no normalization
1041    (False).  Normalization is experimental.  It is rather ad hoc and
1042    conservative and may causes unintuitive matching behavior.  Use
1043    with caution.
1044
1045    INTERNAL
1046    The optional keyword env gives an initial match environment which
1047    may be used to constrain otherwise free pattern variables.  This
1048    argument is mainly intended for internal use.
1049
1050    """
1051    assert node is not None
1052    assert pat is not None
1053    if normalize:
1054        if isinstance(node, ast.AST) or type(node) == list:
1055            node = canonical_pure(node)
1056            pass
1057        if isinstance(node, ast.AST) or type(node) == list:
1058            pat = canonical_pure(pat)
1059            pass
1060        pass
1061
1062    # Permute the PATTERN, not the AST, so that nodes returned
1063    # always match real code.
1064    matches = (matchenv
1065               # outer iteration
1066               for permpat in permutations(pat)
1067               # inner iteration
1068               for matchenv in imatches(node, permpat,
1069                                        Some(env), True)
1070               if matchpred(node, matchenv))
1071
1072    return matches if gen else takeone(matches)

Match an AST against a (pre-parsed) pattern.

The optional keyword argument matchpred gives a predicate of type (AST node * match environment) -> bool that filters structural matches by arbitrary additional criteria. The default is the true predicate, accepting all structural matches.

The optional keyword argument gen determines whether this function:

  • (gen=True) yields an environment for each way pat matches node.
  • (gen=False) returns Some(the first of these environments), or None if there are no matches (default).

EXPERIMENTAL The optional keyword argument normalize determines whether to rewrite the target AST node and the pattern to inline simple straightline variable assignments into large expressions (to the extent possible) for matching. The default is no normalization (False). Normalization is experimental. It is rather ad hoc and conservative and may causes unintuitive matching behavior. Use with caution.

INTERNAL The optional keyword env gives an initial match environment which may be used to constrain otherwise free pattern variables. This argument is mainly intended for internal use.

def bind(env1, name, value):
1075def bind(env1, name, value):
1076    """
1077    Unify the environment in option env1 with a new binding of name to
1078    value.  Return Some(extended environment) if env1 is Some(existing
1079    environment) in which name is not bound or is bound to value.
1080    Otherwise, env1 is None or the existing binding of name
1081    incompatible with value, return None.
1082    """
1083    # Lyn modified to allow debugging prints
1084    assert type(name) == str
1085    if env1 is None:
1086        # return None
1087        result = None
1088    env = env1.value
1089    if var_is_anonymous(name):
1090        # return env1
1091        result = env1
1092    elif name in env:
1093        if takeone(imatches(env[name], value, Some({}), True)):
1094            # return env1
1095            result = env1
1096        else:
1097            # return None
1098            result = None
1099    else:
1100        env = env.copy()
1101        env[name] = value
1102        # print 'bind', name, dump(value), dump(env)
1103        # return Some(env)
1104        result = Some(env)
1105    # print(
1106    #    '\n$ bind({}, {}, {}) => {}'.format(
1107    #         dump(env1),
1108    #         name,
1109    #         dump(value),
1110    #         dump(result)
1111    #     )
1112    # )
1113    return result
1114    pass

Unify the environment in option env1 with a new binding of name to value. Return Some(extended environment) if env1 is Some(existing environment) in which name is not bound or is bound to value. Otherwise, env1 is None or the existing binding of name incompatible with value, return None.

IGNORE_FIELDS = {'col_offset', 'lineno', 'ctx'}

AST node fields to be ignored when matching children.

def argObjToName(argObj):
1121def argObjToName(argObj):
1122    '''Convert Python 3 arg object into a Name node,
1123       ignoring any annotation, lineno, and col_offset.'''
1124    if argObj is None:
1125        return None
1126    else:
1127        return ast.Name(id=argObj.arg, ctx=None)

Convert Python 3 arg object into a Name node, ignoring any annotation, lineno, and col_offset.

def astStr(astObj):
1130def astStr(astObj):
1131    """
1132    Converts an AST object to a string, imperfectly.
1133    """
1134    if isinstance(astObj, ast.Name):
1135        return astObj.id
1136    elif isinstance(astObj, ast.Str):
1137        return repr(astObj.s)
1138    elif isinstance(astObj, ast.Num):
1139        return str(astObj.n)
1140    elif isinstance(astObj, (list, tuple)):
1141        return str([astStr(x) for x in astObj])
1142    elif hasattr(astObj, "_fields") and len(astObj._fields) > 0:
1143        return "{}({})".format(
1144            type(astObj).__name__,
1145            ', '.join(astStr(getattr(astObj, f)) for f in astObj._fields)
1146        )
1147    elif hasattr(astObj, "_fields"):
1148        return type(astObj).__name__
1149    else:
1150        return '<{}>'.format(type(astObj).__name__)

Converts an AST object to a string, imperfectly.

def defaultToName(defObj):
1153def defaultToName(defObj):
1154    """
1155    Converts a default expression to a name for matching purposes.
1156    TODO: Actually do recursive matching on these expressions!
1157    """
1158    if isinstance(defObj, ast.Name):
1159        return defObj.id
1160    else:
1161        return astStr(defObj)

Converts a default expression to a name for matching purposes. TODO: Actually do recursive matching on these expressions!

def field_values(node):
1164def field_values(node):
1165    """Return a list of the values of all matching-relevant fields of the
1166    given AST node, with fields in consistent list positions."""
1167
1168    # Specializations:
1169    if isinstance(node, ast.FunctionDef):
1170        return [ast.Name(id=v, ctx=None) if k == 'name'
1171                # Lyn sez: commented out following, because fields are
1172                # only name, arguments, body
1173                # else sorted(v, key=lambda x: x.arg) if k == 'keywords'
1174                else v
1175                for (k, v) in ast.iter_fields(node)
1176                if k not in IGNORE_FIELDS]
1177
1178    if isinstance(node, ast.ClassDef):
1179        return [ast.Name(id=v, ctx=None) if k == 'name'
1180                else v
1181                for (k, v) in ast.iter_fields(node)
1182                if k not in IGNORE_FIELDS]
1183
1184    if isinstance(node, ast.Call):
1185        # Old: sorted(v, key=lambda kw: kw.arg)
1186        # New: keyword args use nominal (not positional) matching.
1187        #return [sorted(v, key=lambda kw: kw.arg) if k == 'keywords' else v
1188        return [
1189            (
1190                odict((kw.arg, kw.value) for kw in v)
1191                if k == 'keywords'
1192                else v
1193            )
1194            for (k, v) in ast.iter_fields(node)
1195            if k not in IGNORE_FIELDS
1196        ]
1197
1198    # Lyn sez: ast.arguments handling is new for Python 3
1199    if isinstance(node, ast.arguments):
1200        argList = [
1201            # Need to create Name nodes to match patterns.
1202            argObjToName(argObj)
1203            for argObj in node.args
1204        ] # Ignores argObj annotation, lineno, col_offset
1205        if (
1206            node.vararg is None
1207        and node.kwarg is None
1208        and node.kwonlyargs == []
1209        and node.kw_defaults == []
1210        and node.defaults == []
1211        ):
1212            # Optimize (for debugging purposes) this very common case by
1213            # returning a singleton list of lists
1214            return [argList]
1215        else:
1216            # In unoptimized case, return list with sublists and
1217            # argObjects/Nones:
1218            # TODO: treat triple underscores separately/specially!!!
1219            return [
1220                argList,
1221                argObjToName(node.vararg),
1222                argObjToName(node.kwarg),
1223                [argObjToName(argObj) for argObj in node.kwonlyargs],
1224                # Peter 2019-9-30 the defaults cannot be reliably converted
1225                # into names, because they are expressions!
1226                node.kw_defaults,
1227                node.defaults
1228            ]
1229
1230    if isinstance(node, ast.keyword):
1231        return [ast.Name(id=v, ctx=None) if k == 'arg' else v
1232                for (k, v) in ast.iter_fields(node)
1233                if k not in IGNORE_FIELDS]
1234
1235    if isinstance(node, ast.Global):
1236        return [ast.Name(id=n, ctx=None) for n in sorted(node.names)]
1237
1238    if isinstance(node, ast.Import):
1239        return [ node.names ]
1240
1241    if isinstance(node, ast.ImportFrom):
1242        return [ast.Name(id=node.module, ctx=None), node.level, node.names]
1243
1244    if isinstance(node, ast.alias):
1245        result = [ ast.Name(id=node.name, ctx=None) ]
1246        if node.asname is not None:
1247            result.append(ast.Name(id=node.asname, ctx=None))
1248        return result
1249
1250    # General cases:
1251    if isinstance(node, ast.AST):
1252        return [v for (k, v) in ast.iter_fields(node)
1253                if k not in IGNORE_FIELDS]
1254
1255    if type(node) == list:
1256        return node
1257
1258    # Added by Peter Mawhorter 2019-4-2 to fix problem where subpatterns fail
1259    # to match within keyword arguments of functions because field_values
1260    # returns an odict as one value for functions with kwargs, and asking for
1261    # the field_values of that odict was hitting the base case below.
1262    if isinstance(node, (dict, odict)):
1263        return [
1264          ast.Name(id=n, ctx=None) for n in node.keys()
1265        ] + list(node.values())
1266
1267    return []

Return a list of the values of all matching-relevant fields of the given AST node, with fields in consistent list positions.

def imatches(node, pat, env1, seq):
1273def imatches(node, pat, env1, seq):
1274    """Exponential backtracking match generating 0 or more matching
1275    environments.  Supports multiple sequence patterns in one context,
1276    simple permutations of semantically equivalent but syntactically
1277    mirrored patterns.
1278    """
1279    # Lyn change early-return pattern to named result returned at end to
1280    # allow debugging prints
1281    # global imatchCount #$
1282    # print('\n$ {}Entering imatches({}, {}, {}, {})'.format(
1283    #         '| '*imatchCount, dump(node), dump(pat), dump(env1), seq))
1284    # imatchCount += 1 #$
1285    result = iterempty() # default result if not overridden
1286    if env1 is None:
1287        result = iterempty()
1288        # imatchCount -= 1 #$
1289        # print(
1290        #     '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format(
1291        #         '| ' * imatchCount,
1292        #         dump(node),
1293        #         dump(pat),
1294        #         dump(env1),
1295        #         seq,
1296        #         dump(result)
1297        #     )
1298        # )
1299        return result
1300    env = env1.value
1301    assert env is not None
1302    if (
1303        (
1304            type(pat) == bool
1305         or type(pat) == str # All Python 3 strings are unicode
1306         or pat is None
1307        )
1308    and node == pat
1309    ):
1310        result = iterone(env)
1311    elif type(pat) == int or type(pat) == float:
1312        # Literal int or float pattern
1313        if type(node) == type(pat):
1314            if (
1315                (type(pat) == int and node == pat)
1316             or (type(pat) == float and abs(node - pat) < 0.001)
1317            ):
1318                result = iterone(env)
1319            pass
1320    elif node_var(pat):
1321        # Var pattern.
1322        # Match and bind name to node.
1323        if node_is_bindable(node):
1324            # [Peter Mawhorter 2021-8-29] Attempting to allow import
1325            # aliases to unify with variable references later on. If the
1326            # alias has an 'as' part we unify as if it's a name with that
1327            # ID, otherwise we use the name part as the ID.
1328            if isinstance(node, ast.alias):
1329                if node.asname:
1330                    bind_as = ast.Name(id=node.asname, ctx=None)
1331                else:
1332                    bind_as = ast.Name(id=node.name, ctx=None)
1333            else:
1334                bind_as = node
1335            env2 = bind(env1, node_var(pat), bind_as)
1336            if env2:
1337                result = iterone(env2.value)
1338            pass
1339    elif typed_lit_var(pat):
1340        # Var pattern to bind only a literal of given type.
1341        id, ty = typed_lit_var(pat)
1342        lit = LIT_TYPES[ty](node)
1343        # Match and bind name to literal.
1344        if lit:
1345            env2 = bind(env1, id, lit.value)
1346            if env2:
1347                result = iterone(env2.value)
1348            pass
1349    elif type(node) == type(pat):
1350        # Node and pattern have same type.
1351        if type(pat) == list:
1352            # Node and pattern are both lists.  Do positional matching.
1353            if len(pat) == 0:
1354                # Empty list pattern.  Node must also be empty.
1355                if len(node) == 0:
1356                    result = iterone(env)
1357                pass
1358            elif len(node) == 0:
1359                # Non-empty list pattern with empty node.
1360                # Try to match sequence subpatterns.
1361                if seq:
1362                    psn = set_var(pat[0])
1363                    if psn:
1364                        result = imatches(node, pat[1:],
1365                                          bind(env1, psn, []), seq)
1366                    pass
1367                pass
1368            else:
1369                # Both are non-empty.
1370                psn = set_var(pat[0])
1371                if seq and psn:
1372                    # First subpattern is a sequence pattern.
1373                    # Try all consumption sizes, greediest first.
1374                    # Unsophisticated exponential backtracking search.
1375                    result = ichain(
1376                        imatches(
1377                            node[i:],
1378                            pat[1:],
1379                            bind(env1, psn, node[:i]),
1380                            seq
1381                        )
1382                        for i in range(len(node), -1, -1)
1383                    )
1384                # Lyn sez: common special case helpful for more concrete
1385                # debugging results (e.g., may return FiniteIterator
1386                # rather than itertools.chain object.)
1387                elif len(node) == 1 and len(pat) == 1:
1388                    result = imatches(node[0], pat[0], env1, True)
1389                else:
1390                    # For all matches of scalar first element sub pattern.
1391                    # Generate all corresponding matches of remainder.
1392                    # Unsophisticated exponential backtracking search.
1393                    result = ichain(
1394                        imatches(node[1:], pat[1:], Some(bs), seq)
1395                       for bs in imatches(node[0], pat[0], env1, True)
1396                    )
1397                pass
1398        elif type(node) == dict or type(node) == odict:
1399            result = match_dict(node, pat, env1)
1400        else:
1401            # Node and pat have same type, but are not lists.
1402            # Match scalar structures by matching lists of their fields.
1403            if isinstance(node, ast.AST):
1404                # TODO: DEBUG
1405                #if isinstance(node, ast.Import):
1406                #    print(
1407                #        "FV2i",
1408                #        dump(field_values(node)),
1409                #        dump(field_values(pat))
1410                #    )
1411                #if isinstance(node, ast.alias):
1412                #    print(
1413                #        "FV2a",
1414                #        dump(field_values(node)),
1415                #        dump(field_values(pat))
1416                #    )
1417                result = imatches(
1418                    field_values(node),
1419                    field_values(pat),
1420                    env1,
1421                    False
1422                )
1423            pass
1424        pass
1425    # return iterempty()
1426    # imatchCount -= 1 #$
1427    # print(
1428    #     '\n$ {} Exiting imatches({}, {}, {}, {}) => {})'.format(
1429    #         '| ' * imatchCount,
1430    #         dump(node),
1431    #         dump(pat),
1432    #         dump(env1),
1433    #         seq,
1434    #         dump(result)
1435    #     )
1436    # )
1437    return result

Exponential backtracking match generating 0 or more matching environments. Supports multiple sequence patterns in one context, simple permutations of semantically equivalent but syntactically mirrored patterns.

def match_dict(node, pat, env1):
1440def match_dict(node, pat, env1):
1441    # Node and pattern are both dictionaries. Do nominal matching.
1442    # Match all named key patterns, then all single-key pattern variables,
1443    # and finally the multi-key pattern variables.
1444    assert all(type(k) == str # All Python 3 strings are unicode
1445               for k in pat)
1446    assert all(type(k) == str # All Python 3 strings are unicode
1447               for k in node)
1448
1449    def match_keys(node, pat, envopt):
1450        """Match literal keys."""
1451        keyopt = takeone(k for k in pat
1452                         if not node_var_str(k)
1453                         and not set_var_str(k))
1454        if keyopt:
1455            # There is at least one named key in the pattern.
1456            # If this key is also in the program node, then for each
1457            # match of the corresponding node value and pattern value,
1458            # generate all matches for the remaining keys.
1459            key = keyopt.value
1460            if key in node:
1461                return ichain(match_keys(dict_unbind(node, key),
1462                                         dict_unbind(pat, key),
1463                                         Some(kenv))
1464                              for kenv in imatches(node[key], pat[key],
1465                                                   envopt, False))
1466            else:
1467                return iterempty()
1468            pass
1469        else:
1470            # The pattern contains no literal keys.
1471            # Generate all matches for the node and set key variables.
1472            return match_var_keys(node, pat, envopt)
1473        pass
1474
1475    def match_var_keys(node, pat, envopt):
1476        """Match node variable keys."""
1477        keyvaropt = takeone(k for k in pat if node_var_str(k))
1478        if keyvaropt:
1479            # There is at least one single-key variable in the pattern.
1480            # For each key-value pair in the node whose value matches
1481            # this single-key variable's associated value pattern,
1482            # generate all matches for the remaining keys.
1483            keyvar = keyvaropt.value
1484            return ichain(match_var_keys(dict_unbind(node, nkey),
1485                                         dict_unbind(pat, keyvar),
1486                                         bind(Some(kenv),
1487                                              node_var_str(keyvar),
1488                                              ast.Name(id=nkey, ctx=None)))
1489                          # outer iteration:
1490                          for nkey, nval in node.items()
1491                          # inner iteration:
1492                          for kenv in imatches(nval, pat[keyvar],
1493                                               envopt, False))
1494        else:
1495            # The pattern contains no single-key variables.
1496            # Generate all matches for the set key variables.
1497            return match_set_var_keys(node, pat, envopt)
1498        pass
1499
1500    def match_set_var_keys(node, pat, envopt):
1501        """Match set variable keys."""
1502        # NOTE: see discussion of match environments for this case:
1503        # https://github.com/wellesleycs111/codder/issues/25
1504        assert envopt
1505        keysetvaropt = takeone(k for k in pat if set_var_str(k))
1506        if keysetvaropt:
1507            # There is a multi-key variable in the pattern.
1508            # Capture all remaining key-value pairs in the node.
1509            e = bind(envopt, set_var_str(keysetvaropt.value),
1510                     [(ast.Name(id=kw, ctx=None), karg)
1511                      for kw, karg in node.items()])
1512            return iterone(e.value) if e else iterempty()
1513        elif 0 == len(node):
1514            # There is no multi-key variable in the pattern.
1515            # There is a match only if there are no remaining
1516            # keys in the node.
1517            # There should also be no remaining keys in the pattern.
1518            assert 0 == len(pat)
1519            return iterone(envopt.value)
1520        else:
1521            return iterempty()
1522
1523    return match_keys(node, pat, env1)
def find(node, pat, **kwargs):
1526def find(node, pat, **kwargs):
1527    """Pre-order search for first sub-AST matching pattern, returning
1528    (matched node, bindings)."""
1529    kwargs['gen'] = True
1530    return takeone(findall(node, parse_pattern(pat), **kwargs))

Pre-order search for first sub-AST matching pattern, returning (matched node, bindings).

def findall(node, pat, outside=[], **kwargs):
1533def findall(node, pat, outside=[], **kwargs):
1534    """
1535    Search for all sub-ASTs matching pattern, returning list of (matched
1536    node, bindings).
1537    """
1538    assert node is not None
1539    assert pat is not None
1540    gen = kwargs.get('gen', False)
1541    kwargs['gen'] = True
1542    pat = parse_pattern(pat)
1543    # Top-level sequence patterns are not "anchored" to the ends of
1544    # the containing block when *finding* a submatch within a node (as
1545    # opposed to matching a node exactly).  They may match any
1546    # contiguous subsequence.
1547    # - To allow sequence patterns to match starting later than the
1548    #   beginning of a program sequence, matching is attempted
1549    #   recursively with smaller and smaller suffixes of the program
1550    #   sequence.
1551    # - To allow sequence patterns to match ending earlier
1552    #   than the end of a program block, we implicitly ensure that a
1553    #   sequence wildcard pattern terminates every top-level sequence
1554    #   pattern.
1555    # TODO: Because permutations are applied later, this doesn't allow
1556    # all the matching we'd like in the following scenario
1557    # Node: [ alias(name="x"), alias(name="y"), alias(name="z") ]
1558    # Pat: [ alias(name="_a_"), alias(name="___") ]
1559    # here because order of aliases doesn't matter, we *should* be able
1560    # to bind _a_ to x, y, OR z, but it can only bind to x or z, NOT y
1561    if type(pat) == list and not set_var(pat[-1]):
1562        pat = pat + [parse_pattern('___')]
1563        pass
1564
1565    def findall_scalar_pat_iter(node):
1566        """Generate all matches of a scalar (non-list) pattern at this node
1567        or any non-excluded descendant of this node."""
1568        assert type(pat) != list
1569        assert node is not None
1570        # Yield any environment(s) for match(es) of pattern at node.
1571        envs = [e for e in matchpat(node, pat, **kwargs)]
1572        if 0 < len(envs):
1573            yield (node, envs)
1574        # Continue the search for matches in sub-ASTs of node, only if
1575        # node is not excluded by a "match outside" pattern.
1576        if not any(match(node, op) for op in outside):
1577            for n in field_values(node):
1578                if n: # Search only within non-None children.
1579                    for result in findall_scalar_pat_iter(n):
1580                        yield result
1581                        pass
1582                    pass
1583                pass
1584            pass
1585        pass
1586
1587    def findall_list_pat_iter(node):
1588        """Generate all matches of a list pattern at this node or any
1589        non-excluded descendant of this node."""
1590        assert type(pat) == list
1591        assert 0 < len(pat)
1592        if type(node) == list:
1593            # If searching against a list:
1594            # - match against the list itself
1595            # - search against the first child of the list
1596            # - search against the tail of the list
1597
1598            # Match against the list itself.
1599            # Yield any environment(s) for match(es) of pattern at node.
1600            envs = [e for e in matchpat(node, pat, **kwargs)]
1601            if 0 < len(envs):
1602                yield (node, envs)
1603            # Continue the search for matches in sub-ASTs of node,
1604            # only if node is not excluded by a "match outside"
1605            # pattern.
1606            if (
1607                not any(match(node, op) for op in outside)
1608            and 0 < len(node) # only in nonempty nodes...
1609            ):
1610                # Search for matches in the first sub-AST.
1611                for m in findall_list_pat_iter(node[0]):
1612                    yield m
1613                    pass
1614                if not set_var(pat[0]):
1615                    # If this sequence pattern does not start with
1616                    # a sequence wildcard, then:
1617                    # Search for matches in the tail of the list.
1618                    # (Includes matches against the entire tail.)
1619                    for m in findall_list_pat_iter(node[1:]):
1620                        yield m
1621                        pass
1622                    pass
1623                pass
1624            pass
1625        elif not any(match(node, op) for op in outside): # and node is not list
1626            # A list pattern cannot match against a scalar node.
1627            # Search for matches in children of this scalar (non-list)
1628            # node, only if node is not excluded by a "match outside"
1629            # pattern.
1630
1631            # Optimize to search only where list patterns could match.
1632
1633            # Body blocks
1634            for ty in [ast.ClassDef, ast.FunctionDef, ast.With, ast.Module,
1635                       ast.If, ast.For, ast.While,
1636                       # ast.TryExcept, ast.TryFinally,
1637                       # In Python 3, a single Try node covers both
1638                       # TryExcept and TryFinally
1639                       ast.Try,
1640                       ast.ExceptHandler]:
1641                if isinstance(node, ty):
1642                    for m in findall_list_pat_iter(node.body):
1643                        yield m
1644                        pass
1645                    break
1646                pass
1647            # Block else blocks
1648            for ty in [ast.If, ast.For, ast.While,
1649                       # ast.TryExcept
1650                       # In Python 3, a single Try node covers both
1651                       # TryExcept and TryFinally
1652                       ast.Try
1653                       ]:
1654                if isinstance(node, ty) and node.orelse:
1655                    for m in findall_list_pat_iter(node.orelse):
1656                        yield m
1657                        pass
1658                    break
1659                pass
1660
1661#             # Except handler blocks
1662#             if isinstance(node, ast.TryExcept):
1663#                 for h in node.handlers:
1664#                     for m in findall_list_pat_iter(h.body):
1665#                         yield m
1666#                         pass
1667#                     pass
1668#                 pass
1669#             # finally blocks
1670#             if isinstance(node, ast.TryFinally):
1671#                 for m in findall_list_pat_iter(node.finalbody):
1672#                         yield m
1673#                         pass
1674#                 pass
1675
1676            # In Python 3, a single Try node covers both TryExcept and
1677            # TryFinally
1678            if isinstance(node, ast.Try):
1679                for h in node.handlers:
1680                    for m in findall_list_pat_iter(h.body):
1681                        yield m
1682                        pass
1683                    pass
1684                pass
1685
1686            # General non-optimized version.
1687            # Must be mutually exclusive with the above if used.
1688            # for n in field_values(node):
1689            #                 if n:
1690            #                     for result in findall_list_pat_iter(n):
1691            #                         yield result
1692            #                         pass
1693            #                     pass
1694            #                 pass
1695            pass
1696        pass
1697    # Apply the right search based on pattern type.
1698    matches = (findall_list_pat_iter if type(pat) == list
1699               else findall_scalar_pat_iter)(node)
1700    # Return the generator or a list of all generated matches,
1701    # depending on gen.
1702    return matches if gen else list(matches)

Search for all sub-ASTs matching pattern, returning list of (matched node, bindings).

def count(node, pat, **kwargs):
1705def count(node, pat, **kwargs):
1706    """
1707    Count all sub-ASTs matching pattern. Does NOT count individual
1708    environments that match (i.e., ways that bindings could attach at a
1709    given node), but rather counts nodes at which one or more bindings
1710    are possible.
1711    """
1712    assert 'gen' not in kwargs
1713    return sum(1 for x in findall(node, pat,
1714                                  gen=True, **kwargs))

Count all sub-ASTs matching pattern. Does NOT count individual environments that match (i.e., ways that bindings could attach at a given node), but rather counts nodes at which one or more bindings are possible.

class Unimplemented(builtins.Exception):
1721class Unimplemented(Exception):
1722    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
class InlineAvailableExpressions(ast.NodeTransformer):
1734class InlineAvailableExpressions(ast.NodeTransformer):
1735    def __init__(self, other=None):
1736        self.available = dict(other.available) if other else {}
1737
1738    def visit_Name(self, name):
1739        if isinstance(name.ctx, ast.Load) and name.id in self.available:
1740            return self.available[name.id]
1741        else:
1742            return self.generic_visit(name)
1743
1744    def visit_Assign(self, assign):
1745        raise Unimplemented
1746        new = ast.copy_location(ast.Assign(
1747            targets=assign.targets,
1748            value=self.visit(assign.value)
1749        ), assign)
1750        self.available[assign.targets[0].id] = new.value
1751        return new
1752
1753    def visit_If(self, ifelse):
1754        # Inline into the test.
1755        test = self.visit(ifelse.test)
1756        # Inline and accumulate in the then and else independently.
1757        body_inliner = InlineAvailableExpressions(self)
1758        orelse_inliner = InlineAvailableExpressions(self)
1759        body = body_inliner.inline_block(ifelse.body)
1760        orelse = orelse_inliner.inline_block(ifelse.orelse)
1761        # Any var->expression that is available after both branches
1762        # is available after.
1763        self.available = {
1764            name: body_inliner.available[name]
1765            for name in (set(body_inliner.available)
1766                         & set(orelse_inliner.available))
1767            if (body_inliner.available[name] == orelse_inliner.available[name])
1768        }
1769        return ast.copy_location(
1770            ast.If(test=test, body=body, orelse=orelse),
1771            ifelse
1772        )
1773
1774    def generic_visit(self, node):
1775        if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES):
1776            raise Unimplemented()
1777        return ast.NodeTransformer.generic_visit(self, node)
1778
1779    def inline_block(self, block):
1780        # Introduce duplicate common subexpressions...
1781        return [self.visit(stmt) for stmt in block]

A NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the generic_visit() method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

InlineAvailableExpressions(other=None)
1735    def __init__(self, other=None):
1736        self.available = dict(other.available) if other else {}
def visit_Name(self, name):
1738    def visit_Name(self, name):
1739        if isinstance(name.ctx, ast.Load) and name.id in self.available:
1740            return self.available[name.id]
1741        else:
1742            return self.generic_visit(name)
def visit_Assign(self, assign):
1744    def visit_Assign(self, assign):
1745        raise Unimplemented
1746        new = ast.copy_location(ast.Assign(
1747            targets=assign.targets,
1748            value=self.visit(assign.value)
1749        ), assign)
1750        self.available[assign.targets[0].id] = new.value
1751        return new
def visit_If(self, ifelse):
1753    def visit_If(self, ifelse):
1754        # Inline into the test.
1755        test = self.visit(ifelse.test)
1756        # Inline and accumulate in the then and else independently.
1757        body_inliner = InlineAvailableExpressions(self)
1758        orelse_inliner = InlineAvailableExpressions(self)
1759        body = body_inliner.inline_block(ifelse.body)
1760        orelse = orelse_inliner.inline_block(ifelse.orelse)
1761        # Any var->expression that is available after both branches
1762        # is available after.
1763        self.available = {
1764            name: body_inliner.available[name]
1765            for name in (set(body_inliner.available)
1766                         & set(orelse_inliner.available))
1767            if (body_inliner.available[name] == orelse_inliner.available[name])
1768        }
1769        return ast.copy_location(
1770            ast.If(test=test, body=body, orelse=orelse),
1771            ifelse
1772        )
def generic_visit(self, node):
1774    def generic_visit(self, node):
1775        if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES):
1776            raise Unimplemented()
1777        return ast.NodeTransformer.generic_visit(self, node)

Called if no explicit visitor function exists for a node.

def inline_block(self, block):
1779    def inline_block(self, block):
1780        # Introduce duplicate common subexpressions...
1781        return [self.visit(stmt) for stmt in block]
Inherited Members
ast.NodeVisitor
visit
visit_Constant
class DeadCodeElim(ast.NodeTransformer):
1784class DeadCodeElim(ast.NodeTransformer):
1785    def __init__(self, other=None):
1786        self.used = set(other.users) if other else set()
1787        pass
1788
1789    def visit_Name(self, name):
1790        if isinstance(name.ctx, ast.Store):
1791            # Var store defines, removing from the set.
1792            self.used = self.used - {name.id}
1793            return name
1794        elif isinstance(name.ctx, ast.Load):
1795            # Var use uses, adding to the set.
1796            self.used = self.used | {name.id}
1797            return name
1798        else:
1799            return name
1800        pass
1801
1802    def visit_Assign(self, assign):
1803        # This restriction prevents worries about things like:
1804        # x, x[0] = [[1], 2]
1805        # By using this restriction it is safe to keep the single set,
1806        # thus order of removals and additions will not be a problem
1807        # since defs are discovered first (always to left of =), then
1808        # uses are discovered next (to right of =).
1809        assert all(
1810            (
1811                node_is_name(t)
1812             or (
1813                    isinstance(t, ast.Tuple)
1814                and all(node_is_name(t) for t in t.elts)
1815                )
1816            )
1817            for t in assign.targets
1818        )
1819        # Now handled by visit_Name
1820        # self.used = self.used - set(n.id for n in assign.targets)
1821        if (any(t.id in self.used for t in assign.targets if node_is_name(t))
1822            or any(t.id in self.used
1823                   for tup in assign.targets for t in tup
1824                   if type(tup) == tuple and node_is_name(t))):
1825            return ast.copy_location(ast.Assign(
1826                targets=[self.visit(t) for t in assign.targets],
1827                value=self.visit(assign.value)
1828            ), assign)
1829        else:
1830            return None
1831
1832    def visit_If(self, ifelse):
1833        body_elim = DeadCodeElim(self)
1834        orelse_elim = DeadCodeElim(self)
1835        # DCE the body
1836        body = body_elim.elim_block(ifelse.body)
1837        # DCE the else
1838        orelse = orelse_elim.elim_block(ifelse.body)
1839        # Use the test -- TODO: could eliminate entire if sometimes.
1840        # Keep it for now for clarity.
1841        self.used = body_elim.used | orelse_elim.used
1842        test = self.visit(ifelse.test)
1843        return ast.copy_location(ast.If(test=test, body=body, orelse=orelse),
1844                                 ifelse)
1845
1846    def generic_visit(self, node):
1847        if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES):
1848            raise Unimplemented()
1849        return ast.NodeTransformer.generic_visit(self, node)
1850
1851    def elim_block(self, block):
1852        # Introduce duplicate common subexpressions...
1853        return [s for s in
1854                (self.visit(stmt) for stmt in block[::-1])
1855                if s][::-1]

A NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the generic_visit() method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

DeadCodeElim(other=None)
1785    def __init__(self, other=None):
1786        self.used = set(other.users) if other else set()
1787        pass
def visit_Name(self, name):
1789    def visit_Name(self, name):
1790        if isinstance(name.ctx, ast.Store):
1791            # Var store defines, removing from the set.
1792            self.used = self.used - {name.id}
1793            return name
1794        elif isinstance(name.ctx, ast.Load):
1795            # Var use uses, adding to the set.
1796            self.used = self.used | {name.id}
1797            return name
1798        else:
1799            return name
1800        pass
def visit_Assign(self, assign):
1802    def visit_Assign(self, assign):
1803        # This restriction prevents worries about things like:
1804        # x, x[0] = [[1], 2]
1805        # By using this restriction it is safe to keep the single set,
1806        # thus order of removals and additions will not be a problem
1807        # since defs are discovered first (always to left of =), then
1808        # uses are discovered next (to right of =).
1809        assert all(
1810            (
1811                node_is_name(t)
1812             or (
1813                    isinstance(t, ast.Tuple)
1814                and all(node_is_name(t) for t in t.elts)
1815                )
1816            )
1817            for t in assign.targets
1818        )
1819        # Now handled by visit_Name
1820        # self.used = self.used - set(n.id for n in assign.targets)
1821        if (any(t.id in self.used for t in assign.targets if node_is_name(t))
1822            or any(t.id in self.used
1823                   for tup in assign.targets for t in tup
1824                   if type(tup) == tuple and node_is_name(t))):
1825            return ast.copy_location(ast.Assign(
1826                targets=[self.visit(t) for t in assign.targets],
1827                value=self.visit(assign.value)
1828            ), assign)
1829        else:
1830            return None
def visit_If(self, ifelse):
1832    def visit_If(self, ifelse):
1833        body_elim = DeadCodeElim(self)
1834        orelse_elim = DeadCodeElim(self)
1835        # DCE the body
1836        body = body_elim.elim_block(ifelse.body)
1837        # DCE the else
1838        orelse = orelse_elim.elim_block(ifelse.body)
1839        # Use the test -- TODO: could eliminate entire if sometimes.
1840        # Keep it for now for clarity.
1841        self.used = body_elim.used | orelse_elim.used
1842        test = self.visit(ifelse.test)
1843        return ast.copy_location(ast.If(test=test, body=body, orelse=orelse),
1844                                 ifelse)
def generic_visit(self, node):
1846    def generic_visit(self, node):
1847        if not any(isinstance(node, t) for t in SIMPLE_INLINE_TYPES):
1848            raise Unimplemented()
1849        return ast.NodeTransformer.generic_visit(self, node)

Called if no explicit visitor function exists for a node.

def elim_block(self, block):
1851    def elim_block(self, block):
1852        # Introduce duplicate common subexpressions...
1853        return [s for s in
1854                (self.visit(stmt) for stmt in block[::-1])
1855                if s][::-1]
Inherited Members
ast.NodeVisitor
visit
visit_Constant
class NormalizePure(ast.NodeTransformer):
1858class NormalizePure(ast.NodeTransformer):
1859    """AST transformer: normalize/inline straightline assignments into
1860    single expressions as possible."""
1861    def normalize_block(self, block):
1862        try:
1863            return DeadCodeElim().elim_block(
1864                InlineAvailableExpressions().inline_block(block)
1865            )
1866        except Unimplemented:
1867            return block
1868
1869    def visit_FunctionDef(self, fun):
1870        # Lyn warning: previously commented code below is from Python 2
1871        # and won't work in Python 3 (args have changed).
1872        # assert 0 < len(
1873        #    result.value[0].intersection(set(a.id for a in fun.args.args))
1874        # )
1875        normbody = self.normalize_block(fun.body)
1876        if normbody != fun.body:
1877            return ast.copy_location(ast.FunctionDef(
1878                name=fun.name,
1879                args=self.generic_visit(fun.args),
1880                body=normbody,
1881            ), fun)
1882        else:
1883            return fun

AST transformer: normalize/inline straightline assignments into single expressions as possible.

NormalizePure()
def normalize_block(self, block):
1861    def normalize_block(self, block):
1862        try:
1863            return DeadCodeElim().elim_block(
1864                InlineAvailableExpressions().inline_block(block)
1865            )
1866        except Unimplemented:
1867            return block
def visit_FunctionDef(self, fun):
1869    def visit_FunctionDef(self, fun):
1870        # Lyn warning: previously commented code below is from Python 2
1871        # and won't work in Python 3 (args have changed).
1872        # assert 0 < len(
1873        #    result.value[0].intersection(set(a.id for a in fun.args.args))
1874        # )
1875        normbody = self.normalize_block(fun.body)
1876        if normbody != fun.body:
1877            return ast.copy_location(ast.FunctionDef(
1878                name=fun.name,
1879                args=self.generic_visit(fun.args),
1880                body=normbody,
1881            ), fun)
1882        else:
1883            return fun
Inherited Members
ast.NodeTransformer
generic_visit
ast.NodeVisitor
visit
visit_Constant
def canonical_pure(node):
1890def canonical_pure(node):
1891    """Return the normalized/inlined version of an AST node."""
1892    # print(type(fun))
1893    # assert isinstance(fun, ast.FunctionDef)
1894    # print()
1895    # print('FUN', dump(fun))
1896    if type(node) == list:
1897        return NormalizePure().normalize_block(node)
1898        #    [DesugarAugAssign().visit(stmt) for stmt in node]
1899        #)
1900    else:
1901        assert isinstance(node, ast.AST)
1902        #return NormalizePure().visit(DesugarAugAssign().visit(node))
1903        return NormalizePure().visit(node)

Return the normalized/inlined version of an AST node.

def parse_canonical_pure(string, toplevel=False):
1909def parse_canonical_pure(string, toplevel=False):
1910    """Parse a normalized/inlined version of a program."""
1911    if type(string) == list:
1912        return [parse_canonical_pure(x) for x in string]
1913    elif string not in CANON_CACHE:
1914        CANON_CACHE[string] = canonical_pure(
1915            parse_pattern(string, toplevel=toplevel)
1916        )
1917        pass
1918    return CANON_CACHE[string]

Parse a normalized/inlined version of a program.

def indent(pat, indent=' '):
1928def indent(pat, indent=INDENT):
1929    """Apply indents to a source string."""
1930    return indent + pat.replace('\n', '\n' + indent)

Apply indents to a source string.

class SourceFormatter(ast.NodeVisitor):
1933class SourceFormatter(ast.NodeVisitor):
1934    """AST visitor: pretty print AST to python source string"""
1935    def __init__(self):
1936        ast.NodeVisitor.__init__(self)
1937        self._indent = ''
1938        pass
1939
1940    def indent(self):
1941        self._indent += INDENT
1942        pass
1943
1944    def unindent(self):
1945        self._indent = self._indent[:-4]
1946        pass
1947
1948    def line(self, ln):
1949        return self._indent + ln + '\n'
1950
1951    def lines(self, lst):
1952        return ''.join(lst)
1953
1954    def generic_visit(self, node):
1955        assert False, 'visiting {}'.format(ast.dump(node))
1956        pass
1957
1958    def visit_Module(self, m):
1959        return self.lines(self.visit(n) for n in m.body)
1960
1961    def visit_Interactive(self, i):
1962        return self.lines(self.visit(n) for n in i.body)
1963
1964    def visit_Expression(self, e):
1965        return self.line(self.visit(e.body))
1966
1967    def visit_FunctionDef(self, f):
1968        assert not f.decorator_list
1969        header = self.line('def {name}({args}):'.format(
1970            name=f.name,
1971            args=self.visit(f.args))
1972        )
1973        self.indent()
1974        body = self.lines(self.visit(s) for s in f.body)
1975        self.unindent()
1976        return header + body + '\n'
1977
1978    def visit_ClassDef(self, c):
1979        assert not c.decorator_list
1980        header = self.line('class {name}({bases}):'.format(
1981            name=c.name,
1982            bases=', '.join(self.visit(b) for b in c.bases)
1983        ))
1984        self.indent()
1985        body = self.lines(self.visit(s) for s in c.body)
1986        self.unindent()
1987        return header + body + '\n'
1988
1989    def visit_Return(self, r):
1990        return self.line('return' if r.value is None
1991                         else 'return {}'.format(self.visit(r.value)))
1992
1993    def visit_Delete(self, d):
1994        return self.line('del ' + ''.join(self.visit(e) for e in d.targets))
1995
1996    def visit_Assign(self, a):
1997        return self.line(', '.join(self.visit(e)
1998                                   for e in a.targets)
1999                         + ' = ' + self.visit(a.value))
2000
2001    def visit_AugAssign(self, a):
2002        return self.line('{target} {op}= {expr}'.format(
2003            target=self.visit(a.target),
2004            op=self.visit(a.op),
2005            expr=self.visit(a.value))
2006        )
2007
2008# Print removed as ast node on Python 3
2009#     def visit_Print(self, p):
2010#         assert p.dest == None
2011#         return self.line('print {}{}'.format(
2012#             ', '.join(self.visit(e) for e in p.values),
2013#             ',' if p.values and not p.nl else ''
2014#         ))
2015
2016    def visit_For(self, f):
2017        header = self.line('for {} in {}:'.format(
2018            self.visit(f.target),
2019            self.visit(f.iter))
2020        )
2021        self.indent()
2022        body = self.lines(self.visit(s) for s in f.body)
2023        orelse = self.lines(self.visit(s) for s in f.orelse)
2024        self.unindent()
2025        return header + body + (
2026            self.line('else:') + orelse
2027            if f.orelse else ''
2028        )
2029
2030    def visit_While(self, w):
2031        # Peter 2021-6-16: Removed this assert; orelse isn't defined here
2032        # assert not orelse
2033        header = self.line('while {}:'.format(self.visit(w.test)))
2034        self.indent()
2035        body = self.lines(self.visit(s) for s in w.body)
2036        orelse = self.lines(self.visit(s) for s in w.orelse)
2037        self.unindent()
2038        return header + body + (
2039            self.line('else:') + orelse
2040            if w.orelse else ''
2041        )
2042        return header + body
2043
2044    def visit_If(self, i):
2045        header = self.line('if {}:'.format(self.visit(i.test)))
2046        self.indent()
2047        body = self.lines(self.visit(s) for s in i.body)
2048        orelse = self.lines(self.visit(s) for s in i.orelse)
2049        self.unindent()
2050        return header + body + (
2051            self.line('else:') + orelse
2052            if i.orelse else ''
2053        )
2054
2055    def visit_With(self, w):
2056        # Converted to Python3 withitems:
2057        header = self.line(
2058            'with {items}:'.format(
2059                items=', '.join(
2060                    '{expr}{asnames}'.format(
2061                        expr=self.visit(item.context_expr),
2062                        asnames=('as ' + self.visit(item.optional_vars)
2063                                 if item.optional_vars else '')
2064                    )
2065                        for item in w.items
2066                )
2067            )
2068        )
2069        self.indent()
2070        body = self.lines(self.visit(s) for s in w.body)
2071        self.unindent()
2072        return header + body
2073
2074    # Python 3: raise has new abstract syntax
2075    def visit_Raise(self, r):
2076        return self.line('raise{}{}'.format(
2077            (' ' + self.visit(r.exc)) if r.exc else '',
2078            (' from ' + self.visit(r.cause)) if r.cause else ''
2079        ))
2080
2081#    def visit_Raise(self, r):
2082#        return self.line('raise{}{}{}{}{}'.format(
2083#             self.visit(r.type) if r.type else '',
2084#             ', ' if r.type and r.inst else '',
2085#             self.visit(r.inst) if r.inst else '',
2086#             ', ' if (r.type or r.inst) and r.tback else '',
2087#             self.visit(r.tback) if r.tback else ''
2088#                 ))
2089
2090#     def visit_TryExcept(self, te):
2091#         self.indent()
2092#         tblock = self.lines(self.visit(s) for s in te.body)
2093#         orelse = self.lines(self.visit(s) for s in te.orelse)
2094#         self.unindent()
2095#         return (
2096#             self.line('try:')
2097#             + tblock
2098#             + ''.join(self.visit(eh) for eh in te.handlers)
2099#             + (self.line('else:') + orelse if orelse else '' )
2100#         )
2101
2102#     def visit_TryFinally(self, tf):
2103#         self.indent()
2104#         tblock = self.lines(self.visit(s) for s in tf.body)
2105#         fblock = self.lines(self.visit(s) for s in tf.finalbody)
2106#         self.unindent()
2107#         return (
2108#             self.line('try:')
2109#             + tblock
2110#             + self.line('finally:')
2111#             + fblock
2112#         )
2113
2114    # In Python 3, a single Try node covers both TryExcept and TryFinally
2115    def visit_Try(self, t):
2116        self.indent()
2117        tblock = self.lines(self.visit(s) for s in t.body)
2118        orelse = self.lines(self.visit(s) for s in t.orelse)
2119        fblock = self.lines(self.visit(s) for s in t.finalbody)
2120        self.unindent()
2121        return (
2122            self.line('try:')
2123            + tblock
2124            + ''.join(self.visit(eh) for eh in t.handlers)
2125            + (self.line('else:') + orelse if orelse else '' )
2126            + (self.line('finally:') + fblock if fblock else '' )
2127        )
2128
2129    def visit_ExceptHandler(self, eh):
2130        header = self.line('except{}{}{}{}:'.format(
2131            ' ' if eh.type else '',
2132            self.visit(eh.type) if eh.type else '',
2133            ' as ' if eh.type and eh.name else ' ' if eh.name else '',
2134            self.visit(eh.name) if eh.name and isinstance(eh.name, ast.AST)
2135              else (eh.name if eh.name else '')
2136        ))
2137        self.indent()
2138        body = self.lines(self.visit(s) for s in eh.body)
2139        self.unindent()
2140        return header + body
2141
2142    def visit_Assert(self, a):
2143        return self.line('assert {}{}{}'.format(
2144            self.visit(a.test),
2145            ', ' if a.msg else '',
2146            self.visit(a.msg) if a.msg else ''
2147        ))
2148
2149    def visit_Import(self, i):
2150        return self.line(
2151            'import {}'.format(', '.join(self.visit(n) for n in i.names))
2152        )
2153
2154    def visit_ImportFrom(self, f):
2155        return self.line('from {}{} import {}'.format(
2156            '.' * f.level,
2157            f.module if f.module else '',
2158            ', '.join(self.visit(n) for n in f.names)
2159        ))
2160
2161    def visit_Exec(self, e):
2162        return self.line('exec {}{}{}{}{}'.format(
2163            self.visit(e.body),
2164            ' in ' if e.globals else '',
2165            self.visit(e.globals) if e.globals else '',
2166            ', ' if e.locals else '',
2167            self.visit(e.locals) if e.locals else ''
2168        ))
2169
2170    def visit_Global(self, g):
2171        return self.line('global {}'.format(', '.join(g.names)))
2172
2173    def visit_Expr(self, e):
2174        return self.line(self.visit(e.value))
2175
2176    def visit_Pass(self, p):
2177        return self.line('pass')
2178
2179    def visit_Break(self, b):
2180        return self.line('break')
2181
2182    def visit_Continue(self, c):
2183        return self.line('continue')
2184
2185    def visit_BoolOp(self, b):
2186        return ' {} '.format(
2187            self.visit(b.op)
2188        ).join('({})'.format(self.visit(e)) for e in b.values)
2189
2190    def visit_BinOp(self, b):
2191        return '({}) {} ({})'.format(
2192            self.visit(b.left),
2193            self.visit(b.op),
2194            self.visit(b.right)
2195        )
2196
2197    def visit_UnaryOp(self, u):
2198        return '{} ({})'.format(
2199            self.visit(u.op),
2200            self.visit(u.operand)
2201        )
2202
2203    def visit_Lambda(self, ld):
2204        return '(lambda {}: {})'.format(
2205            self.visit(ld.args),
2206            self.visit(ld.body)
2207        )
2208
2209    def visit_IfExp(self, i):
2210        return '({} if {} else {})'.format(
2211            self.visit(i.body),
2212            self.visit(i.test),
2213            self.visit(i.orelse)
2214        )
2215
2216    def visit_Dict(self, d):
2217        return '{{ {} }}'.format(
2218            ', '.join('{}: {}'.format(self.visit(k), self.visit(v))
2219                      for k, v in zip(d.keys, d.values))
2220        )
2221
2222    def visit_Set(self, s):
2223        return '{{ {} }}'.format(', '.join(self.visit(e) for e in s.elts))
2224
2225    def visit_ListComp(self, lc):
2226        return '[{} {}]'.format(
2227            self.visit(lc.elt),
2228            ' '.join(self.visit(g) for g in lc.generators)
2229        )
2230
2231    def visit_SetComp(self, sc):
2232        return '{{{} {}}}'.format(
2233            self.visit(sc.elt),
2234            ' '.join(self.visit(g) for g in sc.generators)
2235        )
2236
2237    def visit_DictComp(self, dc):
2238        return '{{{} {}}}'.format(
2239            '{}: {}'.format(self.visit(dc.key), self.visit(dc.value)),
2240            ' '.join(self.visit(g) for g in dc.generators)
2241        )
2242
2243    def visit_GeneratorExp(self, ge):
2244        return '({} {})'.format(
2245            self.visit(ge.elt),
2246            ' '.join(self.visit(g) for g in ge.generators)
2247        )
2248
2249    def visit_Yield(self, y):
2250        return 'yield {}'.format(self.visit(y.value) if y.value else '')
2251
2252    def visit_Compare(self, c):
2253        assert len(c.ops) == len(c.comparators)
2254        return '{} {}'.format(
2255            '({})'.format(self.visit(c.left)),
2256            ' '.join(
2257                '{} ({})'.format(self.visit(op), self.visit(expr))
2258                for op, expr in zip(c.ops, c.comparators)
2259            )
2260        )
2261
2262    def visit_Call(self, c):
2263        # return '{fun}({args}{keys}{starargs}{starstarargs})'.format(
2264        # Unlike Python 2, Python 3 has no starargs or startstarargs
2265        return '{fun}({args}{keys})'.format(
2266            fun=self.visit(c.func),
2267            args=', '.join(self.visit(a) for a in c.args),
2268            keys=(
2269                (', ' if c.args else '')
2270              + (
2271                  ', '.join(self.visit(ka) for ka in c.keywords)
2272                    if c.keywords else ''
2273                )
2274            )
2275        )
2276
2277    def visit_Repr(self, r):
2278        return 'repr({})'.format(self.visit(r.expr))
2279
2280    def visit_Num(self, n):
2281        return repr(n.n)
2282
2283    def visit_Str(self, s):
2284        return repr(s.s)
2285
2286    def visit_Attribute(self, a):
2287        return '{}.{}'.format(self.visit(a.value), a.attr)
2288
2289    def visit_Subscript(self, s):
2290        return '{}[{}]'.format(self.visit(s.value), self.visit(s.slice))
2291
2292    def visit_Name(self, n):
2293        return n.id
2294
2295    def visit_List(self, ls):
2296        return '[{}]'.format(', '.join(self.visit(e) for e in ls.elts))
2297
2298    def visit_Tuple(self, tp):
2299        return '({})'.format(', '.join(self.visit(e) for e in tp.elts))
2300
2301    def visit_Ellipsis(self, s):
2302        return '...'
2303
2304    def visit_Slice(self, s):
2305        return '{}:{}{}{}'.format(
2306            self.visit(s.lower) if s.lower else '',
2307            self.visit(s.upper) if s.upper else '',
2308            ':' if s.step else '',
2309            self.visit(s.step) if s.step else ''
2310        )
2311
2312    def visit_ExtSlice(self, es):
2313        return ', '.join(self.visit(s) for s in es.dims)
2314
2315    def visit_Index(self, i):
2316        return self.visit(i.value)
2317
2318    def visit_And(self, a):
2319        return 'and'
2320
2321    def visit_Or(self, o):
2322        return 'or'
2323
2324    def visit_Add(self, a):
2325        return '+'
2326
2327    def visit_Sub(self, a):
2328        return '-'
2329
2330    def visit_Mult(self, a):
2331        return '*'
2332
2333    def visit_Div(self, a):
2334        return '/'
2335
2336    def visit_Mod(self, a):
2337        return '%'
2338
2339    def visit_Pow(self, a):
2340        return '**'
2341
2342    def visit_LShift(self, a):
2343        return '<<'
2344
2345    def visit_RShift(self, a):
2346        return '>>'
2347
2348    def visit_BitOr(self, a):
2349        return '|'
2350
2351    def visit_BixXor(self, a):
2352        return '^'
2353
2354    def visit_BitAnd(self, a):
2355        return '&'
2356
2357    def visit_FloorDiv(self, a):
2358        return '//'
2359
2360    def visit_Invert(self, a):
2361        return '~'
2362
2363    def visit_Not(self, a):
2364        return 'not'
2365
2366    def visit_UAdd(self, a):
2367        return '+'
2368
2369    def visit_USub(self, a):
2370        return '-'
2371
2372    def visit_Eq(self, a):
2373        return '=='
2374
2375    def visit_NotEq(self, a):
2376        return '!='
2377
2378    def visit_Lt(self, a):
2379        return '<'
2380
2381    def visit_LtE(self, a):
2382        return '<='
2383
2384    def visit_Gt(self, a):
2385        return '>'
2386
2387    def visit_GtE(self, a):
2388        return '>='
2389
2390    def visit_Is(self, a):
2391        return 'is'
2392
2393    def visit_IsNot(self, a):
2394        return 'is not'
2395
2396    def visit_In(self, a):
2397        return 'in'
2398
2399    def visit_NotIn(self, a):
2400        return 'not in'
2401
2402    def visit_comprehension(self, c):
2403        return 'for {} in {}{}{}'.format(
2404            self.visit(c.target),
2405            self.visit(c.iter),
2406            ' ' if c.ifs else '',
2407            ' '.join('if {}'.format(self.visit(i)) for i in c.ifs)
2408        )
2409
2410    def visit_arg(self, a):
2411        '''[2019/01/22, lyn] Handle new arg objects in Python 3.'''
2412        return a.arg # The name of the argument
2413
2414    def visit_keyword(self, k):
2415        return '{}={}'.format(k.arg, self.visit(k.value))
2416
2417    def visit_alias(self, a):
2418        return '{} as {}'.format(a.name, a.asname) if a.asname else a.name
2419
2420    def visit_arguments(self, a):
2421        # [2019/01/22, lyn] Note: This does *not* handle Python 3's
2422        # keyword-only arguments (probably moot for 111, but not
2423        # beyond).
2424        stdargs = a.args[:-len(a.defaults)] if a.defaults else a.args
2425        defargs = (
2426            zip(a.args[-len(a.defaults):], a.defaults)
2427            if a.defaults else []
2428        )
2429        return '{stdargs}{sep1}{defargs}{sep2}{varargs}{sep3}{kwargs}'.format(
2430            stdargs=', '.join(self.visit(sa) for sa in stdargs),
2431            sep1=', ' if 0 < len(stdargs) and defargs else '',
2432            defargs=', '.join('{}={}'.format(self.visit(da), self.visit(dd))
2433                              for da, dd in defargs),
2434            sep2=', ' if 0 < len(a.args) and a.vararg else '',
2435            varargs='*{}'.format(a.vararg) if a.vararg else '',
2436            sep3=', ' if (0 < len(a.args) or a.vararg) and a.kwarg else '',
2437            kwargs='**{}'.format(a.kwarg) if a.kwarg else ''
2438        )
2439
2440    def visit_NameConstant(self, nc):
2441        return str(nc.value)
2442
2443    def visit_Starred(self, st):
2444        # TODO: Is this correct?
2445        return '*' + st.value.id

AST visitor: pretty print AST to python source string

SourceFormatter()
1935    def __init__(self):
1936        ast.NodeVisitor.__init__(self)
1937        self._indent = ''
1938        pass
def indent(self):
1940    def indent(self):
1941        self._indent += INDENT
1942        pass
def unindent(self):
1944    def unindent(self):
1945        self._indent = self._indent[:-4]
1946        pass
def line(self, ln):
1948    def line(self, ln):
1949        return self._indent + ln + '\n'
def lines(self, lst):
1951    def lines(self, lst):
1952        return ''.join(lst)
def generic_visit(self, node):
1954    def generic_visit(self, node):
1955        assert False, 'visiting {}'.format(ast.dump(node))
1956        pass

Called if no explicit visitor function exists for a node.

def visit_Module(self, m):
1958    def visit_Module(self, m):
1959        return self.lines(self.visit(n) for n in m.body)
def visit_Interactive(self, i):
1961    def visit_Interactive(self, i):
1962        return self.lines(self.visit(n) for n in i.body)
def visit_Expression(self, e):
1964    def visit_Expression(self, e):
1965        return self.line(self.visit(e.body))
def visit_FunctionDef(self, f):
1967    def visit_FunctionDef(self, f):
1968        assert not f.decorator_list
1969        header = self.line('def {name}({args}):'.format(
1970            name=f.name,
1971            args=self.visit(f.args))
1972        )
1973        self.indent()
1974        body = self.lines(self.visit(s) for s in f.body)
1975        self.unindent()
1976        return header + body + '\n'
def visit_ClassDef(self, c):
1978    def visit_ClassDef(self, c):
1979        assert not c.decorator_list
1980        header = self.line('class {name}({bases}):'.format(
1981            name=c.name,
1982            bases=', '.join(self.visit(b) for b in c.bases)
1983        ))
1984        self.indent()
1985        body = self.lines(self.visit(s) for s in c.body)
1986        self.unindent()
1987        return header + body + '\n'
def visit_Return(self, r):
1989    def visit_Return(self, r):
1990        return self.line('return' if r.value is None
1991                         else 'return {}'.format(self.visit(r.value)))
def visit_Delete(self, d):
1993    def visit_Delete(self, d):
1994        return self.line('del ' + ''.join(self.visit(e) for e in d.targets))
def visit_Assign(self, a):
1996    def visit_Assign(self, a):
1997        return self.line(', '.join(self.visit(e)
1998                                   for e in a.targets)
1999                         + ' = ' + self.visit(a.value))
def visit_AugAssign(self, a):
2001    def visit_AugAssign(self, a):
2002        return self.line('{target} {op}= {expr}'.format(
2003            target=self.visit(a.target),
2004            op=self.visit(a.op),
2005            expr=self.visit(a.value))
2006        )
def visit_For(self, f):
2016    def visit_For(self, f):
2017        header = self.line('for {} in {}:'.format(
2018            self.visit(f.target),
2019            self.visit(f.iter))
2020        )
2021        self.indent()
2022        body = self.lines(self.visit(s) for s in f.body)
2023        orelse = self.lines(self.visit(s) for s in f.orelse)
2024        self.unindent()
2025        return header + body + (
2026            self.line('else:') + orelse
2027            if f.orelse else ''
2028        )
def visit_While(self, w):
2030    def visit_While(self, w):
2031        # Peter 2021-6-16: Removed this assert; orelse isn't defined here
2032        # assert not orelse
2033        header = self.line('while {}:'.format(self.visit(w.test)))
2034        self.indent()
2035        body = self.lines(self.visit(s) for s in w.body)
2036        orelse = self.lines(self.visit(s) for s in w.orelse)
2037        self.unindent()
2038        return header + body + (
2039            self.line('else:') + orelse
2040            if w.orelse else ''
2041        )
2042        return header + body
def visit_If(self, i):
2044    def visit_If(self, i):
2045        header = self.line('if {}:'.format(self.visit(i.test)))
2046        self.indent()
2047        body = self.lines(self.visit(s) for s in i.body)
2048        orelse = self.lines(self.visit(s) for s in i.orelse)
2049        self.unindent()
2050        return header + body + (
2051            self.line('else:') + orelse
2052            if i.orelse else ''
2053        )
def visit_With(self, w):
2055    def visit_With(self, w):
2056        # Converted to Python3 withitems:
2057        header = self.line(
2058            'with {items}:'.format(
2059                items=', '.join(
2060                    '{expr}{asnames}'.format(
2061                        expr=self.visit(item.context_expr),
2062                        asnames=('as ' + self.visit(item.optional_vars)
2063                                 if item.optional_vars else '')
2064                    )
2065                        for item in w.items
2066                )
2067            )
2068        )
2069        self.indent()
2070        body = self.lines(self.visit(s) for s in w.body)
2071        self.unindent()
2072        return header + body
def visit_Raise(self, r):
2075    def visit_Raise(self, r):
2076        return self.line('raise{}{}'.format(
2077            (' ' + self.visit(r.exc)) if r.exc else '',
2078            (' from ' + self.visit(r.cause)) if r.cause else ''
2079        ))
def visit_Try(self, t):
2115    def visit_Try(self, t):
2116        self.indent()
2117        tblock = self.lines(self.visit(s) for s in t.body)
2118        orelse = self.lines(self.visit(s) for s in t.orelse)
2119        fblock = self.lines(self.visit(s) for s in t.finalbody)
2120        self.unindent()
2121        return (
2122            self.line('try:')
2123            + tblock
2124            + ''.join(self.visit(eh) for eh in t.handlers)
2125            + (self.line('else:') + orelse if orelse else '' )
2126            + (self.line('finally:') + fblock if fblock else '' )
2127        )
def visit_ExceptHandler(self, eh):
2129    def visit_ExceptHandler(self, eh):
2130        header = self.line('except{}{}{}{}:'.format(
2131            ' ' if eh.type else '',
2132            self.visit(eh.type) if eh.type else '',
2133            ' as ' if eh.type and eh.name else ' ' if eh.name else '',
2134            self.visit(eh.name) if eh.name and isinstance(eh.name, ast.AST)
2135              else (eh.name if eh.name else '')
2136        ))
2137        self.indent()
2138        body = self.lines(self.visit(s) for s in eh.body)
2139        self.unindent()
2140        return header + body
def visit_Assert(self, a):
2142    def visit_Assert(self, a):
2143        return self.line('assert {}{}{}'.format(
2144            self.visit(a.test),
2145            ', ' if a.msg else '',
2146            self.visit(a.msg) if a.msg else ''
2147        ))
def visit_Import(self, i):
2149    def visit_Import(self, i):
2150        return self.line(
2151            'import {}'.format(', '.join(self.visit(n) for n in i.names))
2152        )
def visit_ImportFrom(self, f):
2154    def visit_ImportFrom(self, f):
2155        return self.line('from {}{} import {}'.format(
2156            '.' * f.level,
2157            f.module if f.module else '',
2158            ', '.join(self.visit(n) for n in f.names)
2159        ))
def visit_Exec(self, e):
2161    def visit_Exec(self, e):
2162        return self.line('exec {}{}{}{}{}'.format(
2163            self.visit(e.body),
2164            ' in ' if e.globals else '',
2165            self.visit(e.globals) if e.globals else '',
2166            ', ' if e.locals else '',
2167            self.visit(e.locals) if e.locals else ''
2168        ))
def visit_Global(self, g):
2170    def visit_Global(self, g):
2171        return self.line('global {}'.format(', '.join(g.names)))
def visit_Expr(self, e):
2173    def visit_Expr(self, e):
2174        return self.line(self.visit(e.value))
def visit_Pass(self, p):
2176    def visit_Pass(self, p):
2177        return self.line('pass')
def visit_Break(self, b):
2179    def visit_Break(self, b):
2180        return self.line('break')
def visit_Continue(self, c):
2182    def visit_Continue(self, c):
2183        return self.line('continue')
def visit_BoolOp(self, b):
2185    def visit_BoolOp(self, b):
2186        return ' {} '.format(
2187            self.visit(b.op)
2188        ).join('({})'.format(self.visit(e)) for e in b.values)
def visit_BinOp(self, b):
2190    def visit_BinOp(self, b):
2191        return '({}) {} ({})'.format(
2192            self.visit(b.left),
2193            self.visit(b.op),
2194            self.visit(b.right)
2195        )
def visit_UnaryOp(self, u):
2197    def visit_UnaryOp(self, u):
2198        return '{} ({})'.format(
2199            self.visit(u.op),
2200            self.visit(u.operand)
2201        )
def visit_Lambda(self, ld):
2203    def visit_Lambda(self, ld):
2204        return '(lambda {}: {})'.format(
2205            self.visit(ld.args),
2206            self.visit(ld.body)
2207        )
def visit_IfExp(self, i):
2209    def visit_IfExp(self, i):
2210        return '({} if {} else {})'.format(
2211            self.visit(i.body),
2212            self.visit(i.test),
2213            self.visit(i.orelse)
2214        )
def visit_Dict(self, d):
2216    def visit_Dict(self, d):
2217        return '{{ {} }}'.format(
2218            ', '.join('{}: {}'.format(self.visit(k), self.visit(v))
2219                      for k, v in zip(d.keys, d.values))
2220        )
def visit_Set(self, s):
2222    def visit_Set(self, s):
2223        return '{{ {} }}'.format(', '.join(self.visit(e) for e in s.elts))
def visit_ListComp(self, lc):
2225    def visit_ListComp(self, lc):
2226        return '[{} {}]'.format(
2227            self.visit(lc.elt),
2228            ' '.join(self.visit(g) for g in lc.generators)
2229        )
def visit_SetComp(self, sc):
2231    def visit_SetComp(self, sc):
2232        return '{{{} {}}}'.format(
2233            self.visit(sc.elt),
2234            ' '.join(self.visit(g) for g in sc.generators)
2235        )
def visit_DictComp(self, dc):
2237    def visit_DictComp(self, dc):
2238        return '{{{} {}}}'.format(
2239            '{}: {}'.format(self.visit(dc.key), self.visit(dc.value)),
2240            ' '.join(self.visit(g) for g in dc.generators)
2241        )
def visit_GeneratorExp(self, ge):
2243    def visit_GeneratorExp(self, ge):
2244        return '({} {})'.format(
2245            self.visit(ge.elt),
2246            ' '.join(self.visit(g) for g in ge.generators)
2247        )
def visit_Yield(self, y):
2249    def visit_Yield(self, y):
2250        return 'yield {}'.format(self.visit(y.value) if y.value else '')
def visit_Compare(self, c):
2252    def visit_Compare(self, c):
2253        assert len(c.ops) == len(c.comparators)
2254        return '{} {}'.format(
2255            '({})'.format(self.visit(c.left)),
2256            ' '.join(
2257                '{} ({})'.format(self.visit(op), self.visit(expr))
2258                for op, expr in zip(c.ops, c.comparators)
2259            )
2260        )
def visit_Call(self, c):
2262    def visit_Call(self, c):
2263        # return '{fun}({args}{keys}{starargs}{starstarargs})'.format(
2264        # Unlike Python 2, Python 3 has no starargs or startstarargs
2265        return '{fun}({args}{keys})'.format(
2266            fun=self.visit(c.func),
2267            args=', '.join(self.visit(a) for a in c.args),
2268            keys=(
2269                (', ' if c.args else '')
2270              + (
2271                  ', '.join(self.visit(ka) for ka in c.keywords)
2272                    if c.keywords else ''
2273                )
2274            )
2275        )
def visit_Repr(self, r):
2277    def visit_Repr(self, r):
2278        return 'repr({})'.format(self.visit(r.expr))
def visit_Num(self, n):
2280    def visit_Num(self, n):
2281        return repr(n.n)
def visit_Str(self, s):
2283    def visit_Str(self, s):
2284        return repr(s.s)
def visit_Attribute(self, a):
2286    def visit_Attribute(self, a):
2287        return '{}.{}'.format(self.visit(a.value), a.attr)
def visit_Subscript(self, s):
2289    def visit_Subscript(self, s):
2290        return '{}[{}]'.format(self.visit(s.value), self.visit(s.slice))
def visit_Name(self, n):
2292    def visit_Name(self, n):
2293        return n.id
def visit_List(self, ls):
2295    def visit_List(self, ls):
2296        return '[{}]'.format(', '.join(self.visit(e) for e in ls.elts))
def visit_Tuple(self, tp):
2298    def visit_Tuple(self, tp):
2299        return '({})'.format(', '.join(self.visit(e) for e in tp.elts))
def visit_Ellipsis(self, s):
2301    def visit_Ellipsis(self, s):
2302        return '...'
def visit_Slice(self, s):
2304    def visit_Slice(self, s):
2305        return '{}:{}{}{}'.format(
2306            self.visit(s.lower) if s.lower else '',
2307            self.visit(s.upper) if s.upper else '',
2308            ':' if s.step else '',
2309            self.visit(s.step) if s.step else ''
2310        )
def visit_ExtSlice(self, es):
2312    def visit_ExtSlice(self, es):
2313        return ', '.join(self.visit(s) for s in es.dims)
def visit_Index(self, i):
2315    def visit_Index(self, i):
2316        return self.visit(i.value)
def visit_And(self, a):
2318    def visit_And(self, a):
2319        return 'and'
def visit_Or(self, o):
2321    def visit_Or(self, o):
2322        return 'or'
def visit_Add(self, a):
2324    def visit_Add(self, a):
2325        return '+'
def visit_Sub(self, a):
2327    def visit_Sub(self, a):
2328        return '-'
def visit_Mult(self, a):
2330    def visit_Mult(self, a):
2331        return '*'
def visit_Div(self, a):
2333    def visit_Div(self, a):
2334        return '/'
def visit_Mod(self, a):
2336    def visit_Mod(self, a):
2337        return '%'
def visit_Pow(self, a):
2339    def visit_Pow(self, a):
2340        return '**'
def visit_LShift(self, a):
2342    def visit_LShift(self, a):
2343        return '<<'
def visit_RShift(self, a):
2345    def visit_RShift(self, a):
2346        return '>>'
def visit_BitOr(self, a):
2348    def visit_BitOr(self, a):
2349        return '|'
def visit_BixXor(self, a):
2351    def visit_BixXor(self, a):
2352        return '^'
def visit_BitAnd(self, a):
2354    def visit_BitAnd(self, a):
2355        return '&'
def visit_FloorDiv(self, a):
2357    def visit_FloorDiv(self, a):
2358        return '//'
def visit_Invert(self, a):
2360    def visit_Invert(self, a):
2361        return '~'
def visit_Not(self, a):
2363    def visit_Not(self, a):
2364        return 'not'
def visit_UAdd(self, a):
2366    def visit_UAdd(self, a):
2367        return '+'
def visit_USub(self, a):
2369    def visit_USub(self, a):
2370        return '-'
def visit_Eq(self, a):
2372    def visit_Eq(self, a):
2373        return '=='
def visit_NotEq(self, a):
2375    def visit_NotEq(self, a):
2376        return '!='
def visit_Lt(self, a):
2378    def visit_Lt(self, a):
2379        return '<'
def visit_LtE(self, a):
2381    def visit_LtE(self, a):
2382        return '<='
def visit_Gt(self, a):
2384    def visit_Gt(self, a):
2385        return '>'
def visit_GtE(self, a):
2387    def visit_GtE(self, a):
2388        return '>='
def visit_Is(self, a):
2390    def visit_Is(self, a):
2391        return 'is'
def visit_IsNot(self, a):
2393    def visit_IsNot(self, a):
2394        return 'is not'
def visit_In(self, a):
2396    def visit_In(self, a):
2397        return 'in'
def visit_NotIn(self, a):
2399    def visit_NotIn(self, a):
2400        return 'not in'
def visit_comprehension(self, c):
2402    def visit_comprehension(self, c):
2403        return 'for {} in {}{}{}'.format(
2404            self.visit(c.target),
2405            self.visit(c.iter),
2406            ' ' if c.ifs else '',
2407            ' '.join('if {}'.format(self.visit(i)) for i in c.ifs)
2408        )
def visit_arg(self, a):
2410    def visit_arg(self, a):
2411        '''[2019/01/22, lyn] Handle new arg objects in Python 3.'''
2412        return a.arg # The name of the argument

[2019/01/22, lyn] Handle new arg objects in Python 3.

def visit_keyword(self, k):
2414    def visit_keyword(self, k):
2415        return '{}={}'.format(k.arg, self.visit(k.value))
def visit_alias(self, a):
2417    def visit_alias(self, a):
2418        return '{} as {}'.format(a.name, a.asname) if a.asname else a.name
def visit_arguments(self, a):
2420    def visit_arguments(self, a):
2421        # [2019/01/22, lyn] Note: This does *not* handle Python 3's
2422        # keyword-only arguments (probably moot for 111, but not
2423        # beyond).
2424        stdargs = a.args[:-len(a.defaults)] if a.defaults else a.args
2425        defargs = (
2426            zip(a.args[-len(a.defaults):], a.defaults)
2427            if a.defaults else []
2428        )
2429        return '{stdargs}{sep1}{defargs}{sep2}{varargs}{sep3}{kwargs}'.format(
2430            stdargs=', '.join(self.visit(sa) for sa in stdargs),
2431            sep1=', ' if 0 < len(stdargs) and defargs else '',
2432            defargs=', '.join('{}={}'.format(self.visit(da), self.visit(dd))
2433                              for da, dd in defargs),
2434            sep2=', ' if 0 < len(a.args) and a.vararg else '',
2435            varargs='*{}'.format(a.vararg) if a.vararg else '',
2436            sep3=', ' if (0 < len(a.args) or a.vararg) and a.kwarg else '',
2437            kwargs='**{}'.format(a.kwarg) if a.kwarg else ''
2438        )
def visit_NameConstant(self, nc):
2440    def visit_NameConstant(self, nc):
2441        return str(nc.value)
def visit_Starred(self, st):
2443    def visit_Starred(self, st):
2444        # TODO: Is this correct?
2445        return '*' + st.value.id
Inherited Members
ast.NodeVisitor
visit
visit_Constant
def ast2source(node):
2448def ast2source(node):
2449    """Pretty print an AST as a python source string"""
2450    return SourceFormatter().visit(node)

Pretty print an AST as a python source string

def source(node):
2453def source(node):
2454    """Alias for ast2source"""
2455    return ast2source(node)

Alias for ast2source