Change multimethods to take first match instead of being rigid.

- Formerly required that one and only one spec match
- Now allows first match in a list (more flexible and more intuitive)

- introduces new bug that provides() doesn't do the correct thing
  when a version is not in a range that has been explicitly provided.
- TODO: fix this.
This commit is contained in:
Todd Gamblin 2014-01-07 09:51:51 +01:00
parent 36e6ef9fbd
commit 2f520d6119
5 changed files with 71 additions and 46 deletions

View file

@ -3,6 +3,8 @@
Developer Guide
=====================
This guide is intended for people who want to work on Spack's inner
workings. Right now it's pretty sparse.
Spec objects
-------------------------

View file

@ -1060,17 +1060,17 @@ for example:
# the default, called when no @when specs match
pass
@when('mpi@3:')
@when('^mpi@3:')
def setup_mpi(self):
# this will be called when mpi is version 3 or higher
pass
@when('mpi@2:')
@when('^mpi@2:')
def setup_mpi(self):
# this will be called when mpi is version 2 or higher
pass
@when('mpi@1:')
@when('^mpi@1:')
def setup_mpi(self):
# this will be called when mpi is version 1 or higher
pass

View file

@ -52,20 +52,20 @@ class SpecMultiMethod(object):
registers method versions with them.
To register a method, you can do something like this:
mf = SpecMultiMethod()
mf.register("^chaos_5_x86_64_ib", some_method)
mm = SpecMultiMethod()
mm.register("^chaos_5_x86_64_ib", some_method)
The object registered needs to be a Spec or some string that
will parse to be a valid spec.
When the pmf is actually called, it selects a version of the
When the mm is actually called, it selects a version of the
method to call based on the sys_type of the object it is
called on.
See the docs for decorators below for more details.
"""
def __init__(self, default=None):
self.method_map = {}
self.method_list = []
self.default = default
if default:
functools.update_wrapper(self, default)
@ -73,7 +73,7 @@ def __init__(self, default=None):
def register(self, spec, method):
"""Register a version of a method for a particular sys_type."""
self.method_map[spec] = method
self.method_list.append((spec, method))
if not hasattr(self, '__name__'):
functools.update_wrapper(self, method)
@ -87,33 +87,25 @@ def __get__(self, obj, objtype):
def __call__(self, package_self, *args, **kwargs):
"""Try to find a method that matches package_self.sys_type.
If none is found, call the default method that this was
initialized with. If there is no default, raise an error.
"""Find the first method with a spec that matches the
package's spec. If none is found, call the default
or if there is none, then raise a NoSuchMethodError.
"""
spec = package_self.spec
matching_specs = [s for s in self.method_map if s.satisfies(spec)]
num_matches = len(matching_specs)
if num_matches == 0:
if self.default is None:
raise NoSuchMethodError(type(package_self), self.__name__,
spec, self.method_map.keys())
else:
method = self.default
elif num_matches == 1:
method = self.method_map[matching_specs[0]]
else:
raise AmbiguousMethodError(type(package_self), self.__name__,
spec, matching_specs)
for spec, method in self.method_list:
if spec.satisfies(package_self.spec):
return method(package_self, *args, **kwargs)
if self.default:
return self.default(package_self, *args, **kwargs)
else:
raise NoSuchMethodError(
type(package_self), self.__name__, spec,
[m[0] for m in self.method_list])
def __str__(self):
return "SpecMultiMethod {\n\tdefault: %s,\n\tspecs: %s\n}" % (
self.default, self.method_map)
self.default, self.method_list)
class when(object):
@ -207,12 +199,3 @@ def __init__(self, cls, method_name, spec, possible_specs):
"Package %s does not support %s called with %s. Options are: %s"
% (cls.__name__, method_name, spec,
", ".join(str(s) for s in possible_specs)))
class AmbiguousMethodError(spack.error.SpackError):
"""Raised when we can't find a version of a multi-method."""
def __init__(self, cls, method_name, spec, matching_specs):
super(AmbiguousMethodError, self).__init__(
"Package %s has multiple versions of %s that match %s: %s"
% (cls.__name__, method_name, spec,
",".join(str(s) for s in matching_specs)))

View file

@ -27,17 +27,36 @@ def no_version_2(self):
#
# These functions overlap too much, so there is ambiguity
# These functions overlap, so there is ambiguity, but we'll take
# the first one.
#
@when('@:4')
def version_overlap(self):
pass
return 1
@when('@2:')
def version_overlap(self):
pass
return 2
#
# More complicated case with cascading versions.
#
def mpi_version(self):
return 0
@when('^mpi@3:')
def mpi_version(self):
return 3
@when('^mpi@2:')
def mpi_version(self):
return 2
@when('^mpi@1:')
def mpi_version(self):
return 1
#
# Use these to test whether the default method is called when no

View file

@ -30,8 +30,30 @@ def test_one_version_match(self):
def test_version_overlap(self):
pkg = packages.get('multimethod@3.0')
self.assertRaises(AmbiguousMethodError, pkg.version_overlap)
pkg = packages.get('multimethod@2.0')
self.assertEqual(pkg.version_overlap(), 1)
pkg = packages.get('multimethod@5.0')
self.assertEqual(pkg.version_overlap(), 2)
def test_mpi_version(self):
pkg = packages.get('multimethod^mpich@3.0.4')
self.assertEqual(pkg.mpi_version(), 3)
pkg = packages.get('multimethod^mpich2@1.2')
self.assertEqual(pkg.mpi_version(), 2)
pkg = packages.get('multimethod^mpich@1.0')
self.assertEqual(pkg.mpi_version(), 1)
def test_undefined_mpi_version(self):
# This currently fails because provides() doesn't do
# the right thing undefined version ranges.
# TODO: fix this.
pkg = packages.get('multimethod^mpich@0.4')
self.assertEqual(pkg.mpi_version(), 0)
def test_default_works(self):
@ -69,11 +91,10 @@ def test_dependency_match(self):
pkg = packages.get('multimethod^mpich')
self.assertEqual(pkg.different_by_dep(), 'mpich')
def test_ambiguous_dep(self):
"""If we try to switch on some entirely different dep, it's ambiguous"""
# If we try to switch on some entirely different dep, it's ambiguous,
# but should take the first option
pkg = packages.get('multimethod^foobar')
self.assertRaises(AmbiguousMethodError, pkg.different_by_dep)
self.assertEqual(pkg.different_by_dep(), 'mpich')
def test_virtual_dep_match(self):