Spec constraints and normalization now work.

- Specs can be "constrained" by other specs, throw exceptions when constraint
  can't be satisfied.

- Normalize will put a spec in DAG form and merge all package constraints with
  the spec.

- Ready to add concretization policies for abstract specs now.
This commit is contained in:
Todd Gamblin 2013-10-15 03:04:25 -07:00
parent 3fb7699e1e
commit db07c7f611
17 changed files with 481 additions and 199 deletions

View file

@ -12,6 +12,5 @@ def setup_parser(subparser):
def spec(parser, args): def spec(parser, args):
specs = spack.cmd.parse_specs(args.specs) specs = spack.cmd.parse_specs(args.specs)
for spec in specs: for spec in specs:
print spec.colorized() spec.normalize()
print " --> ", spec.concretized().colorized() print spec.tree()
print spec.concretized().concrete()

View file

@ -32,7 +32,7 @@
install_layout = DefaultDirectoryLayout(install_path) install_layout = DefaultDirectoryLayout(install_path)
# Version information # Version information
spack_version = Version("0.2") spack_version = Version("0.5")
# User's editor from the environment # User's editor from the environment
editor = Executable(os.environ.get("EDITOR", "")) editor = Executable(os.environ.get("EDITOR", ""))

View file

@ -381,7 +381,7 @@ def sanity_check(self):
@property @property
@memoized @memoized
def all_dependencies(self): def all_dependencies(self):
"""Set of all transitive dependencies of this package.""" """Dict(str -> Package) of all transitive dependencies of this package."""
all_deps = set(self.dependencies) all_deps = set(self.dependencies)
for dep in self.dependencies: for dep in self.dependencies:
dep_pkg = packages.get(dep.name) dep_pkg = packages.get(dep.name)

View file

@ -9,6 +9,7 @@
import spack.error import spack.error
import spack.spec import spack.spec
from spack.util.filesystem import new_path from spack.util.filesystem import new_path
from spack.util.lang import list_modules
import spack.arch as arch import spack.arch as arch
# Valid package names can contain '-' but can't start with it. # Valid package names can contain '-' but can't start with it.
@ -19,13 +20,12 @@
instances = {} instances = {}
def get(spec): def get(pkg_name):
spec = spack.spec.make_spec(spec) if not pkg_name in instances:
if not spec in instances: package_class = get_class_for_package_name(pkg_name)
package_class = get_class_for_package_name(spec.name) instances[pkg_name] = package_class(pkg_name)
instances[spec] = package_class(spec)
return instances[spec] return instances[pkg_name]
def valid_package_name(pkg_name): def valid_package_name(pkg_name):

View file

@ -0,0 +1,14 @@
from spack import *
class Callpath(Package):
homepage = "https://github.com/tgamblin/callpath"
url = "http://github.com/tgamblin/callpath-0.2.tar.gz"
md5 = "foobarbaz"
depends_on("dyninst")
depends_on("mpich")
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -0,0 +1,14 @@
from spack import *
class Dyninst(Package):
homepage = "https://paradyn.org"
url = "http://www.dyninst.org/sites/default/files/downloads/dyninst/8.1.2/DyninstAPI-8.1.2.tgz"
md5 = "bf03b33375afa66fe0efa46ce3f4b17a"
depends_on("libelf")
depends_on("libdwarf")
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -11,7 +11,7 @@ class Libdwarf(Package):
list_url = "http://reality.sgiweb.org/davea/dwarf.html" list_url = "http://reality.sgiweb.org/davea/dwarf.html"
depends_on("libelf") depends_on("libelf@0:1")
def clean(self): def clean(self):

View file

@ -0,0 +1,11 @@
from spack import *
class Mpich(Package):
homepage = "http://www.mpich.org"
url = "http://www.mpich.org/static/downloads/3.0.4/mpich-3.0.4.tar.gz"
md5 = "9c5d5d4fe1e17dd12153f40bc5b6dbc0"
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -0,0 +1,14 @@
from spack import *
class Mpileaks(Package):
homepage = "http://www.llnl.gov"
url = "http://www.llnl.gov/mpileaks-1.0.tar.gz"
md5 = "foobarbaz"
depends_on("mpich")
depends_on("callpath")
def install(self, prefix):
configure("--prefix=%s" % prefix)
make()
make("install")

