From afb358313aacfb8bf63c8dc084eea17c7903b9f8 Mon Sep 17 00:00:00 2001 From: Todd Gamblin Date: Tue, 21 Dec 2021 16:43:03 -0800 Subject: [PATCH] unparser: refactor delimiting with context managers in ast.unparse Backport of https://github.com/python/cpython/commit/4b3b1226e86df6cd45e921c8f2ad23c3639c43b2 --- .../external/spack_astunparse/unparser.py | 367 +++++++++--------- 1 file changed, 177 insertions(+), 190 deletions(-) diff --git a/lib/spack/external/spack_astunparse/unparser.py b/lib/spack/external/spack_astunparse/unparser.py index d3b57684d9..56e60c4c14 100644 --- a/lib/spack/external/spack_astunparse/unparser.py +++ b/lib/spack/external/spack_astunparse/unparser.py @@ -13,6 +13,7 @@ # TODO: if we require Python 3.7, use its `nullcontext()` +@contextmanager def nullcontext(): yield @@ -101,6 +102,21 @@ def __exit__(self, exc_type, exc_value, traceback): def block(self): return self._Block(self) + @contextmanager + def delimit(self, start, end): + """A context manager for preparing the source for expressions. It adds + *start* to the buffer and enters, after exit it adds *end*.""" + + self.write(start) + yield + self.write(end) + + def delimit_if(self, start, end, condition): + if condition: + return self.delimit(start, end) + else: + return nullcontext() + def dispatch(self, tree): "Dispatcher function, dispatching tree type T to method _T." if isinstance(tree, list): @@ -135,11 +151,10 @@ def _Expr(self, tree): self.dispatch(tree.value) def _NamedExpr(self, tree): - self.write("(") - self.dispatch(tree.target) - self.write(" := ") - self.dispatch(tree.value) - self.write(")") + with self.delimit("(", ")"): + self.dispatch(tree.target) + self.write(" := ") + self.dispatch(tree.value) def _Import(self, t): self.fill("import ") @@ -172,11 +187,9 @@ def _AugAssign(self, t): def _AnnAssign(self, t): self.fill() - if not t.simple and isinstance(t.target, ast.Name): - self.write('(') - self.dispatch(t.target) - if not t.simple and isinstance(t.target, ast.Name): - self.write(')') + with self.delimit_if( + "(", ")", not node.simple and isinstance(t.target, ast.Name)): + self.dispatch(t.target) self.write(": ") self.dispatch(t.annotation) if t.value: @@ -250,28 +263,25 @@ def _Nonlocal(self, t): interleave(lambda: self.write(", "), self.write, t.names) def _Await(self, t): - self.write("(") - self.write("await") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") + with self.delimit("(", ")"): + self.write("await") + if t.value: + self.write(" ") + self.dispatch(t.value) def _Yield(self, t): - self.write("(") - self.write("yield") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") + with self.delimit("(", ")"): + self.write("yield") + if t.value: + self.write(" ") + self.dispatch(t.value) def _YieldFrom(self, t): - self.write("(") - self.write("yield from") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") + with self.delimit("(", ")"): + self.write("yield from") + if t.value: + self.write(" ") + self.dispatch(t.value) def _Raise(self, t): self.fill("raise") @@ -356,35 +366,33 @@ def _ClassDef(self, t): self.dispatch(deco) self.fill("class "+t.name) if six.PY3: - self.write("(") - comma = False - for e in t.bases: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - for e in t.keywords: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - if sys.version_info[:2] < (3, 5): - if t.starargs: + with self.delimit("(", ")"): + comma = False + for e in t.bases: if comma: self.write(", ") else: comma = True - self.write("*") - self.dispatch(t.starargs) - if t.kwargs: + self.dispatch(e) + for e in t.keywords: if comma: self.write(", ") else: comma = True - self.write("**") - self.dispatch(t.kwargs) - self.write(")") + self.dispatch(e) + if sys.version_info[:2] < (3, 5): + if t.starargs: + if comma: self.write(", ") + else: comma = True + self.write("*") + self.dispatch(t.starargs) + if t.kwargs: + if comma: self.write(", ") + else: comma = True + self.write("**") + self.dispatch(t.kwargs) elif t.bases: - self.write("(") + with self.delimit("(", ")"): for a in t.bases[:-1]: self.dispatch(a) self.write(", ") self.dispatch(t.bases[-1]) - self.write(")") with self.block(): self.dispatch(t.body) @@ -399,10 +407,10 @@ def __FunctionDef_helper(self, t, fill_suffix): for deco in t.decorator_list: self.fill("@") self.dispatch(deco) - def_str = fill_suffix+" "+t.name + "(" + def_str = fill_suffix + " " + t.name self.fill(def_str) - self.dispatch(t.args) - self.write(")") + with self.delimit("(", ")"): + self.dispatch(t.args) if getattr(t, "returns", False): self.write(" -> ") self.dispatch(t.returns) @@ -574,13 +582,12 @@ def _write_constant(self, value): def _Constant(self, t): value = t.value if isinstance(value, tuple): - self.write("(") - if len(value) == 1: - self._write_constant(value[0]) - self.write(",") - else: - interleave(lambda: self.write(", "), self._write_constant, value) - self.write(")") + with self.delimit("(", ")"): + if len(value) == 1: + self._write_constant(value[0]) + self.write(",") + else: + interleave(lambda: self.write(", "), self._write_constant, value) elif value is Ellipsis: # instead of `...` for Py2 compatibility self.write("...") else: @@ -594,49 +601,41 @@ def _Num(self, t): self.write(repr_n.replace("inf", INFSTR)) else: # Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2. - if repr_n.startswith("-"): - self.write("(") - if "inf" in repr_n and repr_n.endswith("*j"): - repr_n = repr_n.replace("*j", "j") - # Substitute overflowing decimal literal for AST infinities. - self.write(repr_n.replace("inf", INFSTR)) - if repr_n.startswith("-"): - self.write(")") + with self.delimit_if("(", ")", repr_n.startswith("-")): + if "inf" in repr_n and repr_n.endswith("*j"): + repr_n = repr_n.replace("*j", "j") + # Substitute overflowing decimal literal for AST infinities. + self.write(repr_n.replace("inf", INFSTR)) def _List(self, t): - self.write("[") - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("]") + with self.delimit("[", "]"): + interleave(lambda: self.write(", "), self.dispatch, t.elts) def _ListComp(self, t): - self.write("[") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("]") + with self.delimit("[", "]"): + self.dispatch(t.elt) + for gen in t.generators: + self.dispatch(gen) def _GeneratorExp(self, t): - self.write("(") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write(")") + with self.delimit("(", ")"): + self.dispatch(t.elt) + for gen in t.generators: + self.dispatch(gen) def _SetComp(self, t): - self.write("{") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("}") + with self.delimit("{", "}"): + self.dispatch(t.elt) + for gen in t.generators: + self.dispatch(gen) def _DictComp(self, t): - self.write("{") - self.dispatch(t.key) - self.write(": ") - self.dispatch(t.value) - for gen in t.generators: - self.dispatch(gen) - self.write("}") + with self.delimit("{", "}"): + self.dispatch(t.key) + self.write(": ") + self.dispatch(t.value) + for gen in t.generators: + self.dispatch(gen) def _comprehension(self, t): if getattr(t, 'is_async', False): @@ -651,22 +650,19 @@ def _comprehension(self, t): self.dispatch(if_clause) def _IfExp(self, t): - self.write("(") - self.dispatch(t.body) - self.write(" if ") - self.dispatch(t.test) - self.write(" else ") - self.dispatch(t.orelse) - self.write(")") + with self.delimit("(", ")"): + self.dispatch(t.body) + self.write(" if ") + self.dispatch(t.test) + self.write(" else ") + self.dispatch(t.orelse) def _Set(self, t): assert(t.elts) # should be at least one element - self.write("{") - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("}") + with self.delimit("{", "}"): + interleave(lambda: self.write(", "), self.dispatch, t.elts) def _Dict(self, t): - self.write("{") def write_key_value_pair(k, v): self.dispatch(k) self.write(": ") @@ -681,64 +677,59 @@ def write_item(item): self.dispatch(v) else: write_key_value_pair(k, v) - interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values)) - self.write("}") + + with self.delimit("{", "}"): + interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values)) def _Tuple(self, t): - self.write("(") - if len(t.elts) == 1: - elt = t.elts[0] - self.dispatch(elt) - self.write(",") - else: - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write(")") + with self.delimit("(", ")"): + if len(t.elts) == 1: + elt = t.elts[0] + self.dispatch(elt) + self.write(",") + else: + interleave(lambda: self.write(", "), self.dispatch, t.elts) unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"} def _UnaryOp(self, t): - self.write("(") - self.write(self.unop[t.op.__class__.__name__]) - if not self._py_ver_consistent: - self.write(" ") - if six.PY2 and isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num): - # If we're applying unary minus to a number, parenthesize the number. - # This is necessary: -2147483648 is different from -(2147483648) on - # a 32-bit machine (the first is an int, the second a long), and - # -7j is different from -(7j). (The first has real part 0.0, the second - # has real part -0.0.) - self.write("(") - self.dispatch(t.operand) - self.write(")") - else: - self.dispatch(t.operand) - self.write(")") + with self.delimit("(", ")"): + self.write(self.unop[t.op.__class__.__name__]) + if not self._py_ver_consistent: + self.write(" ") + if six.PY2 and isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num): + # If we're applying unary minus to a number, parenthesize the number. + # This is necessary: -2147483648 is different from -(2147483648) on + # a 32-bit machine (the first is an int, the second a long), and + # -7j is different from -(7j). (The first has real part 0.0, the second + # has real part -0.0.) + with self.delimit("(", ")"): + self.dispatch(t.operand) + else: + self.dispatch(t.operand) binop = { "Add":"+", "Sub":"-", "Mult":"*", "MatMult":"@", "Div":"/", "Mod":"%", "LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&", "FloorDiv":"//", "Pow": "**"} def _BinOp(self, t): - self.write("(") - self.dispatch(t.left) - self.write(" " + self.binop[t.op.__class__.__name__] + " ") - self.dispatch(t.right) - self.write(")") + with self.delimit("(", ")"): + self.dispatch(t.left) + self.write(" " + self.binop[t.op.__class__.__name__] + " ") + self.dispatch(t.right) cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=", "Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"} def _Compare(self, t): - self.write("(") - self.dispatch(t.left) - for o, e in zip(t.ops, t.comparators): - self.write(" " + self.cmpops[o.__class__.__name__] + " ") - self.dispatch(e) - self.write(")") + with self.delimit("(", ")"): + self.dispatch(t.left) + for o, e in zip(t.ops, t.comparators): + self.write(" " + self.cmpops[o.__class__.__name__] + " ") + self.dispatch(e) boolops = {ast.And: 'and', ast.Or: 'or'} def _BoolOp(self, t): - self.write("(") - s = " %s " % self.boolops[t.op.__class__] - interleave(lambda: self.write(s), self.dispatch, t.values) - self.write(")") + with self.delimit("(", ")"): + s = " %s " % self.boolops[t.op.__class__] + interleave(lambda: self.write(s), self.dispatch, t.values) def _Attribute(self,t): self.dispatch(t.value) @@ -752,55 +743,52 @@ def _Attribute(self,t): def _Call(self, t): self.dispatch(t.func) - self.write("(") - comma = False + with self.delimit("(", ")"): + comma = False - # move starred arguments last in Python 3.5+, for consistency w/earlier versions - star_and_kwargs = [] - move_stars_last = sys.version_info[:2] >= (3, 5) + # starred arguments last in Python 3.5+, for consistency w/earlier versions + star_and_kwargs = [] + move_stars_last = sys.version_info[:2] >= (3, 5) - for e in t.args: - if move_stars_last and isinstance(e, ast.Starred): - star_and_kwargs.append(e) - else: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) + for e in t.args: + if move_stars_last and isinstance(e, ast.Starred): + star_and_kwargs.append(e) + else: + if comma: self.write(", ") + else: comma = True + self.dispatch(e) - for e in t.keywords: - # starting from Python 3.5 this denotes a kwargs part of the invocation - if e.arg is None and move_stars_last: - star_and_kwargs.append(e) - else: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) + for e in t.keywords: + # starting from Python 3.5 this denotes a kwargs part of the invocation + if e.arg is None and move_stars_last: + star_and_kwargs.append(e) + else: + if comma: self.write(", ") + else: comma = True + self.dispatch(e) - if move_stars_last: - for e in star_and_kwargs: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) + if move_stars_last: + for e in star_and_kwargs: + if comma: self.write(", ") + else: comma = True + self.dispatch(e) - if sys.version_info[:2] < (3, 5): - if t.starargs: - if comma: self.write(", ") - else: comma = True - self.write("*") - self.dispatch(t.starargs) - if t.kwargs: - if comma: self.write(", ") - else: comma = True - self.write("**") - self.dispatch(t.kwargs) - - self.write(")") + if sys.version_info[:2] < (3, 5): + if t.starargs: + if comma: self.write(", ") + else: comma = True + self.write("*") + self.dispatch(t.starargs) + if t.kwargs: + if comma: self.write(", ") + else: comma = True + self.write("**") + self.dispatch(t.kwargs) def _Subscript(self, t): self.dispatch(t.value) - self.write("[") - self.dispatch(t.slice) - self.write("]") + with self.delimit("[", "]"): + self.dispatch(t.slice) def _Starred(self, t): self.write("*") @@ -902,12 +890,11 @@ def _keyword(self, t): self.dispatch(t.value) def _Lambda(self, t): - self.write("(") - self.write("lambda ") - self.dispatch(t.args) - self.write(": ") - self.dispatch(t.body) - self.write(")") + with self.delimit("(", ")"): + self.write("lambda ") + self.dispatch(t.args) + self.write(": ") + self.dispatch(t.body) def _alias(self, t): self.write(t.name)