unparser: add block() context manager for indentation

This is a backport of a refactor from cpython 3.9
This commit is contained in:
Todd Gamblin 2021-12-21 16:15:48 -08:00 committed by Greg Becker
parent 2badd6500e
commit 5847eb1e65

View file

@ -1,12 +1,22 @@
"Usage: unparse.py <path to source file>" "Usage: unparse.py <path to source file>"
from __future__ import print_function, unicode_literals from __future__ import print_function, unicode_literals
import six
import sys
import ast import ast
import os import os
import sys
import tokenize import tokenize
from contextlib import contextmanager
import six
from six import StringIO from six import StringIO
# TODO: if we require Python 3.7, use its `nullcontext()`
def nullcontext():
yield
# Large float and imaginary literals get turned into infinities in the AST. # Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR. # We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
@ -74,14 +84,22 @@ def write(self, text):
"Append a piece of text to the current line." "Append a piece of text to the current line."
self.f.write(six.text_type(text)) self.f.write(six.text_type(text))
def enter(self): class _Block:
"Print ':', and increase the indentation." """A context manager for preparing the source for blocks. It adds
self.write(":") the character ':', increases the indentation on enter and decreases
self._indent += 1 the indentation on exit."""
def __init__(self, unparser):
self.unparser = unparser
def leave(self): def __enter__(self):
"Decrease the indentation level." self.unparser.write(":")
self._indent -= 1 self.unparser._indent += 1
def __exit__(self, exc_type, exc_value, traceback):
self.unparser._indent -= 1
def block(self):
return self._Block(self)
def dispatch(self, tree): def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T." "Dispatcher function, dispatching tree type T to method _T."
@ -279,35 +297,30 @@ def _Raise(self, t):
def _Try(self, t): def _Try(self, t):
self.fill("try") self.fill("try")
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
for ex in t.handlers: for ex in t.handlers:
self.dispatch(ex) self.dispatch(ex)
if t.orelse: if t.orelse:
self.fill("else") self.fill("else")
self.enter() with self.block():
self.dispatch(t.orelse) self.dispatch(t.orelse)
self.leave()
if t.finalbody: if t.finalbody:
self.fill("finally") self.fill("finally")
self.enter() with self.block():
self.dispatch(t.finalbody) self.dispatch(t.finalbody)
self.leave()
def _TryExcept(self, t): def _TryExcept(self, t):
self.fill("try") self.fill("try")
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
for ex in t.handlers: for ex in t.handlers:
self.dispatch(ex) self.dispatch(ex)
if t.orelse: if t.orelse:
self.fill("else") self.fill("else")
self.enter() with self.block():
self.dispatch(t.orelse) self.dispatch(t.orelse)
self.leave()
def _TryFinally(self, t): def _TryFinally(self, t):
if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept): if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept):
@ -315,14 +328,12 @@ def _TryFinally(self, t):
self.dispatch(t.body) self.dispatch(t.body)
else: else:
self.fill("try") self.fill("try")
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
self.fill("finally") self.fill("finally")
self.enter() with self.block():
self.dispatch(t.finalbody) self.dispatch(t.finalbody)
self.leave()
def _ExceptHandler(self, t): def _ExceptHandler(self, t):
self.fill("except") self.fill("except")
@ -335,9 +346,8 @@ def _ExceptHandler(self, t):
self.write(t.name) self.write(t.name)
else: else:
self.dispatch(t.name) self.dispatch(t.name)
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
def _ClassDef(self, t): def _ClassDef(self, t):
self.write("\n") self.write("\n")
@ -375,9 +385,8 @@ def _ClassDef(self, t):
self.write(", ") self.write(", ")
self.dispatch(t.bases[-1]) self.dispatch(t.bases[-1])
self.write(")") self.write(")")
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
def _FunctionDef(self, t): def _FunctionDef(self, t):
self.__FunctionDef_helper(t, "def") self.__FunctionDef_helper(t, "def")
@ -397,9 +406,8 @@ def __FunctionDef_helper(self, t, fill_suffix):
if getattr(t, "returns", False): if getattr(t, "returns", False):
self.write(" -> ") self.write(" -> ")
self.dispatch(t.returns) self.dispatch(t.returns)
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
def _For(self, t): def _For(self, t):
self.__For_helper("for ", t) self.__For_helper("for ", t)
@ -412,48 +420,41 @@ def __For_helper(self, fill, t):
self.dispatch(t.target) self.dispatch(t.target)
self.write(" in ") self.write(" in ")
self.dispatch(t.iter) self.dispatch(t.iter)
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
if t.orelse: if t.orelse:
self.fill("else") self.fill("else")
self.enter() with self.block():
self.dispatch(t.orelse) self.dispatch(t.orelse)
self.leave()
def _If(self, t): def _If(self, t):
self.fill("if ") self.fill("if ")
self.dispatch(t.test) self.dispatch(t.test)
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
# collapse nested ifs into equivalent elifs. # collapse nested ifs into equivalent elifs.
while (t.orelse and len(t.orelse) == 1 and while (t.orelse and len(t.orelse) == 1 and
isinstance(t.orelse[0], ast.If)): isinstance(t.orelse[0], ast.If)):
t = t.orelse[0] t = t.orelse[0]
self.fill("elif ") self.fill("elif ")
self.dispatch(t.test) self.dispatch(t.test)
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
# final else # final else
if t.orelse: if t.orelse:
self.fill("else") self.fill("else")
self.enter() with self.block():
self.dispatch(t.orelse) self.dispatch(t.orelse)
self.leave()
def _While(self, t): def _While(self, t):
self.fill("while ") self.fill("while ")
self.dispatch(t.test) self.dispatch(t.test)
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
if t.orelse: if t.orelse:
self.fill("else") self.fill("else")
self.enter() with self.block():
self.dispatch(t.orelse) self.dispatch(t.orelse)
self.leave()
def _generic_With(self, t, async_=False): def _generic_With(self, t, async_=False):
self.fill("async with " if async_ else "with ") self.fill("async with " if async_ else "with ")
@ -464,9 +465,8 @@ def _generic_With(self, t, async_=False):
if t.optional_vars: if t.optional_vars:
self.write(" as ") self.write(" as ")
self.dispatch(t.optional_vars) self.dispatch(t.optional_vars)
self.enter() with self.block():
self.dispatch(t.body) self.dispatch(t.body)
self.leave()
def _With(self, t): def _With(self, t):
self._generic_With(t) self._generic_With(t)