Change multimethods to take first match instead of being rigid.
- Formerly required that one and only one spec match - Now allows first match in a list (more flexible and more intuitive) - introduces new bug that provides() doesn't do the correct thing when a version is not in a range that has been explicitly provided. - TODO: fix this.
This commit is contained in:
parent
36e6ef9fbd
commit
2f520d6119
5 changed files with 71 additions and 46 deletions
|
@ -3,6 +3,8 @@
|
||||||
Developer Guide
|
Developer Guide
|
||||||
=====================
|
=====================
|
||||||
|
|
||||||
|
This guide is intended for people who want to work on Spack's inner
|
||||||
|
workings. Right now it's pretty sparse.
|
||||||
|
|
||||||
Spec objects
|
Spec objects
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
|
@ -1060,17 +1060,17 @@ for example:
|
||||||
# the default, called when no @when specs match
|
# the default, called when no @when specs match
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@when('mpi@3:')
|
@when('^mpi@3:')
|
||||||
def setup_mpi(self):
|
def setup_mpi(self):
|
||||||
# this will be called when mpi is version 3 or higher
|
# this will be called when mpi is version 3 or higher
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@when('mpi@2:')
|
@when('^mpi@2:')
|
||||||
def setup_mpi(self):
|
def setup_mpi(self):
|
||||||
# this will be called when mpi is version 2 or higher
|
# this will be called when mpi is version 2 or higher
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@when('mpi@1:')
|
@when('^mpi@1:')
|
||||||
def setup_mpi(self):
|
def setup_mpi(self):
|
||||||
# this will be called when mpi is version 1 or higher
|
# this will be called when mpi is version 1 or higher
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -52,20 +52,20 @@ class SpecMultiMethod(object):
|
||||||
registers method versions with them.
|
registers method versions with them.
|
||||||
|
|
||||||
To register a method, you can do something like this:
|
To register a method, you can do something like this:
|
||||||
mf = SpecMultiMethod()
|
mm = SpecMultiMethod()
|
||||||
mf.register("^chaos_5_x86_64_ib", some_method)
|
mm.register("^chaos_5_x86_64_ib", some_method)
|
||||||
|
|
||||||
The object registered needs to be a Spec or some string that
|
The object registered needs to be a Spec or some string that
|
||||||
will parse to be a valid spec.
|
will parse to be a valid spec.
|
||||||
|
|
||||||
When the pmf is actually called, it selects a version of the
|
When the mm is actually called, it selects a version of the
|
||||||
method to call based on the sys_type of the object it is
|
method to call based on the sys_type of the object it is
|
||||||
called on.
|
called on.
|
||||||
|
|
||||||
See the docs for decorators below for more details.
|
See the docs for decorators below for more details.
|
||||||
"""
|
"""
|
||||||
def __init__(self, default=None):
|
def __init__(self, default=None):
|
||||||
self.method_map = {}
|
self.method_list = []
|
||||||
self.default = default
|
self.default = default
|
||||||
if default:
|
if default:
|
||||||
functools.update_wrapper(self, default)
|
functools.update_wrapper(self, default)
|
||||||
|
@ -73,7 +73,7 @@ def __init__(self, default=None):
|
||||||
|
|
||||||
def register(self, spec, method):
|
def register(self, spec, method):
|
||||||
"""Register a version of a method for a particular sys_type."""
|
"""Register a version of a method for a particular sys_type."""
|
||||||
self.method_map[spec] = method
|
self.method_list.append((spec, method))
|
||||||
|
|
||||||
if not hasattr(self, '__name__'):
|
if not hasattr(self, '__name__'):
|
||||||
functools.update_wrapper(self, method)
|
functools.update_wrapper(self, method)
|
||||||
|
@ -87,33 +87,25 @@ def __get__(self, obj, objtype):
|
||||||
|
|
||||||
|
|
||||||
def __call__(self, package_self, *args, **kwargs):
|
def __call__(self, package_self, *args, **kwargs):
|
||||||
"""Try to find a method that matches package_self.sys_type.
|
"""Find the first method with a spec that matches the
|
||||||
If none is found, call the default method that this was
|
package's spec. If none is found, call the default
|
||||||
initialized with. If there is no default, raise an error.
|
or if there is none, then raise a NoSuchMethodError.
|
||||||
"""
|
"""
|
||||||
spec = package_self.spec
|
for spec, method in self.method_list:
|
||||||
matching_specs = [s for s in self.method_map if s.satisfies(spec)]
|
if spec.satisfies(package_self.spec):
|
||||||
num_matches = len(matching_specs)
|
return method(package_self, *args, **kwargs)
|
||||||
if num_matches == 0:
|
|
||||||
if self.default is None:
|
|
||||||
raise NoSuchMethodError(type(package_self), self.__name__,
|
|
||||||
spec, self.method_map.keys())
|
|
||||||
else:
|
|
||||||
method = self.default
|
|
||||||
|
|
||||||
elif num_matches == 1:
|
|
||||||
method = self.method_map[matching_specs[0]]
|
|
||||||
|
|
||||||
|
if self.default:
|
||||||
|
return self.default(package_self, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise AmbiguousMethodError(type(package_self), self.__name__,
|
raise NoSuchMethodError(
|
||||||
spec, matching_specs)
|
type(package_self), self.__name__, spec,
|
||||||
|
[m[0] for m in self.method_list])
|
||||||
return method(package_self, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "SpecMultiMethod {\n\tdefault: %s,\n\tspecs: %s\n}" % (
|
return "SpecMultiMethod {\n\tdefault: %s,\n\tspecs: %s\n}" % (
|
||||||
self.default, self.method_map)
|
self.default, self.method_list)
|
||||||
|
|
||||||
|
|
||||||
class when(object):
|
class when(object):
|
||||||
|
@ -207,12 +199,3 @@ def __init__(self, cls, method_name, spec, possible_specs):
|
||||||
"Package %s does not support %s called with %s. Options are: %s"
|
"Package %s does not support %s called with %s. Options are: %s"
|
||||||
% (cls.__name__, method_name, spec,
|
% (cls.__name__, method_name, spec,
|
||||||
", ".join(str(s) for s in possible_specs)))
|
", ".join(str(s) for s in possible_specs)))
|
||||||
|
|
||||||
|
|
||||||
class AmbiguousMethodError(spack.error.SpackError):
|
|
||||||
"""Raised when we can't find a version of a multi-method."""
|
|
||||||
def __init__(self, cls, method_name, spec, matching_specs):
|
|
||||||
super(AmbiguousMethodError, self).__init__(
|
|
||||||
"Package %s has multiple versions of %s that match %s: %s"
|
|
||||||
% (cls.__name__, method_name, spec,
|
|
||||||
",".join(str(s) for s in matching_specs)))
|
|
||||||
|
|
|
@ -27,17 +27,36 @@ def no_version_2(self):
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# These functions overlap too much, so there is ambiguity
|
# These functions overlap, so there is ambiguity, but we'll take
|
||||||
|
# the first one.
|
||||||
#
|
#
|
||||||
@when('@:4')
|
@when('@:4')
|
||||||
def version_overlap(self):
|
def version_overlap(self):
|
||||||
pass
|
return 1
|
||||||
|
|
||||||
@when('@2:')
|
@when('@2:')
|
||||||
def version_overlap(self):
|
def version_overlap(self):
|
||||||
pass
|
return 2
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# More complicated case with cascading versions.
|
||||||
|
#
|
||||||
|
def mpi_version(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@when('^mpi@3:')
|
||||||
|
def mpi_version(self):
|
||||||
|
return 3
|
||||||
|
|
||||||
|
@when('^mpi@2:')
|
||||||
|
def mpi_version(self):
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@when('^mpi@1:')
|
||||||
|
def mpi_version(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Use these to test whether the default method is called when no
|
# Use these to test whether the default method is called when no
|
||||||
|
|
|
@ -30,8 +30,30 @@ def test_one_version_match(self):
|
||||||
|
|
||||||
|
|
||||||
def test_version_overlap(self):
|
def test_version_overlap(self):
|
||||||
pkg = packages.get('multimethod@3.0')
|
pkg = packages.get('multimethod@2.0')
|
||||||
self.assertRaises(AmbiguousMethodError, pkg.version_overlap)
|
self.assertEqual(pkg.version_overlap(), 1)
|
||||||
|
|
||||||
|
pkg = packages.get('multimethod@5.0')
|
||||||
|
self.assertEqual(pkg.version_overlap(), 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mpi_version(self):
|
||||||
|
pkg = packages.get('multimethod^mpich@3.0.4')
|
||||||
|
self.assertEqual(pkg.mpi_version(), 3)
|
||||||
|
|
||||||
|
pkg = packages.get('multimethod^mpich2@1.2')
|
||||||
|
self.assertEqual(pkg.mpi_version(), 2)
|
||||||
|
|
||||||
|
pkg = packages.get('multimethod^mpich@1.0')
|
||||||
|
self.assertEqual(pkg.mpi_version(), 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_undefined_mpi_version(self):
|
||||||
|
# This currently fails because provides() doesn't do
|
||||||
|
# the right thing undefined version ranges.
|
||||||
|
# TODO: fix this.
|
||||||
|
pkg = packages.get('multimethod^mpich@0.4')
|
||||||
|
self.assertEqual(pkg.mpi_version(), 0)
|
||||||
|
|
||||||
|
|
||||||
def test_default_works(self):
|
def test_default_works(self):
|
||||||
|
@ -69,11 +91,10 @@ def test_dependency_match(self):
|
||||||
pkg = packages.get('multimethod^mpich')
|
pkg = packages.get('multimethod^mpich')
|
||||||
self.assertEqual(pkg.different_by_dep(), 'mpich')
|
self.assertEqual(pkg.different_by_dep(), 'mpich')
|
||||||
|
|
||||||
|
# If we try to switch on some entirely different dep, it's ambiguous,
|
||||||
def test_ambiguous_dep(self):
|
# but should take the first option
|
||||||
"""If we try to switch on some entirely different dep, it's ambiguous"""
|
|
||||||
pkg = packages.get('multimethod^foobar')
|
pkg = packages.get('multimethod^foobar')
|
||||||
self.assertRaises(AmbiguousMethodError, pkg.different_by_dep)
|
self.assertEqual(pkg.different_by_dep(), 'mpich')
|
||||||
|
|
||||||
|
|
||||||
def test_virtual_dep_match(self):
|
def test_virtual_dep_match(self):
|
||||||
|
|
Loading…
Reference in a new issue