SPACK-2: Multimethods for specs.
- multi_function.py -> multimethod.py - Added @when decorator, which takes a spec and implements matching for method dispatch - Added multimethod unit test, covers basic cases.
This commit is contained in:
parent
99b05fd571
commit
f7706d231d
11 changed files with 363 additions and 220 deletions
|
@ -4,4 +4,4 @@
|
|||
|
||||
from package import Package
|
||||
from relations import depends_on, provides
|
||||
from multi_function import platform
|
||||
from multimethod import when
|
||||
|
|
|
@ -1,147 +0,0 @@
|
|||
"""This module contains utilities for using multi-functions in spack.
|
||||
You can think of multi-functions like overloaded functions -- they're
|
||||
functions with the same name, and we need to select a version of the
|
||||
function based on some criteria. e.g., for overloaded functions, you
|
||||
would select a version of the function to call based on the types of
|
||||
its arguments.
|
||||
|
||||
For spack, we might want to select a version of the function based on
|
||||
the platform we want to build a package for, or based on the versions
|
||||
of the dependencies of the package.
|
||||
"""
|
||||
import sys
|
||||
import functools
|
||||
|
||||
import spack.architecture
|
||||
import spack.error as serr
|
||||
|
||||
class NoSuchVersionError(serr.SpackError):
|
||||
"""Raised when we can't find a version of a function for a platform."""
|
||||
def __init__(self, fun_name, sys_type):
|
||||
super(NoSuchVersionError, self).__init__(
|
||||
"No version of %s found for %s!" % (fun_name, sys_type))
|
||||
|
||||
|
||||
class PlatformMultiFunction(object):
|
||||
"""This is a callable type for storing a collection of versions
|
||||
of an instance method. The platform decorator (see docs below)
|
||||
creates PlatformMultiFunctions and registers function versions
|
||||
with them.
|
||||
|
||||
To register a function, you can do something like this:
|
||||
pmf = PlatformMultiFunction()
|
||||
pmf.regsiter("chaos_5_x86_64_ib", some_function)
|
||||
|
||||
When the pmf is actually called, it selects a version of
|
||||
the function to call based on the sys_type of the object
|
||||
it is called on.
|
||||
|
||||
See the docs for the platform decorator for more details.
|
||||
"""
|
||||
def __init__(self, default=None):
|
||||
self.function_map = {}
|
||||
self.default = default
|
||||
if default:
|
||||
self.__name__ = default.__name__
|
||||
|
||||
def register(self, platform, function):
|
||||
"""Register a version of a function for a particular sys_type."""
|
||||
self.function_map[platform] = function
|
||||
if not hasattr(self, '__name__'):
|
||||
self.__name__ = function.__name__
|
||||
else:
|
||||
assert(self.__name__ == function.__name__)
|
||||
|
||||
def __get__(self, obj, objtype):
|
||||
"""This makes __call__ support instance methods."""
|
||||
return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, package_self, *args, **kwargs):
|
||||
"""Try to find a function that matches package_self.sys_type.
|
||||
If none is found, call the default function that this was
|
||||
initialized with. If there is no default, raise an error.
|
||||
"""
|
||||
# TODO: make this work with specs.
|
||||
sys_type = package_self.sys_type
|
||||
function = self.function_map.get(sys_type, self.default)
|
||||
if function:
|
||||
function(package_self, *args, **kwargs)
|
||||
else:
|
||||
raise NoSuchVersionError(self.__name__, sys_type)
|
||||
|
||||
def __str__(self):
|
||||
return "<%s, %s>" % (self.default, self.function_map)
|
||||
|
||||
|
||||
class platform(object):
|
||||
"""This annotation lets packages declare platform-specific versions
|
||||
of functions like install(). For example:
|
||||
|
||||
class SomePackage(Package):
|
||||
...
|
||||
|
||||
def install(self, prefix):
|
||||
# Do default install
|
||||
|
||||
@platform('chaos_5_x86_64_ib')
|
||||
def install(self, prefix):
|
||||
# This will be executed instead of the default install if
|
||||
# the package's sys_type() is chaos_5_x86_64_ib.
|
||||
|
||||
@platform('bgqos_0")
|
||||
def install(self, prefix):
|
||||
# This will be executed if the package's sys_type is bgqos_0
|
||||
|
||||
This allows each package to have a default version of install() AND
|
||||
specialized versions for particular platforms. The version that is
|
||||
called depends on the sys_type of SomePackage.
|
||||
|
||||
Note that this works for functions other than install, as well. So,
|
||||
if you only have part of the install that is platform specific, you
|
||||
could do this:
|
||||
|
||||
class SomePackage(Package):
|
||||
...
|
||||
|
||||
def setup(self):
|
||||
# do nothing in the default case
|
||||
pass
|
||||
|
||||
@platform('chaos_5_x86_64_ib')
|
||||
def setup(self):
|
||||
# do something for x86_64
|
||||
|
||||
def install(self, prefix):
|
||||
# Do common install stuff
|
||||
self.setup()
|
||||
# Do more common install stuff
|
||||
|
||||
If there is no specialized version for the package's sys_type, the
|
||||
default (un-decorated) version will be called. If there is no default
|
||||
version and no specialized version, the call raises a
|
||||
NoSuchVersionError.
|
||||
|
||||
Note that the default version of install() must *always* come first.
|
||||
Otherwise it will override all of the platform-specific versions.
|
||||
There's not much we can do to get around this because of the way
|
||||
decorators work.
|
||||
"""
|
||||
class platform(object):
|
||||
def __init__(self, sys_type):
|
||||
self.sys_type = sys_type
|
||||
|
||||
def __call__(self, fun):
|
||||
# Record the sys_type as an attribute on this function
|
||||
fun.sys_type = self.sys_type
|
||||
|
||||
# Get the first definition of the function in the calling scope
|
||||
calling_frame = sys._getframe(1).f_locals
|
||||
original_fun = calling_frame.get(fun.__name__)
|
||||
|
||||
# Create a multifunction out of the original function if it
|
||||
# isn't one already.
|
||||
if not type(original_fun) == PlatformMultiFunction:
|
||||
original_fun = PlatformMultiFunction(original_fun)
|
||||
|
||||
original_fun.register(self.sys_type, fun)
|
||||
return original_fun
|
211
lib/spack/spack/multimethod.py
Normal file
211
lib/spack/spack/multimethod.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
"""This module contains utilities for using multi-methods in
|
||||
spack. You can think of multi-methods like overloaded methods --
|
||||
they're methods with the same name, and we need to select a version
|
||||
of the method based on some criteria. e.g., for overloaded
|
||||
methods, you would select a version of the method to call based on
|
||||
the types of its arguments.
|
||||
|
||||
In spack, multi-methods are used to ease the life of package
|
||||
authors. They allow methods like install() (or other methods
|
||||
called by install()) to declare multiple versions to be called when
|
||||
the package is instantiated with different specs. e.g., if the
|
||||
package is built with OpenMPI on x86_64,, you might want to call a
|
||||
different install method than if it was built for mpich2 on
|
||||
BlueGene/Q. Likewise, you might want to do a different type of
|
||||
install for different versions of the package.
|
||||
|
||||
Multi-methods provide a simple decorator-based syntax for this that
|
||||
avoids overly complicated rat nests of if statements. Obviously,
|
||||
depending on the scenario, regular old conditionals might be clearer,
|
||||
so package authors should use their judgement.
|
||||
"""
|
||||
import sys
|
||||
import functools
|
||||
import collections
|
||||
|
||||
import spack.architecture
|
||||
import spack.error
|
||||
from spack.util.lang import *
|
||||
from spack.spec import parse_local_spec
|
||||
|
||||
|
||||
class SpecMultiMethod(object):
|
||||
"""This implements a multi-method for Spack specs. Packages are
|
||||
instantiated with a particular spec, and you may want to
|
||||
execute different versions of methods based on what the spec
|
||||
looks like. For example, you might want to call a different
|
||||
version of install() for one platform than you call on another.
|
||||
|
||||
The SpecMultiMethod class implements a callable object that
|
||||
handles method dispatch. When it is called, it looks through
|
||||
registered methods and their associated specs, and it tries
|
||||
to find one that matches the package's spec. If it finds one
|
||||
(and only one), it will call that method.
|
||||
|
||||
The package author is responsible for ensuring that only one
|
||||
condition on multi-methods ever evaluates to true. If
|
||||
multiple methods evaluate to true, this will raise an
|
||||
exception.
|
||||
|
||||
This is intended for use with decorators (see below). The
|
||||
decorator (see docs below) creates SpecMultiMethods and
|
||||
registers method versions with them.
|
||||
|
||||
To register a method, you can do something like this:
|
||||
mf = SpecMultiMethod()
|
||||
mf.regsiter("^chaos_5_x86_64_ib", some_method)
|
||||
|
||||
The object registered needs to be a Spec or some string that
|
||||
will parse to be a valid spec.
|
||||
|
||||
When the pmf is actually called, it selects a version of the
|
||||
method to call based on the sys_type of the object it is
|
||||
called on.
|
||||
|
||||
See the docs for decorators below for more details.
|
||||
"""
|
||||
def __init__(self, default=None):
|
||||
self.method_map = {}
|
||||
self.default = default
|
||||
if default:
|
||||
functools.update_wrapper(self, default)
|
||||
|
||||
|
||||
def register(self, spec, method):
|
||||
"""Register a version of a method for a particular sys_type."""
|
||||
self.method_map[spec] = method
|
||||
|
||||
if not hasattr(self, '__name__'):
|
||||
functools.update_wrapper(self, method)
|
||||
else:
|
||||
assert(self.__name__ == method.__name__)
|
||||
|
||||
|
||||
def __get__(self, obj, objtype):
|
||||
"""This makes __call__ support instance methods."""
|
||||
return functools.partial(self.__call__, obj)
|
||||
|
||||
|
||||
def __call__(self, package_self, *args, **kwargs):
|
||||
"""Try to find a method that matches package_self.sys_type.
|
||||
If none is found, call the default method that this was
|
||||
initialized with. If there is no default, raise an error.
|
||||
"""
|
||||
spec = package_self.spec
|
||||
matching_specs = [s for s in self.method_map if s.satisfies(spec)]
|
||||
|
||||
if not matching_specs and self.default is None:
|
||||
raise NoSuchMethodVersionError(type(package_self), self.__name__,
|
||||
spec, self.method_map.keys())
|
||||
elif len(matching_specs) > 1:
|
||||
raise AmbiguousMethodVersionError(type(package_self), self.__name__,
|
||||
spec, matching_specs)
|
||||
|
||||
method = self.method_map[matching_specs[0]]
|
||||
return method(package_self, *args, **kwargs)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "<%s, %s>" % (self.default, self.method_map)
|
||||
|
||||
|
||||
class when(object):
|
||||
"""This annotation lets packages declare multiple versions of
|
||||
methods like install() that depend on the package's spec.
|
||||
For example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
class SomePackage(Package):
|
||||
...
|
||||
|
||||
def install(self, prefix):
|
||||
# Do default install
|
||||
|
||||
@when('=chaos_5_x86_64_ib')
|
||||
def install(self, prefix):
|
||||
# This will be executed instead of the default install if
|
||||
# the package's sys_type() is chaos_5_x86_64_ib.
|
||||
|
||||
@when('=bgqos_0")
|
||||
def install(self, prefix):
|
||||
# This will be executed if the package's sys_type is bgqos_0
|
||||
|
||||
This allows each package to have a default version of install() AND
|
||||
specialized versions for particular platforms. The version that is
|
||||
called depends on the sys_type of SomePackage.
|
||||
|
||||
Note that this works for methods other than install, as well. So,
|
||||
if you only have part of the install that is platform specific, you
|
||||
could do this:
|
||||
|
||||
class SomePackage(Package):
|
||||
...
|
||||
# virtual dependence on MPI.
|
||||
# could resolve to mpich, mpich2, OpenMPI
|
||||
depends_on('mpi')
|
||||
|
||||
def setup(self):
|
||||
# do nothing in the default case
|
||||
pass
|
||||
|
||||
@when('^openmpi')
|
||||
def setup(self):
|
||||
# do something special when this is built with OpenMPI for
|
||||
# its MPI implementations.
|
||||
|
||||
|
||||
def install(self, prefix):
|
||||
# Do common install stuff
|
||||
self.setup()
|
||||
# Do more common install stuff
|
||||
|
||||
There must be one (and only one) @when clause that matches the
|
||||
package's spec. If there is more than one, or if none match,
|
||||
then the method will raise an exception when it's called.
|
||||
|
||||
Note that the default version of decorated methods must
|
||||
*always* come first. Otherwise it will override all of the
|
||||
platform-specific versions. There's not much we can do to get
|
||||
around this because of the way decorators work.
|
||||
"""
|
||||
class when(object):
|
||||
def __init__(self, spec):
|
||||
pkg = get_calling_package_name()
|
||||
self.spec = parse_local_spec(spec, pkg)
|
||||
|
||||
def __call__(self, method):
|
||||
# Get the first definition of the method in the calling scope
|
||||
original_method = caller_locals().get(method.__name__)
|
||||
|
||||
# Create a multimethod out of the original method if it
|
||||
# isn't one already.
|
||||
if not type(original_method) == SpecMultiMethod:
|
||||
original_method = SpecMultiMethod(original_method)
|
||||
|
||||
original_method.register(self.spec, method)
|
||||
return original_method
|
||||
|
||||
|
||||
class MultiMethodError(spack.error.SpackError):
|
||||
"""Superclass for multimethod dispatch errors"""
|
||||
def __init__(self, message):
|
||||
super(MultiMethodError, self).__init__(message)
|
||||
|
||||
|
||||
class NoSuchMethodVersionError(spack.error.SpackError):
|
||||
"""Raised when we can't find a version of a multi-method."""
|
||||
def __init__(self, cls, method_name, spec, possible_specs):
|
||||
super(NoSuchMethodVersionError, self).__init__(
|
||||
"Package %s does not support %s called with %s. Options are: %s"
|
||||
% (cls.__name__, method_name, spec,
|
||||
", ".join(str(s) for s in possible_specs)))
|
||||
|
||||
|
||||
class AmbiguousMethodVersionError(spack.error.SpackError):
|
||||
"""Raised when we can't find a version of a multi-method."""
|
||||
def __init__(self, cls, method_name, spec, matching_specs):
|
||||
super(AmbiguousMethodVersionError, self).__init__(
|
||||
"Package %s has multiple versions of %s that match %s: %s"
|
||||
% (cls.__name__, method_name, spec,
|
||||
",".join(str(s) for s in matching_specs)))
|
|
@ -25,7 +25,6 @@
|
|||
import multiprocessing
|
||||
import url
|
||||
|
||||
from spack.multi_function import platform
|
||||
import spack.util.crypto as crypto
|
||||
from spack.version import *
|
||||
from spack.stage import Stage
|
||||
|
|
|
@ -50,79 +50,18 @@ class Mpileaks(Package):
|
|||
|
||||
import spack
|
||||
import spack.spec
|
||||
from spack.spec import Spec
|
||||
import spack.error
|
||||
from spack.spec import Spec, parse_local_spec
|
||||
from spack.packages import packages_module
|
||||
|
||||
|
||||
def _caller_locals():
|
||||
"""This will return the locals of the *parent* of the caller.
|
||||
This allows a fucntion to insert variables into its caller's
|
||||
scope. Yes, this is some black magic, and yes it's useful
|
||||
for implementing things like depends_on and provides.
|
||||
"""
|
||||
stack = inspect.stack()
|
||||
try:
|
||||
return stack[2][0].f_locals
|
||||
finally:
|
||||
del stack
|
||||
|
||||
|
||||
def _get_calling_package_name():
|
||||
"""Make sure that the caller is a class definition, and return
|
||||
the module's name. This is useful for getting the name of
|
||||
spack packages from inside a relation function.
|
||||
"""
|
||||
stack = inspect.stack()
|
||||
try:
|
||||
# get calling function name (the relation)
|
||||
relation = stack[1][3]
|
||||
|
||||
# Make sure locals contain __module__
|
||||
caller_locals = stack[2][0].f_locals
|
||||
finally:
|
||||
del stack
|
||||
|
||||
if not '__module__' in caller_locals:
|
||||
raise ScopeError(relation)
|
||||
|
||||
module_name = caller_locals['__module__']
|
||||
base_name = module_name.split('.')[-1]
|
||||
return base_name
|
||||
|
||||
|
||||
def _parse_local_spec(spec_like, pkg_name):
|
||||
"""Allow the user to omit the package name part of a spec in relations.
|
||||
e.g., provides('mpi@2', when='@1.9:') says that this package provides
|
||||
MPI-3 when its version is higher than 1.9.
|
||||
"""
|
||||
if not isinstance(spec_like, (str, Spec)):
|
||||
raise TypeError('spec must be Spec or spec string. Found %s'
|
||||
% type(spec_like))
|
||||
|
||||
if isinstance(spec_like, str):
|
||||
try:
|
||||
local_spec = Spec(spec_like)
|
||||
except spack.parse.ParseError:
|
||||
local_spec = Spec(pkg_name + spec_like)
|
||||
if local_spec.name != pkg_name: raise ValueError(
|
||||
"Invalid spec for package %s: %s" % (pkg_name, spec_like))
|
||||
else:
|
||||
local_spec = spec_like
|
||||
|
||||
if local_spec.name != pkg_name:
|
||||
raise ValueError("Spec name '%s' must match package name '%s'"
|
||||
% (spec_like.name, pkg_name))
|
||||
|
||||
return local_spec
|
||||
from spack.util.lang import *
|
||||
|
||||
|
||||
"""Adds a dependencies local variable in the locals of
|
||||
the calling class, based on args. """
|
||||
def depends_on(*specs):
|
||||
pkg = _get_calling_package_name()
|
||||
pkg = get_calling_package_name()
|
||||
|
||||
dependencies = _caller_locals().setdefault('dependencies', {})
|
||||
dependencies = caller_locals().setdefault('dependencies', {})
|
||||
for string in specs:
|
||||
for spec in spack.spec.parse(string):
|
||||
if pkg == spec.name:
|
||||
|
@ -135,11 +74,11 @@ def provides(*specs, **kwargs):
|
|||
'mpi', other packages can declare that they depend on "mpi", and spack
|
||||
can use the providing package to satisfy the dependency.
|
||||
"""
|
||||
pkg = _get_calling_package_name()
|
||||
pkg = get_calling_package_name()
|
||||
spec_string = kwargs.get('when', pkg)
|
||||
provider_spec = _parse_local_spec(spec_string, pkg)
|
||||
provider_spec = parse_local_spec(spec_string, pkg)
|
||||
|
||||
provided = _caller_locals().setdefault("provided", {})
|
||||
provided = caller_locals().setdefault("provided", {})
|
||||
for string in specs:
|
||||
for provided_spec in spack.spec.parse(string):
|
||||
if pkg == provided_spec.name:
|
||||
|
|
|
@ -1097,6 +1097,34 @@ def parse(string):
|
|||
return SpecParser().parse(string)
|
||||
|
||||
|
||||
def parse_local_spec(spec_like, pkg_name):
|
||||
"""Allow the user to omit the package name part of a spec if they
|
||||
know what it has to be already.
|
||||
|
||||
e.g., provides('mpi@2', when='@1.9:') says that this package
|
||||
provides MPI-3 when its version is higher than 1.9.
|
||||
"""
|
||||
if not isinstance(spec_like, (str, Spec)):
|
||||
raise TypeError('spec must be Spec or spec string. Found %s'
|
||||
% type(spec_like))
|
||||
|
||||
if isinstance(spec_like, str):
|
||||
try:
|
||||
local_spec = Spec(spec_like)
|
||||
except spack.parse.ParseError:
|
||||
local_spec = Spec(pkg_name + spec_like)
|
||||
if local_spec.name != pkg_name: raise ValueError(
|
||||
"Invalid spec for package %s: %s" % (pkg_name, spec_like))
|
||||
else:
|
||||
local_spec = spec_like
|
||||
|
||||
if local_spec.name != pkg_name:
|
||||
raise ValueError("Spec name '%s' must match package name '%s'"
|
||||
% (local_spec.name, pkg_name))
|
||||
|
||||
return local_spec
|
||||
|
||||
|
||||
class SpecError(spack.error.SpackError):
|
||||
"""Superclass for all errors that occur while constructing specs."""
|
||||
def __init__(self, message):
|
||||
|
|
|
@ -5,19 +5,24 @@
|
|||
from spack.colify import colify
|
||||
import spack.tty as tty
|
||||
|
||||
"""Names of tests to be included in Spack's test suite"""
|
||||
test_names = ['versions',
|
||||
'url_parse',
|
||||
'stage',
|
||||
'spec_syntax',
|
||||
'spec_dag',
|
||||
'concretize']
|
||||
'concretize',
|
||||
'multimethod']
|
||||
|
||||
|
||||
def list_tests():
|
||||
"""Return names of all tests that can be run for Spack."""
|
||||
return test_names
|
||||
|
||||
|
||||
def run(names, verbose=False):
|
||||
"""Run tests with the supplied names. Names should be a list. If
|
||||
it's empty, run ALL of Spack's tests."""
|
||||
verbosity = 1 if not verbose else 2
|
||||
|
||||
if not names:
|
||||
|
@ -35,6 +40,7 @@ def run(names, verbose=False):
|
|||
testsRun = errors = failures = skipped = 0
|
||||
for test in names:
|
||||
module = 'spack.test.' + test
|
||||
print module
|
||||
suite = unittest.defaultTestLoader.loadTestsFromName(module)
|
||||
|
||||
tty.msg("Running test: %s" % test)
|
||||
|
|
39
lib/spack/spack/test/mock_packages/multimethod.py
Normal file
39
lib/spack/spack/test/mock_packages/multimethod.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
from spack import *
|
||||
|
||||
|
||||
class Multimethod(Package):
|
||||
"""This package is designed for use with Spack's multimethod test.
|
||||
It has a bunch of test cases for the @when decorator that the
|
||||
test uses.
|
||||
"""
|
||||
|
||||
homepage = 'http://www.example.com/'
|
||||
url = 'http://www.example.com/example-1.0.tar.gz'
|
||||
|
||||
#
|
||||
# These functions are only valid for versions 1, 2, and 3.
|
||||
#
|
||||
@when('@1.0')
|
||||
def no_version_2(self):
|
||||
return 1
|
||||
|
||||
@when('@3.0')
|
||||
def no_version_2(self):
|
||||
return 3
|
||||
|
||||
@when('@4.0')
|
||||
def no_version_2(self):
|
||||
return 4
|
||||
|
||||
|
||||
#
|
||||
# These functions overlap too much, so there is ambiguity
|
||||
#
|
||||
@when('@:4')
|
||||
def version_overlap(self):
|
||||
pass
|
||||
|
||||
@when('@2:')
|
||||
def version_overlap(self):
|
||||
pass
|
||||
|
34
lib/spack/spack/test/multimethod.py
Normal file
34
lib/spack/spack/test/multimethod.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
"""
|
||||
Test for multi_method dispatch.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import spack.packages as packages
|
||||
from spack.multimethod import *
|
||||
from spack.version import *
|
||||
from spack.spec import Spec
|
||||
from spack.multimethod import when
|
||||
from spack.test.mock_packages_test import *
|
||||
|
||||
|
||||
class MultiMethodTest(MockPackagesTest):
|
||||
|
||||
def test_no_version_match(self):
|
||||
pkg = packages.get('multimethod@2.0')
|
||||
self.assertRaises(NoSuchMethodVersionError, pkg.no_version_2)
|
||||
|
||||
def test_one_version_match(self):
|
||||
pkg = packages.get('multimethod@1.0')
|
||||
self.assertEqual(pkg.no_version_2(), 1)
|
||||
|
||||
pkg = packages.get('multimethod@3.0')
|
||||
self.assertEqual(pkg.no_version_2(), 3)
|
||||
|
||||
pkg = packages.get('multimethod@4.0')
|
||||
self.assertEqual(pkg.no_version_2(), 4)
|
||||
|
||||
|
||||
def test_multiple_matches(self):
|
||||
pkg = packages.get('multimethod@3.0')
|
||||
self.assertRaises(AmbiguousMethodVersionError, pkg.version_overlap)
|
||||
|
|
@ -14,8 +14,6 @@
|
|||
from spack.spec import Spec
|
||||
from spack.test.mock_packages_test import *
|
||||
|
||||
mock_packages_path = new_path(spack.module_path, 'test', 'mock_packages')
|
||||
|
||||
|
||||
class ValidationTest(MockPackagesTest):
|
||||
|
||||
|
|
|
@ -9,6 +9,42 @@
|
|||
ignore_modules = [r'^\.#', '~$']
|
||||
|
||||
|
||||
def caller_locals():
|
||||
"""This will return the locals of the *parent* of the caller.
|
||||
This allows a fucntion to insert variables into its caller's
|
||||
scope. Yes, this is some black magic, and yes it's useful
|
||||
for implementing things like depends_on and provides.
|
||||
"""
|
||||
stack = inspect.stack()
|
||||
try:
|
||||
return stack[2][0].f_locals
|
||||
finally:
|
||||
del stack
|
||||
|
||||
|
||||
def get_calling_package_name():
|
||||
"""Make sure that the caller is a class definition, and return
|
||||
the module's name. This is useful for getting the name of
|
||||
spack packages from inside a relation function.
|
||||
"""
|
||||
stack = inspect.stack()
|
||||
try:
|
||||
# get calling function name (the relation)
|
||||
relation = stack[1][3]
|
||||
|
||||
# Make sure locals contain __module__
|
||||
caller_locals = stack[2][0].f_locals
|
||||
finally:
|
||||
del stack
|
||||
|
||||
if not '__module__' in caller_locals:
|
||||
raise ScopeError(relation)
|
||||
|
||||
module_name = caller_locals['__module__']
|
||||
base_name = module_name.split('.')[-1]
|
||||
return base_name
|
||||
|
||||
|
||||
def attr_required(obj, attr_name):
|
||||
"""Ensure that a class has a required attribute."""
|
||||
if not hasattr(obj, attr_name):
|
||||
|
|
Loading…
Reference in a new issue