Change multimethods to take first match instead of being rigid.
- Formerly required that one and only one spec match - Now allows first match in a list (more flexible and more intuitive) - introduces new bug that provides() doesn't do the correct thing when a version is not in a range that has been explicitly provided. - TODO: fix this.
This commit is contained in:
parent
36e6ef9fbd
commit
2f520d6119
5 changed files with 71 additions and 46 deletions
|
@ -3,6 +3,8 @@
|
|||
Developer Guide
|
||||
=====================
|
||||
|
||||
This guide is intended for people who want to work on Spack's inner
|
||||
workings. Right now it's pretty sparse.
|
||||
|
||||
Spec objects
|
||||
-------------------------
|
||||
|
|
|
@ -1060,17 +1060,17 @@ for example:
|
|||
# the default, called when no @when specs match
|
||||
pass
|
||||
|
||||
@when('mpi@3:')
|
||||
@when('^mpi@3:')
|
||||
def setup_mpi(self):
|
||||
# this will be called when mpi is version 3 or higher
|
||||
pass
|
||||
|
||||
@when('mpi@2:')
|
||||
@when('^mpi@2:')
|
||||
def setup_mpi(self):
|
||||
# this will be called when mpi is version 2 or higher
|
||||
pass
|
||||
|
||||
@when('mpi@1:')
|
||||
@when('^mpi@1:')
|
||||
def setup_mpi(self):
|
||||
# this will be called when mpi is version 1 or higher
|
||||
pass
|
||||
|
|
|
@ -52,20 +52,20 @@ class SpecMultiMethod(object):
|
|||
registers method versions with them.
|
||||
|
||||
To register a method, you can do something like this:
|
||||
mf = SpecMultiMethod()
|
||||
mf.register("^chaos_5_x86_64_ib", some_method)
|
||||
mm = SpecMultiMethod()
|
||||
mm.register("^chaos_5_x86_64_ib", some_method)
|
||||
|
||||
The object registered needs to be a Spec or some string that
|
||||
will parse to be a valid spec.
|
||||
|
||||
When the pmf is actually called, it selects a version of the
|
||||
When the mm is actually called, it selects a version of the
|
||||
method to call based on the sys_type of the object it is
|
||||
called on.
|
||||
|
||||
See the docs for decorators below for more details.
|
||||
"""
|
||||
def __init__(self, default=None):
|
||||
self.method_map = {}
|
||||
self.method_list = []
|
||||
self.default = default
|
||||
if default:
|
||||
functools.update_wrapper(self, default)
|
||||
|
@ -73,7 +73,7 @@ def __init__(self, default=None):
|
|||
|
||||
def register(self, spec, method):
|
||||
"""Register a version of a method for a particular sys_type."""
|
||||
self.method_map[spec] = method
|
||||
self.method_list.append((spec, method))
|
||||
|
||||
if not hasattr(self, '__name__'):
|
||||
functools.update_wrapper(self, method)
|
||||
|
@ -87,33 +87,25 @@ def __get__(self, obj, objtype):
|
|||
|
||||
|
||||
def __call__(self, package_self, *args, **kwargs):
|
||||
"""Try to find a method that matches package_self.sys_type.
|
||||
If none is found, call the default method that this was
|
||||
initialized with. If there is no default, raise an error.
|
||||
"""Find the first method with a spec that matches the
|
||||
package's spec. If none is found, call the default
|
||||
or if there is none, then raise a NoSuchMethodError.
|
||||
"""
|
||||
spec = package_self.spec
|
||||
matching_specs = [s for s in self.method_map if s.satisfies(spec)]
|
||||
num_matches = len(matching_specs)
|
||||
if num_matches == 0:
|
||||
if self.default is None:
|
||||
raise NoSuchMethodError(type(package_self), self.__name__,
|
||||
spec, self.method_map.keys())
|
||||
else:
|
||||
method = self.default
|
||||
|
||||
elif num_matches == 1:
|
||||
method = self.method_map[matching_specs[0]]
|
||||
|
||||
else:
|
||||
raise AmbiguousMethodError(type(package_self), self.__name__,
|
||||
spec, matching_specs)
|
||||
|
||||
for spec, method in self.method_list:
|
||||
if spec.satisfies(package_self.spec):
|
||||
return method(package_self, *args, **kwargs)
|
||||
|
||||
if self.default:
|
||||
return self.default(package_self, *args, **kwargs)
|
||||
else:
|
||||
raise NoSuchMethodError(
|
||||
type(package_self), self.__name__, spec,
|
||||
[m[0] for m in self.method_list])
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "SpecMultiMethod {\n\tdefault: %s,\n\tspecs: %s\n}" % (
|
||||
self.default, self.method_map)
|
||||
self.default, self.method_list)
|
||||
|
||||
|
||||
class when(object):
|
||||
|
@ -207,12 +199,3 @@ def __init__(self, cls, method_name, spec, possible_specs):
|
|||
"Package %s does not support %s called with %s. Options are: %s"
|
||||
% (cls.__name__, method_name, spec,
|
||||
", ".join(str(s) for s in possible_specs)))
|
||||
|
||||
|
||||
class AmbiguousMethodError(spack.error.SpackError):
|
||||
"""Raised when we can't find a version of a multi-method."""
|
||||
def __init__(self, cls, method_name, spec, matching_specs):
|
||||
super(AmbiguousMethodError, self).__init__(
|
||||
"Package %s has multiple versions of %s that match %s: %s"
|
||||
% (cls.__name__, method_name, spec,
|
||||
",".join(str(s) for s in matching_specs)))
|
||||
|
|
|
@ -27,17 +27,36 @@ def no_version_2(self):
|
|||
|
||||
|
||||
#
|
||||
# These functions overlap too much, so there is ambiguity
|
||||
# These functions overlap, so there is ambiguity, but we'll take
|
||||
# the first one.
|
||||
#
|
||||
@when('@:4')
|
||||
def version_overlap(self):
|
||||
pass
|
||||
return 1
|
||||
|
||||
@when('@2:')
|
||||
def version_overlap(self):
|
||||
pass
|
||||
return 2
|
||||
|
||||
|
||||
#
|
||||
# More complicated case with cascading versions.
|
||||
#
|
||||
def mpi_version(self):
|
||||
return 0
|
||||
|
||||
@when('^mpi@3:')
|
||||
def mpi_version(self):
|
||||
return 3
|
||||
|
||||
@when('^mpi@2:')
|
||||
def mpi_version(self):
|
||||
return 2
|
||||
|
||||
@when('^mpi@1:')
|
||||
def mpi_version(self):
|
||||
return 1
|
||||
|
||||
|
||||
#
|
||||
# Use these to test whether the default method is called when no
|
||||
|
|
|
@ -30,8 +30,30 @@ def test_one_version_match(self):
|
|||
|
||||
|
||||
def test_version_overlap(self):
|
||||
pkg = packages.get('multimethod@3.0')
|
||||
self.assertRaises(AmbiguousMethodError, pkg.version_overlap)
|
||||
pkg = packages.get('multimethod@2.0')
|
||||
self.assertEqual(pkg.version_overlap(), 1)
|
||||
|
||||
pkg = packages.get('multimethod@5.0')
|
||||
self.assertEqual(pkg.version_overlap(), 2)
|
||||
|
||||
|
||||
def test_mpi_version(self):
|
||||
pkg = packages.get('multimethod^mpich@3.0.4')
|
||||
self.assertEqual(pkg.mpi_version(), 3)
|
||||
|
||||
pkg = packages.get('multimethod^mpich2@1.2')
|
||||
self.assertEqual(pkg.mpi_version(), 2)
|
||||
|
||||
pkg = packages.get('multimethod^mpich@1.0')
|
||||
self.assertEqual(pkg.mpi_version(), 1)
|
||||
|
||||
|
||||
def test_undefined_mpi_version(self):
|
||||
# This currently fails because provides() doesn't do
|
||||
# the right thing undefined version ranges.
|
||||
# TODO: fix this.
|
||||
pkg = packages.get('multimethod^mpich@0.4')
|
||||
self.assertEqual(pkg.mpi_version(), 0)
|
||||
|
||||
|
||||
def test_default_works(self):
|
||||
|
@ -69,11 +91,10 @@ def test_dependency_match(self):
|
|||
pkg = packages.get('multimethod^mpich')
|
||||
self.assertEqual(pkg.different_by_dep(), 'mpich')
|
||||
|
||||
|
||||
def test_ambiguous_dep(self):
|
||||
"""If we try to switch on some entirely different dep, it's ambiguous"""
|
||||
# If we try to switch on some entirely different dep, it's ambiguous,
|
||||
# but should take the first option
|
||||
pkg = packages.get('multimethod^foobar')
|
||||
self.assertRaises(AmbiguousMethodError, pkg.different_by_dep)
|
||||
self.assertEqual(pkg.different_by_dep(), 'mpich')
|
||||
|
||||
|
||||
def test_virtual_dep_match(self):
|
||||
|
|
Loading…
Reference in a new issue