package_hash: fix handling of multimethods and add tests

Package hashing was not properly handling multimethods. In particular, it was removing
any functions that had decorators from the output, so we'd miss things like
`@run_after("install")`, etc.

There were also problems with handling multiple `@when`'s in a single file, and with
handling `@when` functions that *had* to be evaluated dynamically.

- [x] Rework static `@when` resolution for package hash
- [x] Ensure that functions with decorators are not removed from output
- [x] Add tests for many different @when scenarios (multiple @when's,
      combining with other decorators, default/no default, etc.)

Co-authored-by: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com>
This commit is contained in:
Todd Gamblin 2021-12-23 00:53:44 -08:00 committed by Greg Becker
parent 93a6c51d88
commit 800229a448
2 changed files with 260 additions and 34 deletions

View file

@ -150,3 +150,133 @@ def test_remove_directives():
for name in spack.directives.directive_names:
assert name not in unparsed
many_multimethods = """\
class Pkg:
def foo(self):
print("ONE")
@when("@1.0")
def foo(self):
print("TWO")
@when("@2.0")
@when(sys.platform == "darwin")
def foo(self):
print("THREE")
@when("@3.0")
def foo(self):
print("FOUR")
# this one should always stay
@run_after("install")
def some_function(self):
print("FIVE")
"""
def test_multimethod_resolution(tmpdir):
when_pkg = tmpdir.join("pkg.py")
with when_pkg.open("w") as f:
f.write(many_multimethods)
# all are false but the default
filtered = ph.canonical_source("pkg@4.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# we know first @when overrides default and others are false
filtered = ph.canonical_source("pkg@1.0", str(when_pkg))
assert "ONE" not in filtered
assert "TWO" in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# we know last @when overrides default and others are false
filtered = ph.canonical_source("pkg@3.0", str(when_pkg))
assert "ONE" not in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" in filtered
assert "FIVE" in filtered
# we don't know if default or THREE will win, include both
filtered = ph.canonical_source("pkg@2.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
more_dynamic_multimethods = """\
class Pkg:
@when(sys.platform == "darwin")
def foo(self):
print("ONE")
@when("@1.0")
def foo(self):
print("TWO")
# this one isn't dynamic, but an int fails the Spec parse,
# so it's kept because it has to be evaluated at runtime.
@when("@2.0")
@when(1)
def foo(self):
print("THREE")
@when("@3.0")
def foo(self):
print("FOUR")
# this one should always stay
@run_after("install")
def some_function(self):
print("FIVE")
"""
def test_more_dynamic_multimethod_resolution(tmpdir):
when_pkg = tmpdir.join("pkg.py")
with when_pkg.open("w") as f:
f.write(more_dynamic_multimethods)
# we know the first one is the only one that can win.
filtered = ph.canonical_source("pkg@4.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# now we have to include ONE and TWO because ONE may win dynamically.
filtered = ph.canonical_source("pkg@1.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" in filtered
assert "THREE" not in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered
# we know FOUR is true and TWO and THREE are false, but ONE may
# still win dynamically.
filtered = ph.canonical_source("pkg@3.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" not in filtered
assert "FOUR" in filtered
assert "FIVE" in filtered
# TWO and FOUR can't be satisfied, but ONE or THREE could win
filtered = ph.canonical_source("pkg@2.0", str(when_pkg))
assert "ONE" in filtered
assert "TWO" not in filtered
assert "THREE" in filtered
assert "FOUR" not in filtered
assert "FIVE" in filtered

View file

@ -11,7 +11,9 @@
import spack.package
import spack.repo
import spack.spec
import spack.util.hash
import spack.util.naming
from spack.util.unparse import unparse
class RemoveDocstrings(ast.NodeTransformer):
@ -82,70 +84,164 @@ def visit_ClassDef(self, node): # noqa
class TagMultiMethods(ast.NodeVisitor):
"""Tag @when-decorated methods in a spec."""
"""Tag @when-decorated methods in a package AST."""
def __init__(self, spec):
self.spec = spec
# map from function name to (implementation, condition_list) tuples
self.methods = {}
def visit_FunctionDef(self, node): # noqa
nodes = self.methods.setdefault(node.name, [])
if node.decorator_list:
dec = node.decorator_list[0]
def visit_FunctionDef(self, func): # noqa
conditions = []
for dec in func.decorator_list:
if isinstance(dec, ast.Call) and dec.func.id == 'when':
try:
# evaluate spec condition for any when's
cond = dec.args[0].s
nodes.append(
(node, self.spec.satisfies(cond, strict=True)))
conditions.append(self.spec.satisfies(cond, strict=True))
except AttributeError:
# In this case the condition for the 'when' decorator is
# not a string literal (for example it may be a Python
# variable name). Therefore the function is added
# unconditionally since we don't know whether the
# constraint applies or not.
nodes.append((node, None))
else:
nodes.append((node, None))
# variable name). We append None because we don't know
# whether the constraint applies or not, and it should be included
# unless some other constraint is False.
conditions.append(None)
# anything defined without conditions will overwrite prior definitions
if not conditions:
self.methods[func.name] = []
# add all discovered conditions on this node to the node list
impl_conditions = self.methods.setdefault(func.name, [])
impl_conditions.append((func, conditions))
# don't modify the AST -- return the untouched function node
return func
class ResolveMultiMethods(ast.NodeTransformer):
"""Remove methods which do not exist if their @when is not satisfied."""
"""Remove multi-methods when we know statically that they won't be used.
Say we have multi-methods like this::
class SomePackage:
def foo(self): print("implementation 1")
@when("@1.0")
def foo(self): print("implementation 2")
@when("@2.0")
@when(sys.platform == "darwin")
def foo(self): print("implementation 3")
@when("@3.0")
def foo(self): print("implementation 4")
The multimethod that will be chosen at runtime depends on the package spec and on
whether we're on the darwin platform *at build time* (the darwin condition for
implementation 3 is dynamic). We know the package spec statically; we don't know
statically what the runtime environment will be. We need to include things that can
possibly affect package behavior in the package hash, and we want to exclude things
when we know that they will not affect package behavior.
If we're at version 4.0, we know that implementation 1 will win, because some @when
for 2, 3, and 4 will be `False`. We should only include implementation 1.
If we're at version 1.0, we know that implementation 2 will win, because it
overrides implementation 1. We should only include implementation 2.
If we're at version 3.0, we know that implementation 4 will win, because it
overrides implementation 1 (the default), and some @when on all others will be
False.
If we're at version 2.0, it's a bit more complicated. We know we can remove
implementations 2 and 4, because their @when's will never be satisfied. But, the
choice between implementations 1 and 3 will happen at runtime (this is a bad example
because the spec itself has platform information, and we should prefer to use that,
but we allow arbitrary boolean expressions in @when's, so this example suffices).
For this case, we end up needing to include *both* implementation 1 and 3 in the
package hash, because either could be chosen.
"""
def __init__(self, methods):
self.methods = methods
def resolve(self, node):
if node.name not in self.methods:
raise PackageHashError(
"Future traversal visited new node: %s" % node.name)
def resolve(self, impl_conditions):
"""Given list of nodes and conditions, figure out which node will be chosen."""
result = []
default = None
for impl, conditions in impl_conditions:
# if there's a default implementation with no conditions, remember that.
if not conditions:
default = impl
result.append(default)
continue
result = None
for n, cond in self.methods[node.name]:
if cond:
return n
if cond is None:
result = n
# any known-false @when means the method won't be used
if any(c is False for c in conditions):
continue
# anything with all known-true conditions will be picked if it's first
if all(c is True for c in conditions):
if result and result[0] is default:
return [impl] # we know the first MM will always win
# if anything dynamic comes before it we don't know if it'll win,
# so just let this result get appended
# anything else has to be determined dynamically, so add it to a list
result.append(impl)
# if nothing was picked, the last definition wins.
return result
def visit_FunctionDef(self, node): # noqa
if self.resolve(node) is node:
node.decorator_list = []
return node
def visit_FunctionDef(self, func): # noqa
# if the function def wasn't visited on the first traversal there is a problem
assert func.name in self.methods, "Inconsistent package traversal!"
# if the function is a multimethod, need to resolve it statically
impl_conditions = self.methods[func.name]
resolutions = self.resolve(impl_conditions)
if not any(r is func for r in resolutions):
# multimethod did not resolve to this function; remove it
return None
# if we get here, this function is a possible resolution for a multi-method.
# it might be the only one, or there might be several that have to be evaluated
# dynamcially. Either way, we include the function.
# strip the when decorators (preserve the rest)
func.decorator_list = [
dec for dec in func.decorator_list
if not (isinstance(dec, ast.Call) and dec.func.id == 'when')
]
return func
def package_content(spec):
return ast.dump(package_ast(spec))
def canonical_source(spec, filename=None):
return unparse(package_ast(spec, filename=filename), py_ver_consistent=True)
def canonical_source_hash(spec, filename=None):
source = canonical_source(spec, filename)
return spack.util.hash.b32_hash(source)
def package_hash(spec, content=None):
if content is None:
content = package_content(spec)
return hashlib.sha256(content.encode('utf-8')).digest().lower()
def package_ast(spec):
def package_ast(spec, filename=None):
spec = spack.spec.Spec(spec)
if not filename:
filename = spack.repo.path.filename_for_package_name(spec.name)
with open(filename) as f:
text = f.read()
root = ast.parse(text)
@ -154,10 +250,10 @@ def package_ast(spec):
RemoveDirectives(spec).visit(root)
fmm = TagMultiMethods(spec)
fmm.visit(root)
tagger = TagMultiMethods(spec)
tagger.visit(root)
root = ResolveMultiMethods(fmm.methods).visit(root)
root = ResolveMultiMethods(tagger.methods).visit(root)
return root