Spec constraints and normalization now work.

- Specs can be "constrained" by other specs, throw exceptions when constraint
  can't be satisfied.

- Normalize will put a spec in DAG form and merge all package constraints with
  the spec.

- Ready to add concretization policies for abstract specs now.
This commit is contained in:
Todd Gamblin 2013-10-15 03:04:25 -07:00
parent 3fb7699e1e
commit db07c7f611
17 changed files with 481 additions and 199 deletions

View file

@ -12,6 +12,5 @@ def setup_parser(subparser):
def spec(parser, args):
specs = spack.cmd.parse_specs(args.specs)
for spec in specs:
print spec.colorized()
print " --> ", spec.concretized().colorized()
print spec.concretized().concrete()
spec.normalize()
print spec.tree()

View file

@ -32,7 +32,7 @@
install_layout = DefaultDirectoryLayout(install_path)
# Version information
spack_version = Version("0.2")
spack_version = Version("0.5")
# User's editor from the environment
editor = Executable(os.environ.get("EDITOR", ""))

View file

@ -381,7 +381,7 @@ def sanity_check(self):
@property
@memoized
def all_dependencies(self):
"""Set of all transitive dependencies of this package."""
"""Dict(str -> Package) of all transitive dependencies of this package."""
all_deps = set(self.dependencies)
for dep in self.dependencies:
dep_pkg = packages.get(dep.name)

View file

@ -9,6 +9,7 @@
import spack.error
import spack.spec
from spack.util.filesystem import new_path
from spack.util.lang import list_modules
import spack.arch as arch
# Valid package names can contain '-' but can't start with it.
@ -19,13 +20,12 @@
instances = {}
def get(spec):
spec = spack.spec.make_spec(spec)
if not spec in instances:
package_class = get_class_for_package_name(spec.name)
instances[spec] = package_class(spec)
def get(pkg_name):
if not pkg_name in instances:
package_class = get_class_for_package_name(pkg_name)
instances[pkg_name] = package_class(pkg_name)
return instances[spec]
return instances[pkg_name]
def valid_package_name(pkg_name):

View file

@ -0,0 +1,14 @@
from spack import *
class Callpath(Package):
homepage = "https://github.com/tgamblin/callpath"
url = "http://github.com/tgamblin/callpath-0.2.tar.gz"
md5 = "foobarbaz"
depends_on("dyninst")
depends_on("mpich")
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -0,0 +1,14 @@
from spack import *
class Dyninst(Package):
homepage = "https://paradyn.org"
url = "http://www.dyninst.org/sites/default/files/downloads/dyninst/8.1.2/DyninstAPI-8.1.2.tgz"
md5 = "bf03b33375afa66fe0efa46ce3f4b17a"
depends_on("libelf")
depends_on("libdwarf")
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -11,7 +11,7 @@ class Libdwarf(Package):
list_url = "http://reality.sgiweb.org/davea/dwarf.html"
depends_on("libelf")
depends_on("libelf@0:1")
def clean(self):

View file

@ -0,0 +1,11 @@
from spack import *
class Mpich(Package):
homepage = "http://www.mpich.org"
url = "http://www.mpich.org/static/downloads/3.0.4/mpich-3.0.4.tar.gz"
md5 = "9c5d5d4fe1e17dd12153f40bc5b6dbc0"
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -0,0 +1,14 @@
from spack import *
class Mpileaks(Package):
homepage = "http://www.llnl.gov"
url = "http://www.llnl.gov/mpileaks-1.0.tar.gz"
md5 = "foobarbaz"
depends_on("mpich")
depends_on("callpath")
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -1,20 +1,6 @@
import re
import spack.error as err
import itertools
class ParseError(err.SpackError):
"""Raised when we don't hit an error while parsing."""
def __init__(self, message, string, pos):
super(ParseError, self).__init__(message)
self.string = string
self.pos = pos
class LexError(ParseError):
"""Raised when we don't know how to lex something."""
def __init__(self, message, string, pos):
super(LexError, self).__init__(message, string, pos)
import spack.error
class Token:
@ -109,3 +95,17 @@ def parse(self, text):
self.text = text
self.push_tokens(self.lexer.lex(text))
return self.do_parse()
class ParseError(spack.error.SpackError):
"""Raised when we don't hit an error while parsing."""
def __init__(self, message, string, pos):
super(ParseError, self).__init__(message)
self.string = string
self.pos = pos
class LexError(ParseError):
"""Raised when we don't know how to lex something."""
def __init__(self, message, string, pos):
super(LexError, self).__init__(message, string, pos)

