unparser: refactor delimiting with context managers in ast.unparse

Backport of 4b3b1226e8
This commit is contained in:
Todd Gamblin 2021-12-21 16:43:03 -08:00 committed by Greg Becker
parent 5847eb1e65
commit afb358313a

View file

@ -13,6 +13,7 @@
# TODO: if we require Python 3.7, use its `nullcontext()` # TODO: if we require Python 3.7, use its `nullcontext()`
@contextmanager
def nullcontext(): def nullcontext():
yield yield
@ -101,6 +102,21 @@ def __exit__(self, exc_type, exc_value, traceback):
def block(self): def block(self):
return self._Block(self) return self._Block(self)
@contextmanager
def delimit(self, start, end):
"""A context manager for preparing the source for expressions. It adds
*start* to the buffer and enters, after exit it adds *end*."""
self.write(start)
yield
self.write(end)
def delimit_if(self, start, end, condition):
if condition:
return self.delimit(start, end)
else:
return nullcontext()
def dispatch(self, tree): def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T." "Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list): if isinstance(tree, list):
@ -135,11 +151,10 @@ def _Expr(self, tree):
self.dispatch(tree.value) self.dispatch(tree.value)
def _NamedExpr(self, tree): def _NamedExpr(self, tree):
self.write("(") with self.delimit("(", ")"):
self.dispatch(tree.target) self.dispatch(tree.target)
self.write(" := ") self.write(" := ")
self.dispatch(tree.value) self.dispatch(tree.value)
self.write(")")
def _Import(self, t): def _Import(self, t):
self.fill("import ") self.fill("import ")
@ -172,11 +187,9 @@ def _AugAssign(self, t):
def _AnnAssign(self, t): def _AnnAssign(self, t):
self.fill() self.fill()
if not t.simple and isinstance(t.target, ast.Name): with self.delimit_if(
self.write('(') "(", ")", not node.simple and isinstance(t.target, ast.Name)):
self.dispatch(t.target) self.dispatch(t.target)
if not t.simple and isinstance(t.target, ast.Name):
self.write(')')
self.write(": ") self.write(": ")
self.dispatch(t.annotation) self.dispatch(t.annotation)
if t.value: if t.value:
@ -250,28 +263,25 @@ def _Nonlocal(self, t):
interleave(lambda: self.write(", "), self.write, t.names) interleave(lambda: self.write(", "), self.write, t.names)
def _Await(self, t): def _Await(self, t):
self.write("(") with self.delimit("(", ")"):
self.write("await") self.write("await")
if t.value: if t.value:
self.write(" ") self.write(" ")
self.dispatch(t.value) self.dispatch(t.value)
self.write(")")
def _Yield(self, t): def _Yield(self, t):
self.write("(") with self.delimit("(", ")"):
self.write("yield") self.write("yield")
if t.value: if t.value:
self.write(" ") self.write(" ")
self.dispatch(t.value) self.dispatch(t.value)
self.write(")")
def _YieldFrom(self, t): def _YieldFrom(self, t):
self.write("(") with self.delimit("(", ")"):
self.write("yield from") self.write("yield from")
if t.value: if t.value:
self.write(" ") self.write(" ")
self.dispatch(t.value) self.dispatch(t.value)
self.write(")")
def _Raise(self, t): def _Raise(self, t):
self.fill("raise") self.fill("raise")
@ -356,7 +366,7 @@ def _ClassDef(self, t):
self.dispatch(deco) self.dispatch(deco)
self.fill("class "+t.name) self.fill("class "+t.name)
if six.PY3: if six.PY3:
self.write("(") with self.delimit("(", ")"):
comma = False comma = False
for e in t.bases: for e in t.bases:
if comma: self.write(", ") if comma: self.write(", ")
@ -377,14 +387,12 @@ def _ClassDef(self, t):
else: comma = True else: comma = True
self.write("**") self.write("**")
self.dispatch(t.kwargs) self.dispatch(t.kwargs)
self.write(")")
elif t.bases: elif t.bases:
self.write("(") with self.delimit("(", ")"):
for a in t.bases[:-1]: for a in t.bases[:-1]:
self.dispatch(a) self.dispatch(a)
self.write(", ") self.write(", ")
self.dispatch(t.bases[-1]) self.dispatch(t.bases[-1])
self.write(")")
with self.block(): with self.block():
self.dispatch(t.body) self.dispatch(t.body)
@ -399,10 +407,10 @@ def __FunctionDef_helper(self, t, fill_suffix):
for deco in t.decorator_list: for deco in t.decorator_list:
self.fill("@") self.fill("@")
self.dispatch(deco) self.dispatch(deco)
def_str = fill_suffix+" "+t.name + "(" def_str = fill_suffix + " " + t.name
self.fill(def_str) self.fill(def_str)
with self.delimit("(", ")"):
self.dispatch(t.args) self.dispatch(t.args)
self.write(")")
if getattr(t, "returns", False): if getattr(t, "returns", False):
self.write(" -> ") self.write(" -> ")
self.dispatch(t.returns) self.dispatch(t.returns)
@ -574,13 +582,12 @@ def _write_constant(self, value):
def _Constant(self, t): def _Constant(self, t):
value = t.value value = t.value
if isinstance(value, tuple): if isinstance(value, tuple):
self.write("(") with self.delimit("(", ")"):
if len(value) == 1: if len(value) == 1:
self._write_constant(value[0]) self._write_constant(value[0])
self.write(",") self.write(",")
else: else:
interleave(lambda: self.write(", "), self._write_constant, value) interleave(lambda: self.write(", "), self._write_constant, value)
self.write(")")
elif value is Ellipsis: # instead of `...` for Py2 compatibility elif value is Ellipsis: # instead of `...` for Py2 compatibility
self.write("...") self.write("...")
else: else:
@ -594,49 +601,41 @@ def _Num(self, t):
self.write(repr_n.replace("inf", INFSTR)) self.write(repr_n.replace("inf", INFSTR))
else: else:
# Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2. # Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2.
if repr_n.startswith("-"): with self.delimit_if("(", ")", repr_n.startswith("-")):
self.write("(")
if "inf" in repr_n and repr_n.endswith("*j"): if "inf" in repr_n and repr_n.endswith("*j"):
repr_n = repr_n.replace("*j", "j") repr_n = repr_n.replace("*j", "j")
# Substitute overflowing decimal literal for AST infinities. # Substitute overflowing decimal literal for AST infinities.
self.write(repr_n.replace("inf", INFSTR)) self.write(repr_n.replace("inf", INFSTR))
if repr_n.startswith("-"):
self.write(")")
def _List(self, t): def _List(self, t):
self.write("[") with self.delimit("[", "]"):
interleave(lambda: self.write(", "), self.dispatch, t.elts) interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("]")
def _ListComp(self, t): def _ListComp(self, t):
self.write("[") with self.delimit("[", "]"):
self.dispatch(t.elt) self.dispatch(t.elt)
for gen in t.generators: for gen in t.generators:
self.dispatch(gen) self.dispatch(gen)
self.write("]")
def _GeneratorExp(self, t): def _GeneratorExp(self, t):
self.write("(") with self.delimit("(", ")"):
self.dispatch(t.elt) self.dispatch(t.elt)
for gen in t.generators: for gen in t.generators:
self.dispatch(gen) self.dispatch(gen)
self.write(")")
def _SetComp(self, t): def _SetComp(self, t):
self.write("{") with self.delimit("{", "}"):
self.dispatch(t.elt) self.dispatch(t.elt)
for gen in t.generators: for gen in t.generators:
self.dispatch(gen) self.dispatch(gen)
self.write("}")
def _DictComp(self, t): def _DictComp(self, t):
self.write("{") with self.delimit("{", "}"):
self.dispatch(t.key) self.dispatch(t.key)
self.write(": ") self.write(": ")
self.dispatch(t.value) self.dispatch(t.value)
for gen in t.generators: for gen in t.generators:
self.dispatch(gen) self.dispatch(gen)
self.write("}")
def _comprehension(self, t): def _comprehension(self, t):
if getattr(t, 'is_async', False): if getattr(t, 'is_async', False):
@ -651,22 +650,19 @@ def _comprehension(self, t):
self.dispatch(if_clause) self.dispatch(if_clause)
def _IfExp(self, t): def _IfExp(self, t):
self.write("(") with self.delimit("(", ")"):
self.dispatch(t.body) self.dispatch(t.body)
self.write(" if ") self.write(" if ")
self.dispatch(t.test) self.dispatch(t.test)
self.write(" else ") self.write(" else ")
self.dispatch(t.orelse) self.dispatch(t.orelse)
self.write(")")
def _Set(self, t): def _Set(self, t):
assert(t.elts) # should be at least one element assert(t.elts) # should be at least one element
self.write("{") with self.delimit("{", "}"):
interleave(lambda: self.write(", "), self.dispatch, t.elts) interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("}")
def _Dict(self, t): def _Dict(self, t):
self.write("{")
def write_key_value_pair(k, v): def write_key_value_pair(k, v):
self.dispatch(k) self.dispatch(k)
self.write(": ") self.write(": ")
@ -681,22 +677,22 @@ def write_item(item):
self.dispatch(v) self.dispatch(v)
else: else:
write_key_value_pair(k, v) write_key_value_pair(k, v)
with self.delimit("{", "}"):
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values)) interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
self.write("}")
def _Tuple(self, t): def _Tuple(self, t):
self.write("(") with self.delimit("(", ")"):
if len(t.elts) == 1: if len(t.elts) == 1:
elt = t.elts[0] elt = t.elts[0]
self.dispatch(elt) self.dispatch(elt)
self.write(",") self.write(",")
else: else:
interleave(lambda: self.write(", "), self.dispatch, t.elts) interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"} unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"}
def _UnaryOp(self, t): def _UnaryOp(self, t):
self.write("(") with self.delimit("(", ")"):
self.write(self.unop[t.op.__class__.__name__]) self.write(self.unop[t.op.__class__.__name__])
if not self._py_ver_consistent: if not self._py_ver_consistent:
self.write(" ") self.write(" ")
@ -706,39 +702,34 @@ def _UnaryOp(self, t):
# a 32-bit machine (the first is an int, the second a long), and # a 32-bit machine (the first is an int, the second a long), and
# -7j is different from -(7j). (The first has real part 0.0, the second # -7j is different from -(7j). (The first has real part 0.0, the second
# has real part -0.0.) # has real part -0.0.)
self.write("(") with self.delimit("(", ")"):
self.dispatch(t.operand) self.dispatch(t.operand)
self.write(")")
else: else:
self.dispatch(t.operand) self.dispatch(t.operand)
self.write(")")
binop = { "Add":"+", "Sub":"-", "Mult":"*", "MatMult":"@", "Div":"/", "Mod":"%", binop = { "Add":"+", "Sub":"-", "Mult":"*", "MatMult":"@", "Div":"/", "Mod":"%",
"LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&", "LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&",
"FloorDiv":"//", "Pow": "**"} "FloorDiv":"//", "Pow": "**"}
def _BinOp(self, t): def _BinOp(self, t):
self.write("(") with self.delimit("(", ")"):
self.dispatch(t.left) self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ") self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right) self.dispatch(t.right)
self.write(")")
cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=", cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=",
"Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"} "Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"}
def _Compare(self, t): def _Compare(self, t):
self.write("(") with self.delimit("(", ")"):
self.dispatch(t.left) self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators): for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ") self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e) self.dispatch(e)
self.write(")")
boolops = {ast.And: 'and', ast.Or: 'or'} boolops = {ast.And: 'and', ast.Or: 'or'}
def _BoolOp(self, t): def _BoolOp(self, t):
self.write("(") with self.delimit("(", ")"):
s = " %s " % self.boolops[t.op.__class__] s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values) interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")
def _Attribute(self,t): def _Attribute(self,t):
self.dispatch(t.value) self.dispatch(t.value)
@ -752,10 +743,10 @@ def _Attribute(self,t):
def _Call(self, t): def _Call(self, t):
self.dispatch(t.func) self.dispatch(t.func)
self.write("(") with self.delimit("(", ")"):
comma = False comma = False
# move starred arguments last in Python 3.5+, for consistency w/earlier versions # starred arguments last in Python 3.5+, for consistency w/earlier versions
star_and_kwargs = [] star_and_kwargs = []
move_stars_last = sys.version_info[:2] >= (3, 5) move_stars_last = sys.version_info[:2] >= (3, 5)
@ -794,13 +785,10 @@ def _Call(self, t):
self.write("**") self.write("**")
self.dispatch(t.kwargs) self.dispatch(t.kwargs)
self.write(")")
def _Subscript(self, t): def _Subscript(self, t):
self.dispatch(t.value) self.dispatch(t.value)
self.write("[") with self.delimit("[", "]"):
self.dispatch(t.slice) self.dispatch(t.slice)
self.write("]")
def _Starred(self, t): def _Starred(self, t):
self.write("*") self.write("*")
@ -902,12 +890,11 @@ def _keyword(self, t):
self.dispatch(t.value) self.dispatch(t.value)
def _Lambda(self, t): def _Lambda(self, t):
self.write("(") with self.delimit("(", ")"):
self.write("lambda ") self.write("lambda ")
self.dispatch(t.args) self.dispatch(t.args)
self.write(": ") self.write(": ")
self.dispatch(t.body) self.dispatch(t.body)
self.write(")")
def _alias(self, t): def _alias(self, t):
self.write(t.name) self.write(t.name)