View file

@ -1,20 +1,6 @@
import re import re
import spack.error as err
import itertools import itertools
import spack.error
class ParseError(err.SpackError):
"""Raised when we don't hit an error while parsing."""
def __init__(self, message, string, pos):
super(ParseError, self).__init__(message)
self.string = string
self.pos = pos
class LexError(ParseError):
"""Raised when we don't know how to lex something."""
def __init__(self, message, string, pos):
super(LexError, self).__init__(message, string, pos)
class Token: class Token:
@ -109,3 +95,17 @@ def parse(self, text):
self.text = text self.text = text
self.push_tokens(self.lexer.lex(text)) self.push_tokens(self.lexer.lex(text))
return self.do_parse() return self.do_parse()
class ParseError(spack.error.SpackError):
"""Raised when we don't hit an error while parsing."""
def __init__(self, message, string, pos):
super(ParseError, self).__init__(message)
self.string = string
self.pos = pos
class LexError(ParseError):
"""Raised when we don't know how to lex something."""
def __init__(self, message, string, pos):
super(LexError, self).__init__(message, string, pos)

View file

@ -62,7 +62,6 @@
expansion when it is the first character in an id typed on the command line. expansion when it is the first character in an id typed on the command line.
""" """
import sys import sys
from functools import total_ordering
from StringIO import StringIO from StringIO import StringIO
import tty import tty
@ -72,8 +71,11 @@
import spack.compilers.gcc import spack.compilers.gcc
import spack.packages as packages import spack.packages as packages
import spack.arch as arch import spack.arch as arch
from spack.version import * from spack.version import *
from spack.color import * from spack.color import *
from spack.util.lang import *
from spack.util.string import *
"""This map determines the coloring of specs when using color output. """This map determines the coloring of specs when using color output.
We make the fields different colors to enhance readability. We make the fields different colors to enhance readability.
@ -109,6 +111,7 @@ def __call__(self, match):
return colorize(re.sub(separators, insert_color(), str(spec)) + '@.') return colorize(re.sub(separators, insert_color(), str(spec)) + '@.')
@key_ordering
class Compiler(object): class Compiler(object):
"""The Compiler field represents the compiler or range of compiler """The Compiler field represents the compiler or range of compiler
versions that a package should be built with. Compilers have a versions that a package should be built with. Compilers have a
@ -128,6 +131,19 @@ def _add_version(self, version):
self.versions.add(version) self.versions.add(version)
def satisfies(self, other):
return (self.name == other.name and
self.versions.overlaps(other.versions))
def constrain(self, other):
if not self.satisfies(other.compiler):
raise UnsatisfiableCompilerSpecError(
"%s does not satisfy %s" % (self.compiler, other.compiler))
self.versions.intersect(other.versions)
@property @property
def concrete(self): def concrete(self):
return self.versions.concrete return self.versions.concrete
@ -163,16 +179,8 @@ def copy(self):
return clone return clone
def __eq__(self, other): def _cmp_key(self):
return (self.name, self.versions) == (other.name, other.versions) return (self.name, self.versions)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.name, self.versions))
def __str__(self): def __str__(self):
@ -183,7 +191,7 @@ def __str__(self):
return out return out
@total_ordering @key_ordering
class Variant(object): class Variant(object):
"""Variants are named, build-time options for a package. Names depend """Variants are named, build-time options for a package. Names depend
on the particular package being built, and each named variant can on the particular package being built, and each named variant can
@ -194,67 +202,21 @@ def __init__(self, name, enabled):
self.enabled = enabled self.enabled = enabled
def __eq__(self, other): def _cmp_key(self):
return self.name == other.name and self.enabled == other.enabled
def __ne__(self, other):
return not (self == other)
@property
def tuple(self):
return (self.name, self.enabled) return (self.name, self.enabled)
def __hash__(self):
return hash(self.tuple)
def __lt__(self, other):
return self.tuple < other.tuple
def __str__(self): def __str__(self):
out = '+' if self.enabled else '~' out = '+' if self.enabled else '~'
return out + self.name return out + self.name
@total_ordering
class HashableMap(dict):
"""This is a hashable, comparable dictionary. Hash is performed on
a tuple of the values in the dictionary."""
def __eq__(self, other):
return (len(self) == len(other) and
sorted(self.values()) == sorted(other.values()))
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
return tuple(sorted(self.values())) < tuple(sorted(other.values()))
def __hash__(self):
return hash(tuple(sorted(self.values())))
def copy(self):
"""Type-agnostic clone method. Preserves subclass type."""
# Construct a new dict of my type
T = type(self)
clone = T()
# Copy everything from this dict into it.
for key in self:
clone[key] = self[key]
return clone
class VariantMap(HashableMap): class VariantMap(HashableMap):
def satisfies(self, other):
return all(self[key].enabled == other[key].enabled
for key in other if key in self)
def __str__(self): def __str__(self):
sorted_keys = sorted(self.keys()) sorted_keys = sorted(self.keys())
return ''.join(str(self[key]) for key in sorted_keys) return ''.join(str(self[key]) for key in sorted_keys)
@ -268,13 +230,18 @@ def concrete(self):
return all(d.concrete for d in self.values()) return all(d.concrete for d in self.values())
def satisfies(self, other):
return all(self[name].satisfies(other[name]) for name in self
if name in other)
def __str__(self): def __str__(self):
sorted_keys = sorted(self.keys()) sorted_dep_names = sorted(self.keys())
return ''.join( return ''.join(
["^" + str(self[name]) for name in sorted_keys]) ["^" + str(self[name]) for name in sorted_dep_names])
@total_ordering @key_ordering
class Spec(object): class Spec(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
@ -322,11 +289,11 @@ def _add_dependency(self, dep):
@property @property
def concrete(self): def concrete(self):
return (self.versions.concrete return bool(self.versions.concrete
# TODO: support variants # TODO: support variants
and self.architecture and self.architecture
and self.compiler and self.compiler.concrete and self.compiler and self.compiler.concrete
and self.dependencies.concrete) and self.dependencies.concrete)
def _concretize(self): def _concretize(self):
@ -349,7 +316,7 @@ def _concretize(self):
""" """
# TODO: modularize the process of selecting concrete versions. # TODO: modularize the process of selecting concrete versions.
# There should be a set of user-configurable policies for these decisions. # There should be a set of user-configurable policies for these decisions.
self.check_sanity() self.validate()
# take the system's architecture for starters # take the system's architecture for starters
if not self.architecture: if not self.architecture:
@ -370,60 +337,118 @@ def _concretize(self):
# Ensure dependencies have right versions # Ensure dependencies have right versions
@property
def traverse_deps(self, visited=None):
"""Yields dependencies in depth-first order"""
if not visited:
visited = set()
def check_sanity(self): for name in sorted(self.dependencies.keys()):
"""Check names of packages and dependency validity.""" dep = dependencies[name]
self.check_package_name_sanity() if dep in visited:
self.check_dependency_sanity()
self.check_dependence_constraint_sanity()
def check_package_name_sanity(self):
"""Ensure that all packages mentioned in the spec exist."""
packages.get(self.name)
for dep in self.dependencies.values():
packages.get(dep.name)
def check_dependency_sanity(self):
"""Ensure that dependencies specified on the spec are actual
dependencies of the package it represents.
"""
pkg = packages.get(self.name)
dep_names = set(dep.name for dep in pkg.all_dependencies)
invalid_dependencies = [d.name for d in self.dependencies.values()
if d.name not in dep_names]
if invalid_dependencies:
raise InvalidDependencyException(
"The packages (%s) are not dependencies of %s" %
(','.join(invalid_dependencies), self.name))
def check_dependence_constraint_sanity(self):
"""Ensure that package's dependencies have consistent constraints on
their dependencies.
"""
pkg = packages.get(self.name)
specs = {}
for spec in pkg.all_dependencies:
if not spec.name in specs:
specs[spec.name] = spec
continue continue
merged = specs[spec.name] for d in dep.traverse_deps(seen):
yield d
# Specs in deps can't be disjoint. yield dep
if not spec.versions.overlaps(merged.versions):
raise InvalidConstraintException(
"One package %s, version constraint %s conflicts with %s"
% (pkg.name, spec.versions, merged.versions))
def merge(self, other): def _normalize_helper(self, visited, spec_deps):
"""Considering these specs as constraints, attempt to merge. """Recursive helper function for _normalize."""
Raise an exception if specs are disjoint. if self.name in visited:
""" return
pass visited.add(self.name)
# Combine constraints from package dependencies with
# information in this spec's dependencies.
pkg = packages.get(self.name)
for pkg_dep in pkg.dependencies:
name = pkg_dep.name
if name not in spec_deps:
# Clone the spec from the package
spec_deps[name] = pkg_dep.copy()
try:
# intersect package information with spec info
spec_deps[name].constrain(pkg_dep)
except UnsatisfiableSpecError, e:
error_type = type(e)
raise error_type(
"Violated depends_on constraint from package %s: %s"
% (self.name, e.message))
# Add merged spec to my deps and recurse
self.dependencies[name] = spec_deps[name]
self.dependencies[name]._normalize_helper(visited, spec_deps)
def normalize(self):
if any(dep.dependencies for dep in self.dependencies.values()):
raise SpecError("Spec has already been normalized.")
self.validate_package_names()
spec_deps = self.dependencies
self.dependencies = DependencyMap()
visited = set()
self._normalize_helper(visited, spec_deps)
# If there are deps specified but not visited, they're not
# actually deps of this package. Raise an error.
extra = set(spec_deps.viewkeys()).difference(visited)
if extra:
raise InvalidDependencyException(
self.name + " does not depend on " + comma_or(extra))
def validate_package_names(self):
for name in self.dependencies:
packages.get(name)
def constrain(self, other):
if not self.versions.overlaps(other.versions):
raise UnsatisfiableVersionSpecError(
"%s does not satisfy %s" % (self.versions, other.versions))
conflicting_variants = [
v for v in other.variants if v in self.variants and
self.variants[v].enabled != other.variants[v].enabled]
if conflicting_variants:
raise UnsatisfiableVariantSpecError(comma_and(
"%s does not satisfy %s" % (self.variants[v], other.variants[v])
for v in conflicting_variants))
if self.architecture is not None and other.architecture is not None:
if self.architecture != other.architecture:
raise UnsatisfiableArchitectureSpecError(
"Asked for architecture %s, but required %s"
% (self.architecture, other.architecture))
if self.compiler is not None and other.compiler is not None:
self.compiler.constrain(other.compiler)
elif self.compiler is None:
self.compiler = other.compiler
self.versions.intersect(other.versions)
self.variants.update(other.variants)
self.architecture = self.architecture or other.architecture
def satisfies(self, other):
def sat(attribute):
s = getattr(self, attribute)
o = getattr(other, attribute)
return not s or not o or s.satisfies(o)
return (self.name == other.name and
all(sat(attr) for attr in
('versions', 'variants', 'compiler', 'architecture')) and
# TODO: what does it mean to satisfy deps?
self.dependencies.satisfies(other.dependencies))
def concretized(self): def concretized(self):
@ -451,43 +476,16 @@ def version(self):
return self.versions[0] return self.versions[0]
@property def _cmp_key(self):
def tuple(self):
return (self.name, self.versions, self.variants, return (self.name, self.versions, self.variants,
self.architecture, self.compiler, self.dependencies) self.architecture, self.compiler)
@property
def tuple(self):
return (self.name, self.versions, self.variants, self.architecture,
self.compiler, self.dependencies)
def __eq__(self, other):
return self.tuple == other.tuple
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
return self.tuple < other.tuple
def __hash__(self):
return hash(self.tuple)
def colorized(self): def colorized(self):
return colorize_spec(self) return colorize_spec(self)
def __repr__(self): def str_without_deps(self):
return str(self)
def __str__(self):
out = self.name out = self.name
# If the version range is entirely open, omit it # If the version range is entirely open, omit it
@ -502,10 +500,26 @@ def __str__(self):
if self.architecture: if self.architecture:
out += "=%s" % self.architecture out += "=%s" % self.architecture
out += str(self.dependencies)
return out return out
def tree(self, indent=""):
"""Prints out this spec and its dependencies, tree-formatted
with indentation."""
out = indent + self.str_without_deps()
for dep in sorted(self.dependencies.keys()):
out += "\n" + self.dependencies[dep].tree(indent + " ")
return out
def __repr__(self):
return str(self)
def __str__(self):
return self.str_without_deps() + str(self.dependencies)
# #
# These are possible token types in the spec grammar. # These are possible token types in the spec grammar.
# #
@ -580,7 +594,7 @@ def spec(self):
# If there was no version in the spec, consier it an open range # If there was no version in the spec, consier it an open range
if not added_version: if not added_version:
spec.versions = VersionList([':']) spec.versions = VersionList(':')
return spec return spec
@ -721,7 +735,31 @@ def __init__(self, message):
super(InvalidDependencyException, self).__init__(message) super(InvalidDependencyException, self).__init__(message)
class InvalidConstraintException(SpecError): class UnsatisfiableSpecError(SpecError):
"""Raised when a package dependencies conflict.""" """Raised when a spec conflicts with package constraints."""
def __init__(self, message): def __init__(self, message):
super(InvalidConstraintException, self).__init__(message) super(UnsatisfiableSpecError, self).__init__(message)
class UnsatisfiableVersionSpecError(UnsatisfiableSpecError):
"""Raised when a spec version conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableVersionSpecError, self).__init__(message)
class UnsatisfiableCompilerSpecError(UnsatisfiableSpecError):
"""Raised when a spec comiler conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableCompilerSpecError, self).__init__(message)
class UnsatisfiableVariantSpecError(UnsatisfiableSpecError):
"""Raised when a spec variant conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableVariantSpecError, self).__init__(message)
class UnsatisfiableArchitectureSpecError(UnsatisfiableSpecError):
"""Raised when a spec architecture conflicts with package constraints."""
def __init__(self, message):
super(UnsatisfiableArchitectureSpecError, self).__init__(message)

View file

@ -6,8 +6,12 @@ class ConcretizeTest(unittest.TestCase):
def check_concretize(self, abstract_spec): def check_concretize(self, abstract_spec):
abstract = spack.spec.parse_one(abstract_spec) abstract = spack.spec.parse_one(abstract_spec)
print abstract
print abstract.concretized()
print abstract.concretized().concrete
self.assertTrue(abstract.concretized().concrete) self.assertTrue(abstract.concretized().concrete)
def test_packages(self): def test_packages(self):
self.check_concretize("libelf") pass
#self.check_concretize("libelf")

View file

@ -59,6 +59,25 @@ def check_lex(self, tokens, spec):
# Only check the type for non-identifiers. # Only check the type for non-identifiers.
self.assertEqual(tok.type, spec_tok.type) self.assertEqual(tok.type, spec_tok.type)
def check_satisfies(self, lspec, rspec):
l = spack.spec.parse_one(lspec)
r = spack.spec.parse_one(rspec)
self.assertTrue(l.satisfies(r) and r.satisfies(l))
# These should not raise
l.constrain(r)
r.constrain(l)
def check_constrain(self, expected, constrained, constraint):
exp = spack.spec.parse_one(expected)
constrained = spack.spec.parse_one(constrained)
constraint = spack.spec.parse_one(constraint)
constrained.constrain(constraint)
self.assertEqual(exp, constrained)
# ================================================================================ # ================================================================================
# Parse checks # Parse checks
# =============================================================================== # ===============================================================================
@ -117,6 +136,18 @@ def test_duplicate_compiler(self):
self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%gcc%intel") self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%gcc%intel")
# ================================================================================
# Satisfiability and constraints
# ================================================================================
def test_satisfies(self):
self.check_satisfies('libelf@0.8.13', 'libelf@0:1')
self.check_satisfies('libdwarf^libelf@0.8.13', 'libdwarf^libelf@0:1')
def test_constrain(self):
self.check_constrain('libelf@0:1', 'libelf', 'libelf@0:1')
# ================================================================================ # ================================================================================
# Lex checks # Lex checks
# ================================================================================ # ================================================================================

View file

@ -59,6 +59,10 @@ def assert_no_overlap(self, v1, v2):
self.assertFalse(ver(v1).overlaps(ver(v2))) self.assertFalse(ver(v1).overlaps(ver(v2)))
def check_intersection(self, expected, a, b):
self.assertEqual(ver(expected), ver(a).intersection(ver(b)))
def test_two_segments(self): def test_two_segments(self):
self.assert_ver_eq('1.0', '1.0') self.assert_ver_eq('1.0', '1.0')
self.assert_ver_lt('1.0', '2.0') self.assert_ver_lt('1.0', '2.0')
@ -215,6 +219,7 @@ def test_ranges_overlap(self):
self.assert_overlaps('1.2:', '1.6:') self.assert_overlaps('1.2:', '1.6:')
self.assert_overlaps(':', ':') self.assert_overlaps(':', ':')
self.assert_overlaps(':', '1.6:1.9') self.assert_overlaps(':', '1.6:1.9')
self.assert_overlaps('1.6:1.9', ':')
def test_lists_overlap(self): def test_lists_overlap(self):
@ -258,3 +263,17 @@ def test_canonicalize_list(self):
self.assert_canonical([':'], self.assert_canonical([':'],
[':,1.3, 1.3.1,1.3.9,1.4 : 1.5 , 1.3 : 1.4']) [':,1.3, 1.3.1,1.3.9,1.4 : 1.5 , 1.3 : 1.4'])
def test_intersection(self):
self.check_intersection('2.5',
'1.0:2.5', '2.5:3.0')
self.check_intersection('2.5:2.7',
'1.0:2.7', '2.5:3.0')
self.check_intersection('0:1', ':', '0:1')
self.check_intersection(['1.0', '2.5:2.7'],
['1.0:2.7'], ['2.5:3.0','1.0'])
self.check_intersection(['2.5:2.7'],
['1.1:2.7'], ['2.5:3.0','1.0'])
self.check_intersection(['0:1'], [':'], ['0:1'])

View file

@ -1,8 +1,20 @@
import os import os
import re import re
import sys
import functools import functools
import inspect
from spack.util.filesystem import new_path from spack.util.filesystem import new_path
def has_method(cls, name):
for base in inspect.getmro(cls):
if base is object:
continue
if name in base.__dict__:
return True
return False
def memoized(obj): def memoized(obj):
"""Decorator that caches the results of a function, storing them """Decorator that caches the results of a function, storing them
in an attribute of that function.""" in an attribute of that function."""
@ -30,3 +42,54 @@ def list_modules(directory):
elif name.endswith('.py'): elif name.endswith('.py'):
yield re.sub('.py$', '', name) yield re.sub('.py$', '', name)
def key_ordering(cls):
"""Decorates a class with extra methods that implement rich comparison
operations and __hash__. The decorator assumes that the class
implements a function called _cmp_key(). The rich comparison operations
will compare objects using this key, and the __hash__ function will
return the hash of this key.
If a class already has __eq__, __ne__, __lt__, __le__, __gt__, or __ge__
defined, this decorator will overwrite them. If the class does not
have a _cmp_key method, then this will raise a TypeError.
"""
def setter(name, value):
value.__name__ = name
setattr(cls, name, value)
if not has_method(cls, '_cmp_key'):
raise TypeError("'%s' doesn't define _cmp_key()." % cls.__name__)
setter('__eq__', lambda s,o: o is not None and s._cmp_key() == o._cmp_key())
setter('__lt__', lambda s,o: o is not None and s._cmp_key() < o._cmp_key())
setter('__le__', lambda s,o: o is not None and s._cmp_key() <= o._cmp_key())
setter('__ne__', lambda s,o: o is None or s._cmp_key() != o._cmp_key())
setter('__gt__', lambda s,o: o is None or s._cmp_key() > o._cmp_key())
setter('__ge__', lambda s,o: o is None or s._cmp_key() >= o._cmp_key())
setter('__hash__', lambda self: hash(self._cmp_key()))
return cls
@key_ordering
class HashableMap(dict):
"""This is a hashable, comparable dictionary. Hash is performed on
a tuple of the values in the dictionary."""
def _cmp_key(self):
return tuple(sorted(self.values()))
def copy(self):
"""Type-agnostic clone method. Preserves subclass type."""
# Construct a new dict of my type
T = type(self)
clone = T()
# Copy everything from this dict into it.
for key in self:
clone[key] = self[key]
return clone

View file

@ -0,0 +1,23 @@
def comma_list(sequence, article=''):
if type(sequence) != list:
sequence = list(sequence)
if not sequence:
return
elif len(sequence) == 1:
return sequence[0]
else:
out = ', '.join(str(s) for s in sequence[:-1])
out += ', '
if article:
out += article + ' '
out += str(sequence[-1])
return out
def comma_or(sequence):
return comma_list(sequence, 'or')
def comma_and(sequence):
return comma_list(sequence, 'and')

View file

@ -13,8 +13,10 @@
__eq__, __ne__, __lt__, __gt__, __ge__, __le__, __hash__ __eq__, __ne__, __lt__, __gt__, __ge__, __le__, __hash__
__contains__ __contains__
satisfies
overlaps overlaps
merge union
intersection
concrete concrete
True if the Version, VersionRange or VersionList represents True if the Version, VersionRange or VersionList represents
a single version. a single version.
@ -161,6 +163,7 @@ def __str__(self):
def concrete(self): def concrete(self):
return self return self
@coerced @coerced
def __lt__(self, other): def __lt__(self, other):
"""Version comparison is designed for consistency with the way RPM """Version comparison is designed for consistency with the way RPM
@ -219,13 +222,21 @@ def overlaps(self, other):
@coerced @coerced
def merge(self, other): def union(self, other):
if self == other: if self == other:
return self return self
else: else:
return VersionList([self, other]) return VersionList([self, other])
@coerced
def intersection(self, other):
if self == other:
return self
else:
return VersionList()
@total_ordering @total_ordering
class VersionRange(object): class VersionRange(object):
def __init__(self, start, end): def __init__(self, start, end):
@ -295,9 +306,21 @@ def overlaps(self, other):
@coerced @coerced
def merge(self, other): def union(self, other):
return VersionRange(none_low.min(self.start, other.start), if self.overlaps(other):
none_high.max(self.end, other.end)) return VersionRange(none_low.min(self.start, other.start),
none_high.max(self.end, other.end))
else:
return VersionList([self, other])
@coerced
def intersection(self, other):
if self.overlaps(other):
return VersionRange(none_low.max(self.start, other.start),
none_high.min(self.end, other.end))
else:
return VersionList()
def __hash__(self): def __hash__(self):
@ -338,12 +361,12 @@ def add(self, version):
i = bisect_left(self, version) i = bisect_left(self, version)
while i-1 >= 0 and version.overlaps(self[i-1]): while i-1 >= 0 and version.overlaps(self[i-1]):
version = version.merge(self[i-1]) version = version.union(self[i-1])
del self.versions[i-1] del self.versions[i-1]
i -= 1 i -= 1
while i < len(self) and version.overlaps(self[i]): while i < len(self) and version.overlaps(self[i]):
version = version.merge(self[i]) version = version.union(self[i])
del self.versions[i] del self.versions[i]
self.versions.insert(i, version) self.versions.insert(i, version)
@ -384,25 +407,54 @@ def highest(self):
return self[-1].highest() return self[-1].highest()
def satisfies(self, other):
"""Synonym for overlaps."""
return self.overlaps(other)
@coerced @coerced
def overlaps(self, other): def overlaps(self, other):
if not other or not self: if not other or not self:
return False return False
i = o = 0 s = o = 0
while i < len(self) and o < len(other): while s < len(self) and o < len(other):
if self[i].overlaps(other[o]): if self[s].overlaps(other[o]):
return True return True
elif self[i] < other[o]: elif self[s] < other[o]:
i += 1 s += 1
else: else:
o += 1 o += 1
return False return False
@coerced @coerced
def merge(self, other): def update(self, other):
return VersionList(self.versions + other.versions) for v in other.versions:
self.add(v)
@coerced
def union(self, other):
result = self.copy()
result.update(other)
return result
@coerced
def intersection(self, other):
# TODO: make this faster. This is O(n^2).
result = VersionList()
for s in self:
for o in other:
result.add(s.intersection(o))
return result
@coerced
def intersect(self, other):
isection = self.intersection(other)
self.versions = isection.versions
@coerced @coerced