View file

@ -62,7 +62,6 @@
expansion when it is the first character in an id typed on the command line.
"""
import sys
from functools import total_ordering
from StringIO import StringIO
import tty
@ -72,8 +71,11 @@
import spack.compilers.gcc
import spack.packages as packages
import spack.arch as arch
from spack.version import *
from spack.color import *
from spack.util.lang import *
from spack.util.string import *
"""This map determines the coloring of specs when using color output.
We make the fields different colors to enhance readability.
@ -109,6 +111,7 @@ def __call__(self, match):
return colorize(re.sub(separators, insert_color(), str(spec)) + '@.')
@key_ordering
class Compiler(object):
"""The Compiler field represents the compiler or range of compiler
versions that a package should be built with. Compilers have a
@ -128,6 +131,19 @@ def _add_version(self, version):
self.versions.add(version)
def satisfies(self, other):
return (self.name == other.name and
self.versions.overlaps(other.versions))
def constrain(self, other):
if not self.satisfies(other.compiler):
raise UnsatisfiableCompilerSpecError(
"%s does not satisfy %s" % (self.compiler, other.compiler))
self.versions.intersect(other.versions)
@property
def concrete(self):
return self.versions.concrete
@ -163,16 +179,8 @@ def copy(self):
return clone
def __eq__(self, other):
return (self.name, self.versions) == (other.name, other.versions)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.name, self.versions))
def _cmp_key(self):
return (self.name, self.versions)
def __str__(self):
@ -183,7 +191,7 @@ def __str__(self):
return out
@total_ordering
@key_ordering
class Variant(object):
"""Variants are named, build-time options for a package. Names depend
on the particular package being built, and each named variant can
@ -194,67 +202,21 @@ def __init__(self, name, enabled):
self.enabled = enabled
def __eq__(self, other):
return self.name == other.name and self.enabled == other.enabled
def __ne__(self, other):
return not (self == other)
@property
def tuple(self):
def _cmp_key(self):
return (self.name, self.enabled)
def __hash__(self):
return hash(self.tuple)
def __lt__(self, other):
return self.tuple < other.tuple
def __str__(self):
out = '+' if self.enabled else '~'
return out + self.name
@total_ordering
class HashableMap(dict):
"""This is a hashable, comparable dictionary. Hash is performed on
a tuple of the values in the dictionary."""
def __eq__(self, other):
return (len(self) == len(other) and
sorted(self.values()) == sorted(other.values()))
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
return tuple(sorted(self.values())) < tuple(sorted(other.values()))
def __hash__(self):
return hash(tuple(sorted(self.values())))
def copy(self):
"""Type-agnostic clone method. Preserves subclass type."""
# Construct a new dict of my type
T = type(self)
clone = T()
# Copy everything from this dict into it.
for key in self:
clone[key] = self[key]
return clone
class VariantMap(HashableMap):
def satisfies(self, other):
return all(self[key].enabled == other[key].enabled
for key in other if key in self)
def __str__(self):
sorted_keys = sorted(self.keys())
return ''.join(str(self[key]) for key in sorted_keys)
@ -268,13 +230,18 @@ def concrete(self):
return all(d.concrete for d in self.values())
def satisfies(self, other):
return all(self[name].satisfies(other[name]) for name in self
if name in other)
def __str__(self):
sorted_keys = sorted(self.keys())
sorted_dep_names = sorted(self.keys())
return ''.join(
["^" + str(self[name]) for name in sorted_keys])
["^" + str(self[name]) for name in sorted_dep_names])
@total_ordering
@key_ordering
class Spec(object):
def __init__(self, name):
self.name = name
@ -322,7 +289,7 @@ def _add_dependency(self, dep):
@property
def concrete(self):
return (self.versions.concrete
return bool(self.versions.concrete
# TODO: support variants
and self.architecture
and self.compiler and self.compiler.concrete
@ -349,7 +316,7 @@ def _concretize(self):
"""
# TODO: modularize the process of selecting concrete versions.
# There should be a set of user-configurable policies for these decisions.
self.check_sanity()
self.validate()
# take the system's architecture for starters
if not self.architecture:
@ -370,60 +337,118 @@ def _concretize(self):
# Ensure dependencies have right versions
@property
def traverse_deps(self, visited=None):
"""Yields dependencies in depth-first order"""
if not visited:
visited = set()
def check_sanity(self):
"""Check names of packages and dependency validity."""
self.check_package_name_sanity()
self.check_dependency_sanity()
self.check_dependence_constraint_sanity()
def check_package_name_sanity(self):
"""Ensure that all packages mentioned in the spec exist."""
packages.get(self.name)
for dep in self.dependencies.values():
packages.get(dep.name)
def check_dependency_sanity(self):
"""Ensure that dependencies specified on the spec are actual
dependencies of the package it represents.
"""
pkg = packages.get(self.name)
dep_names = set(dep.name for dep in pkg.all_dependencies)
invalid_dependencies = [d.name for d in self.dependencies.values()
if d.name not in dep_names]
if invalid_dependencies:
raise InvalidDependencyException(
"The packages (%s) are not dependencies of %s" %
(','.join(invalid_dependencies), self.name))
def check_dependence_constraint_sanity(self):
"""Ensure that package's dependencies have consistent constraints on
their dependencies.
"""
pkg = packages.get(self.name)
specs = {}
for spec in pkg.all_dependencies:
if not spec.name in specs:
specs[spec.name] = spec
for name in sorted(self.dependencies.keys()):
dep = dependencies[name]
if dep in visited:
continue
merged = specs[spec.name]
# Specs in deps can't be disjoint.
if not spec.versions.overlaps(merged.versions):
raise InvalidConstraintException(
"One package %s, version constraint %s conflicts with %s"
% (pkg.name, spec.versions, merged.versions))
for d in dep.traverse_deps(seen):
yield d
yield dep
def merge(self, other):
"""Considering these specs as constraints, attempt to merge.
Raise an exception if specs are disjoint.
"""
pass
def _normalize_helper(self, visited, spec_deps):
"""Recursive helper function for _normalize."""
if self.name in visited:
return
visited.add(self.name)
# Combine constraints from package dependencies with
# information in this spec's dependencies.
pkg = packages.get(self.name)
for pkg_dep in pkg.dependencies:
name = pkg_dep.name
if name not in spec_deps:
# Clone the spec from the package
spec_deps[name] = pkg_dep.copy()
try:
# intersect package information with spec info
spec_deps[name].constrain(pkg_dep)
except UnsatisfiableSpecError, e:
error_type = type(e)
raise error_type(
"Violated depends_on constraint from package %s: %s"
% (self.name, e.message))
# Add merged spec to my deps and recurse
self.dependencies[name] = spec_deps[name]
self.dependencies[name]._normalize_helper(visited, spec_deps)
def normalize(self):
if any(dep.dependencies for dep in self.dependencies.values()):
raise SpecError("Spec has already been normalized.")
self.validate_package_names()
spec_deps = self.dependencies
self.dependencies = DependencyMap()
visited = set()
self._normalize_helper(visited, spec_deps)
# If there are deps specified but not visited, they're not
# actually deps of this package. Raise an error.
extra = set(spec_deps.viewkeys()).difference(visited)
if extra:
raise InvalidDependencyException(
self.name + " does not depend on " + comma_or(extra))
def validate_package_names(self):
for name in self.dependencies:
packages.get(name)
def constrain(self, other):
if not self.versions.overlaps(other.versions):
raise UnsatisfiableVersionSpecError(
"%s does not satisfy %s" % (self.versions, other.versions))
conflicting_variants = [
v for v in other.variants if v in self.variants and
self.variants[v].enabled != other.variants[v].enabled]
if conflicting_variants:
raise UnsatisfiableVariantSpecError(comma_and(
"%s does not satisfy %s" % (self.variants[v], other.variants[v])
for v in conflicting_variants))
if self.architecture is not None and other.architecture is not None:
if self.architecture != other.architecture:
raise UnsatisfiableArchitectureSpecError(
"Asked for architecture %s, but required %s"
% (self.architecture, other.architecture))
if self.compiler is not None and other.compiler is not None:
self.compiler.constrain(other.compiler)
elif self.compiler is None:
self.compiler = other.compiler
self.versions.intersect(other.versions)
self.variants.update(other.variants)
self.architecture = self.architecture or other.architecture
def satisfies(self, other):
def sat(attribute):
s = getattr(self, attribute)
o = getattr(other, attribute)
return not s or not o or s.satisfies(o)
return (self.name == other.name and
all(sat(attr) for attr in
('versions', 'variants', 'compiler', 'architecture')) and
# TODO: what does it mean to satisfy deps?
self.dependencies.satisfies(other.dependencies))
def concretized(self):
@ -451,43 +476,16 @@ def version(self):
return self.versions[0]
@property
def tuple(self):
def _cmp_key(self):
return (self.name, self.versions, self.variants,
self.architecture, self.compiler, self.dependencies)
@property
def tuple(self):
return (self.name, self.versions, self.variants, self.architecture,
self.compiler, self.dependencies)
def __eq__(self, other):
return self.tuple == other.tuple
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
return self.tuple < other.tuple
def __hash__(self):
return hash(self.tuple)
self.architecture, self.compiler)
def colorized(self):
return colorize_spec(self)
def __repr__(self):
return str(self)
def __str__(self):
def str_without_deps(self):
out = self.name
# If the version range is entirely open, omit it
@ -502,10 +500,26 @@ def __str__(self):
if self.architecture:
out += "=%s" % self.architecture
out += str(self.dependencies)
return out
def tree(self, indent=""):
"""Prints out this spec and its dependencies, tree-formatted
with indentation."""
out = indent + self.str_without_deps()
for dep in sorted(self.dependencies.keys()):
out += "\n" + self.dependencies[dep].tree(indent + " ")
return out
def __repr__(self):
return str(self)
def __str__(self):
return self.str_without_deps() + str(self.dependencies)
#
# These are possible token types in the spec grammar.
#
@ -580,7 +594,7 @@ def spec(self):
# If there was no version in the spec, consier it an open range
if not added_version:
spec.versions = VersionList([':'])
spec.versions = VersionList(':')
return spec
@ -721,7 +735,31 @@ def __init__(self, message):
super(InvalidDependencyException, self).__init__(message)
class InvalidConstraintException(SpecError):
"""Raised when a package dependencies conflict."""
class UnsatisfiableSpecError(SpecError):
"""Raised when a spec conflicts with package constraints."""
def __init__(self, message):
super(InvalidConstraintException, self).__init__(message)
super(UnsatisfiableSpecError, self).__init__(message)
class UnsatisfiableVersionSpecError(UnsatisfiableSpecError):
"""Raised when a spec version conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableVersionSpecError, self).__init__(message)
class UnsatisfiableCompilerSpecError(UnsatisfiableSpecError):
"""Raised when a spec comiler conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableCompilerSpecError, self).__init__(message)
class UnsatisfiableVariantSpecError(UnsatisfiableSpecError):
"""Raised when a spec variant conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableVariantSpecError, self).__init__(message)
class UnsatisfiableArchitectureSpecError(UnsatisfiableSpecError):
"""Raised when a spec architecture conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableArchitectureSpecError, self).__init__(message)

View file

@ -6,8 +6,12 @@ class ConcretizeTest(unittest.TestCase):
def check_concretize(self, abstract_spec):
abstract = spack.spec.parse_one(abstract_spec)
print abstract
print abstract.concretized()
print abstract.concretized().concrete
self.assertTrue(abstract.concretized().concrete)
def test_packages(self):
self.check_concretize("libelf")
pass
#self.check_concretize("libelf")

View file

@ -59,6 +59,25 @@ def check_lex(self, tokens, spec):
# Only check the type for non-identifiers.
self.assertEqual(tok.type, spec_tok.type)
def check_satisfies(self, lspec, rspec):
l = spack.spec.parse_one(lspec)
r = spack.spec.parse_one(rspec)
self.assertTrue(l.satisfies(r) and r.satisfies(l))
# These should not raise
l.constrain(r)
r.constrain(l)
def check_constrain(self, expected, constrained, constraint):
exp = spack.spec.parse_one(expected)
constrained = spack.spec.parse_one(constrained)
constraint = spack.spec.parse_one(constraint)
constrained.constrain(constraint)
self.assertEqual(exp, constrained)
# ================================================================================
# Parse checks
# ===============================================================================
@ -117,6 +136,18 @@ def test_duplicate_compiler(self):
self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%gcc%intel")
# ================================================================================
# Satisfiability and constraints
# ================================================================================
def test_satisfies(self):
self.check_satisfies('libelf@0.8.13', 'libelf@0:1')
self.check_satisfies('libdwarf^libelf@0.8.13', 'libdwarf^libelf@0:1')
def test_constrain(self):
self.check_constrain('libelf@0:1', 'libelf', 'libelf@0:1')
# ================================================================================
# Lex checks
# ================================================================================

View file

@ -59,6 +59,10 @@ def assert_no_overlap(self, v1, v2):
self.assertFalse(ver(v1).overlaps(ver(v2)))
def check_intersection(self, expected, a, b):
self.assertEqual(ver(expected), ver(a).intersection(ver(b)))
def test_two_segments(self):
self.assert_ver_eq('1.0', '1.0')
self.assert_ver_lt('1.0', '2.0')
@ -215,6 +219,7 @@ def test_ranges_overlap(self):
self.assert_overlaps('1.2:', '1.6:')
self.assert_overlaps(':', ':')
self.assert_overlaps(':', '1.6:1.9')
self.assert_overlaps('1.6:1.9', ':')
def test_lists_overlap(self):
@ -258,3 +263,17 @@ def test_canonicalize_list(self):
self.assert_canonical([':'],
[':,1.3, 1.3.1,1.3.9,1.4 : 1.5 , 1.3 : 1.4'])
def test_intersection(self):
self.check_intersection('2.5',
'1.0:2.5', '2.5:3.0')
self.check_intersection('2.5:2.7',
'1.0:2.7', '2.5:3.0')
self.check_intersection('0:1', ':', '0:1')
self.check_intersection(['1.0', '2.5:2.7'],
['1.0:2.7'], ['2.5:3.0','1.0'])
self.check_intersection(['2.5:2.7'],
['1.1:2.7'], ['2.5:3.0','1.0'])
self.check_intersection(['0:1'], [':'], ['0:1'])

View file

@ -1,8 +1,20 @@
import os
import re
import sys
import functools
import inspect
from spack.util.filesystem import new_path
def has_method(cls, name):
for base in inspect.getmro(cls):
if base is object:
continue
if name in base.__dict__:
return True
return False
def memoized(obj):
"""Decorator that caches the results of a function, storing them
in an attribute of that function."""
@ -30,3 +42,54 @@ def list_modules(directory):
elif name.endswith('.py'):
yield re.sub('.py$', '', name)
def key_ordering(cls):
"""Decorates a class with extra methods that implement rich comparison
operations and __hash__. The decorator assumes that the class
implements a function called _cmp_key(). The rich comparison operations
will compare objects using this key, and the __hash__ function will
return the hash of this key.
If a class already has __eq__, __ne__, __lt__, __le__, __gt__, or __ge__
defined, this decorator will overwrite them. If the class does not
have a _cmp_key method, then this will raise a TypeError.
"""
def setter(name, value):
value.__name__ = name
setattr(cls, name, value)
if not has_method(cls, '_cmp_key'):
raise TypeError("'%s' doesn't define _cmp_key()." % cls.__name__)
setter('__eq__', lambda s,o: o is not None and s._cmp_key() == o._cmp_key())
setter('__lt__', lambda s,o: o is not None and s._cmp_key() < o._cmp_key())
setter('__le__', lambda s,o: o is not None and s._cmp_key() <= o._cmp_key())
setter('__ne__', lambda s,o: o is None or s._cmp_key() != o._cmp_key())
setter('__gt__', lambda s,o: o is None or s._cmp_key() > o._cmp_key())
setter('__ge__', lambda s,o: o is None or s._cmp_key() >= o._cmp_key())
setter('__hash__', lambda self: hash(self._cmp_key()))
return cls
@key_ordering
class HashableMap(dict):
"""This is a hashable, comparable dictionary. Hash is performed on
a tuple of the values in the dictionary."""
def _cmp_key(self):
return tuple(sorted(self.values()))
def copy(self):
"""Type-agnostic clone method. Preserves subclass type."""
# Construct a new dict of my type
T = type(self)
clone = T()
# Copy everything from this dict into it.
for key in self:
clone[key] = self[key]
return clone

View file

@ -0,0 +1,23 @@
def comma_list(sequence, article=''):
if type(sequence) != list:
sequence = list(sequence)
if not sequence:
return
elif len(sequence) == 1:
return sequence[0]
else:
out = ', '.join(str(s) for s in sequence[:-1])
out += ', '
if article:
out += article + ' '
out += str(sequence[-1])
return out
def comma_or(sequence):
return comma_list(sequence, 'or')
def comma_and(sequence):
return comma_list(sequence, 'and')

View file

@ -13,8 +13,10 @@
__eq__, __ne__, __lt__, __gt__, __ge__, __le__, __hash__
__contains__
satisfies
overlaps
merge
union
intersection
concrete
True if the Version, VersionRange or VersionList represents
a single version.
@ -161,6 +163,7 @@ def __str__(self):
def concrete(self):
return self
@coerced
def __lt__(self, other):
"""Version comparison is designed for consistency with the way RPM
@ -219,13 +222,21 @@ def overlaps(self, other):
@coerced
def merge(self, other):
def union(self, other):
if self == other:
return self
else:
return VersionList([self, other])
@coerced
def intersection(self, other):
if self == other:
return self
else:
return VersionList()
@total_ordering
class VersionRange(object):
def __init__(self, start, end):
@ -295,9 +306,21 @@ def overlaps(self, other):
@coerced
def merge(self, other):
def union(self, other):
if self.overlaps(other):
return VersionRange(none_low.min(self.start, other.start),
none_high.max(self.end, other.end))
else:
return VersionList([self, other])
@coerced
def intersection(self, other):
if self.overlaps(other):
return VersionRange(none_low.max(self.start, other.start),
none_high.min(self.end, other.end))
else:
return VersionList()
def __hash__(self):
@ -338,12 +361,12 @@ def add(self, version):
i = bisect_left(self, version)
while i-1 >= 0 and version.overlaps(self[i-1]):
version = version.merge(self[i-1])
version = version.union(self[i-1])
del self.versions[i-1]
i -= 1
while i < len(self) and version.overlaps(self[i]):
version = version.merge(self[i])
version = version.union(self[i])
del self.versions[i]
self.versions.insert(i, version)
@ -384,25 +407,54 @@ def highest(self):
return self[-1].highest()
def satisfies(self, other):
"""Synonym for overlaps."""
return self.overlaps(other)
@coerced
def overlaps(self, other):
if not other or not self:
return False
i = o = 0
while i < len(self) and o < len(other):
if self[i].overlaps(other[o]):
s = o = 0
while s < len(self) and o < len(other):
if self[s].overlaps(other[o]):
return True
elif self[i] < other[o]:
i += 1
elif self[s] < other[o]:
s += 1
else:
o += 1
return False
@coerced
def merge(self, other):
return VersionList(self.versions + other.versions)
def update(self, other):
for v in other.versions:
self.add(v)
@coerced
def union(self, other):
result = self.copy()
result.update(other)
return result
@coerced
def intersection(self, other):
# TODO: make this faster. This is O(n^2).
result = VersionList()
for s in self:
for o in other:
result.add(s.intersection(o))
return result
@coerced
def intersect(self, other):
isection = self.intersection(other)
self.versions = isection.versions
@coerced