Checkpoint commit: much-improved spec class.

Still organizing things.
This commit is contained in:
Todd Gamblin 2013-10-07 17:57:27 -07:00
parent 157737efbe
commit 618571b807
27 changed files with 1419 additions and 369 deletions

View file

@ -19,6 +19,7 @@ sys.path.insert(0, SPACK_LIB_PATH)
del SPACK_FILE, SPACK_PREFIX, SPACK_LIB_PATH del SPACK_FILE, SPACK_PREFIX, SPACK_LIB_PATH
import spack import spack
import spack.tty as tty import spack.tty as tty
from spack.error import SpackError
# Command parsing # Command parsing
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -50,5 +51,12 @@ spack.debug = args.debug
command = spack.cmd.get_command(args.command) command = spack.cmd.get_command(args.command)
try: try:
command(parser, args) command(parser, args)
except SpackError, e:
if spack.debug:
# In debug mode, raise with a full stack trace.
raise
else:
# Otherwise print a nice simple message.
tty.die(e.message)
except KeyboardInterrupt: except KeyboardInterrupt:
tty.die("Got a keyboard interrupt from the user.") tty.die("Got a keyboard interrupt from the user.")

View file

@ -14,7 +14,8 @@ def setup_parser(subparser):
help="delete and re-expand the entire stage directory") help="delete and re-expand the entire stage directory")
subparser.add_argument('-d', "--dist", action="store_true", dest='dist', subparser.add_argument('-d', "--dist", action="store_true", dest='dist',
help="delete the downloaded archive.") help="delete the downloaded archive.")
subparser.add_argument('packages', nargs=argparse.REMAINDER, help="specs of packages to clean") subparser.add_argument('packages', nargs=argparse.REMAINDER,
help="specs of packages to clean")
def clean(parser, args): def clean(parser, args):

View file

@ -0,0 +1,9 @@
import spack.compilers
import spack.tty as tty
from spack.colify import colify
description = "List available compilers"
def compilers(parser, args):
tty.msg("Supported compilers")
colify(spack.compilers.supported_compilers(), indent=4)

View file

@ -57,13 +57,13 @@ def create(parser, args):
# make a stage and fetch the archive. # make a stage and fetch the archive.
try: try:
stage = Stage("%s-%s" % (name, version), url) stage = Stage("spack-create/%s-%s" % (name, version), url)
archive_file = stage.fetch() archive_file = stage.fetch()
except spack.FailedDownloadException, e: except spack.FailedDownloadException, e:
tty.die(e.message) tty.die(e.message)
md5 = spack.md5(archive_file) md5 = spack.md5(archive_file)
class_name = packages.class_for(name) class_name = packages.class_name_for_package_name(name)
# Write outa template for the file # Write outa template for the file
tty.msg("Editing %s." % path) tty.msg("Editing %s." % path)

View file

@ -9,53 +9,15 @@
import spack.url as url import spack.url as url
import spack.tty as tty import spack.tty as tty
description ="List spack packages" description ="List spack packages"
def setup_parser(subparser): def setup_parser(subparser):
subparser.add_argument('-v', '--versions', metavar="PACKAGE", dest='version_package',
help='List available versions of a package (experimental).')
subparser.add_argument('-i', '--installed', action='store_true', dest='installed', subparser.add_argument('-i', '--installed', action='store_true', dest='installed',
help='List installed packages for each platform along with versions.') help='List installed packages for each platform along with versions.')
def list(parser, args): def list(parser, args):
if args.installed: if args.installed:
pkgs = packages.installed_packages() colify(str(pkg) for pkg in packages.installed_packages())
for sys_type in pkgs:
print "%s:" % sys_type
package_vers = []
for pkg in pkgs[sys_type]:
pv = [pkg.name + "@" + v for v in pkg.installed_versions]
package_vers.extend(pv)
colify(sorted(package_vers), indent=4)
elif args.version_package:
pkg = packages.get(args.version_package)
# Run curl but grab the mime type from the http headers
try:
listing = spack.curl('-s', '-L', pkg.list_url, return_output=True)
except CalledProcessError:
tty.die("Fetching %s failed." % pkg.list_url,
"'list -v' requires an internet connection.")
url_regex = os.path.basename(url.wildcard_version(pkg.url))
strings = re.findall(url_regex, listing)
versions = []
wildcard = pkg.version.wildcard()
for s in strings:
match = re.search(wildcard, s)
if match:
versions.append(ver(match.group(0)))
if not versions:
tty.die("Found no versions for %s" % pkg.name,
"Listing versions is experimental. You may need to add the list_url",
"attribute to the package to tell Spack where to look for versions.")
colify(str(v) for v in reversed(sorted(set(versions))))
else: else:
colify(packages.all_package_names()) colify(packages.all_package_names())

View file

@ -0,0 +1,17 @@
import argparse
import spack.cmd
import spack.tty as tty
import spack
description = "parse specs and print them out to the command line."
def setup_parser(subparser):
subparser.add_argument('specs', nargs=argparse.REMAINDER, help="specs of packages")
def spec(parser, args):
specs = spack.cmd.parse_specs(args.specs)
for spec in specs:
print spec.colorized()
print " --> ", spec.concretized().colorized()
print spec.concretized().concrete()

View file

@ -0,0 +1,20 @@
import os
import re
from subprocess import CalledProcessError
import spack
import spack.packages as packages
import spack.url as url
import spack.tty as tty
from spack.colify import colify
from spack.version import ver
description ="List available versions of a package"
def setup_parser(subparser):
subparser.add_argument('package', metavar='PACKAGE', help='Package to list versions for')
def versions(parser, args):
pkg = packages.get(args.package)
colify(reversed(pkg.available_versions))

View file

@ -94,9 +94,10 @@ def colify(elts, **options):
indent = options.get("indent", 0) indent = options.get("indent", 0)
padding = options.get("padding", 2) padding = options.get("padding", 2)
# elts needs to be in an array so we can count the elements # elts needs to be an array of strings so we can count the elements
if not type(elts) == list: elts = [str(elt) for elt in elts]
elts = list(elts) if not elts:
return
if not output.isatty(): if not output.isatty():
for elt in elts: for elt in elts:

View file

@ -97,9 +97,11 @@ def __call__(self, match):
elif m == '@.': elif m == '@.':
return self.escape(0) return self.escape(0)
elif m == '@' or (style and not color): elif m == '@' or (style and not color):
raise ColorParseError("Incomplete color format: '%s'" % m) raise ColorParseError("Incomplete color format: '%s' in %s"
% (m, match.string))
elif color not in colors: elif color not in colors:
raise ColorParseError("invalid color specifier: '%s'" % color) raise ColorParseError("invalid color specifier: '%s' in '%s'"
% (color, match.string))
colored_text = '' colored_text = ''
if text: if text:
@ -141,6 +143,10 @@ def cprint(string, stream=sys.stdout, color=None):
"""Same as cwrite, but writes a trailing newline to the stream.""" """Same as cwrite, but writes a trailing newline to the stream."""
cwrite(string + "\n", stream, color) cwrite(string + "\n", stream, color)
def cescape(string):
"""Replace all @ with @@ in the string provided."""
return str(string).replace('@', '@@')
class ColorStream(object): class ColorStream(object):
def __init__(self, stream, color=None): def __init__(self, stream, color=None):

View file

@ -0,0 +1,16 @@
#
# This needs to be expanded for full compiler support.
#
import spack
import spack.compilers.gcc
from spack.utils import list_modules, memoized
@memoized
def supported_compilers():
return [c for c in list_modules(spack.compilers_path)]
def get_compiler():
return Compiler('gcc', spack.compilers.gcc.get_version())

View file

@ -0,0 +1,15 @@
#
# This is a stub module. It should be expanded when we implement full
# compiler support.
#
import subprocess
from spack.version import Version
cc = 'gcc'
cxx = 'g++'
fortran = 'gfortran'
def get_version():
v = subprocess.check_output([cc, '-dumpversion'])
return Version(v)

View file

@ -0,0 +1,15 @@
#
# This is a stub module. It should be expanded when we implement full
# compiler support.
#
import subprocess
from spack.version import Version
cc = 'icc'
cxx = 'icc'
fortran = 'ifort'
def get_version():
v = subprocess.check_output([cc, '-dumpversion'])
return Version(v)

View file

@ -1,20 +0,0 @@
"""
This file defines the dependence relation in spack.
"""
import packages
class Dependency(object):
"""Represents a dependency from one package to another.
"""
def __init__(self, name):
self.name = name
@property
def package(self):
return packages.get(self.name)
def __str__(self):
return "<dep: %s>" % self.name

View file

@ -0,0 +1,98 @@
import exceptions
import re
import os
import spack.spec as spec
from spack.utils import *
from spack.error import SpackError
class DirectoryLayout(object):
"""A directory layout is used to associate unique paths with specs.
Different installations are going to want differnet layouts for their
install, and they can use this to customize the nesting structure of
spack installs.
"""
def __init__(self, root):
self.root = root
def all_specs(self):
"""To be implemented by subclasses to traverse all specs for which there is
a directory within the root.
"""
raise NotImplementedError()
def relative_path_for_spec(self, spec):
"""Implemented by subclasses to return a relative path from the install
root to a unique location for the provided spec."""
raise NotImplementedError()
def path_for_spec(self, spec):
"""Return an absolute path from the root to a directory for the spec."""
if not spec.concrete:
raise ValueError("path_for_spec requires a concrete spec.")
path = self.relative_path_for_spec(spec)
assert(not path.startswith(self.root))
return os.path.join(self.root, path)
def remove_path_for_spec(self, spec):
"""Removes a prefix and any empty parent directories from the root."""
path = self.path_for_spec(spec)
assert(path.startswith(self.root))
if os.path.exists(path):
shutil.rmtree(path, True)
path = os.path.dirname(path)
while not os.listdir(path) and path != self.root:
os.rmdir(path)
path = os.path.dirname(path)
def traverse_dirs_at_depth(root, depth, path_tuple=(), curdepth=0):
"""For each directory at <depth> within <root>, return a tuple representing
the ancestors of that directory.
"""
if curdepth == depth and curdepth != 0:
yield path_tuple
elif depth > curdepth:
for filename in os.listdir(root):
child = os.path.join(root, filename)
if os.path.isdir(child):
child_tuple = path_tuple + (filename,)
for tup in traverse_dirs_at_depth(
child, depth, child_tuple, curdepth+1):
yield tup
class DefaultDirectoryLayout(DirectoryLayout):
def __init__(self, root):
super(DefaultDirectoryLayout, self).__init__(root)
def relative_path_for_spec(self, spec):
if not spec.concrete:
raise ValueError("relative_path_for_spec requires a concrete spec.")
return new_path(
spec.architecture,
spec.compiler,
"%s@%s%s%s" % (spec.name,
spec.version,
spec.variants,
spec.dependencies))
def all_specs(self):
if not os.path.isdir(self.root):
return
for path in traverse_dirs_at_depth(self.root, 3):
arch, compiler, last_dir = path
spec_str = "%s%%%s=%s" % (last_dir, compiler, arch)
yield spec.parse(spec_str)

View file

@ -2,6 +2,7 @@
from version import Version from version import Version
from utils import * from utils import *
import arch import arch
from directory_layout import DefaultDirectoryLayout
# This lives in $prefix/lib/spac/spack/__file__ # This lives in $prefix/lib/spac/spack/__file__
prefix = ancestor(__file__, 4) prefix = ancestor(__file__, 4)
@ -10,16 +11,23 @@
spack_file = new_path(prefix, "bin", "spack") spack_file = new_path(prefix, "bin", "spack")
# spack directory hierarchy # spack directory hierarchy
lib_path = new_path(prefix, "lib", "spack") lib_path = new_path(prefix, "lib", "spack")
env_path = new_path(lib_path, "env") env_path = new_path(lib_path, "env")
module_path = new_path(lib_path, "spack") module_path = new_path(lib_path, "spack")
packages_path = new_path(module_path, "packages") packages_path = new_path(module_path, "packages")
test_path = new_path(module_path, "test") compilers_path = new_path(module_path, "compilers")
test_path = new_path(module_path, "test")
var_path = new_path(prefix, "var", "spack") var_path = new_path(prefix, "var", "spack")
stage_path = new_path(var_path, "stage") stage_path = new_path(var_path, "stage")
install_path = new_path(prefix, "opt") install_path = new_path(prefix, "opt")
#
# This controls how spack lays out install prefixes and
# stage directories.
#
install_layout = DefaultDirectoryLayout(install_path)
# Version information # Version information
spack_version = Version("0.2") spack_version = Version("0.2")

View file

@ -61,6 +61,7 @@ def __call__(self, package_self, *args, **kwargs):
If none is found, call the default function that this was If none is found, call the default function that this was
initialized with. If there is no default, raise an error. initialized with. If there is no default, raise an error.
""" """
# TODO: make this work with specs.
sys_type = package_self.sys_type sys_type = package_self.sys_type
function = self.function_map.get(sys_type, self.default) function = self.function_map.get(sys_type, self.default)
if function: if function:

View file

@ -0,0 +1,60 @@
"""
Functions for comparing values that may potentially be None.
Functions prefixed with 'none_low_' treat None as less than all other values.
Functions prefixed with 'none_high_' treat None as greater than all other values.
"""
def none_low_lt(lhs, rhs):
"""Less-than comparison. None is lower than any value."""
return lhs != rhs and (lhs == None or (rhs != None and lhs < rhs))
def none_low_le(lhs, rhs):
"""Less-than-or-equal comparison. None is less than any value."""
return lhs == rhs or none_low_lt(lhs, rhs)
def none_low_gt(lhs, rhs):
"""Greater-than comparison. None is less than any value."""
return lhs != rhs and not none_low_lt(lhs, rhs)
def none_low_ge(lhs, rhs):
"""Greater-than-or-equal comparison. None is less than any value."""
return lhs == rhs or none_low_gt(lhs, rhs)
def none_low_min(lhs, rhs):
"""Minimum function where None is less than any value."""
if lhs == None or rhs == None:
return None
else:
return min(lhs, rhs)
def none_high_lt(lhs, rhs):
"""Less-than comparison. None is greater than any value."""
return lhs != rhs and (rhs == None or (lhs != None and lhs < rhs))
def none_high_le(lhs, rhs):
"""Less-than-or-equal comparison. None is greater than any value."""
return lhs == rhs or none_high_lt(lhs, rhs)
def none_high_gt(lhs, rhs):
"""Greater-than comparison. None is greater than any value."""
return lhs != rhs and not none_high_lt(lhs, rhs)
def none_high_ge(lhs, rhs):
"""Greater-than-or-equal comparison. None is greater than any value."""
return lhs == rhs or none_high_gt(lhs, rhs)
def none_high_max(lhs, rhs):
"""Maximum function where None is greater than any value."""
if lhs == None or rhs == None:
return None
else:
return max(lhs, rhs)

View file

@ -9,7 +9,6 @@
rundown on spack and how it differs from homebrew, look at the rundown on spack and how it differs from homebrew, look at the
README. README.
""" """
import sys
import inspect import inspect
import os import os
import re import re
@ -18,18 +17,18 @@
import shutil import shutil
from spack import * from spack import *
import spack.spec
import packages import packages
import tty import tty
import attr import attr
import validate import validate
import url import url
import arch
from spec import Compiler from spec import Compiler
from version import Version from version import *
from multi_function import platform from multi_function import platform
from stage import Stage from stage import Stage
from dependency import *
class Package(object): class Package(object):
@ -106,6 +105,21 @@ def install(self, prefix):
install() This function tells spack how to build and install the install() This function tells spack how to build and install the
software it downloaded. software it downloaded.
Optional Attributes
---------------------
You can also optionally add these attributes, if needed:
list_url
Webpage to scrape for available version strings. Default is the
directory containing the tarball; use this if the default isn't
correct so that invoking 'spack versions' will work for this
package.
url_version(self, version)
When spack downloads packages at particular versions, it just
converts version to string with str(version). Override this if
your package needs special version formatting in its URL. boost
is an example of a package that needs this.
Creating Packages Creating Packages
=================== ===================
As a package creator, you can probably ignore most of the preceding As a package creator, you can probably ignore most of the preceding
@ -209,7 +223,7 @@ class SomePackage(Package):
A package's lifecycle over a run of Spack looks something like this: A package's lifecycle over a run of Spack looks something like this:
packge p = new Package() # Done for you by spack p = Package() # Done for you by spack
p.do_fetch() # called by spack commands in spack/cmd. p.do_fetch() # called by spack commands in spack/cmd.
p.do_stage() # see spack.stage.Stage docs. p.do_stage() # see spack.stage.Stage docs.
@ -231,9 +245,15 @@ class SomePackage(Package):
clean() (some of them do this), and others to provide custom behavior. clean() (some of them do this), and others to provide custom behavior.
""" """
#
# These variables are per-package metadata will be defined by subclasses.
#
"""By default a package has no dependencies.""" """By default a package has no dependencies."""
dependencies = [] dependencies = []
#
# These are default values for instance variables.
#
"""By default we build in parallel. Subclasses can override this.""" """By default we build in parallel. Subclasses can override this."""
parallel = True parallel = True
@ -243,19 +263,14 @@ class SomePackage(Package):
"""Controls whether install and uninstall check deps before running.""" """Controls whether install and uninstall check deps before running."""
ignore_dependencies = False ignore_dependencies = False
# TODO: multi-compiler support def __init__(self, spec):
"""Default compiler for this package""" # These attributes are required for all packages.
compiler = Compiler('gcc')
def __init__(self, sys_type = arch.sys_type()):
# Check for attributes that derived classes must set.
attr.required(self, 'homepage') attr.required(self, 'homepage')
attr.required(self, 'url') attr.required(self, 'url')
attr.required(self, 'md5') attr.required(self, 'md5')
# Architecture for this package. # this determines how the package should be built.
self.sys_type = sys_type self.spec = spec
# Name of package is the name of its module (the file that contains it) # Name of package is the name of its module (the file that contains it)
self.name = inspect.getmodulename(self.module.__file__) self.name = inspect.getmodulename(self.module.__file__)
@ -277,16 +292,16 @@ def __init__(self, sys_type = arch.sys_type()):
elif type(self.version) == string: elif type(self.version) == string:
self.version = Version(self.version) self.version = Version(self.version)
# This adds a bunch of convenience commands to the package's module scope. # Empty at first; only compute dependent packages if necessary
self.add_commands_to_module()
# Empty at first; only compute dependents if necessary
self._dependents = None self._dependents = None
# stage used to build this package. # This is set by scraping a web page.
self.stage = Stage(self.stage_name, self.url) self._available_versions = None
# Set a default list URL (place to find lots of versions) # stage used to build this package.
self.stage = Stage("%s-%s" % (self.name, self.version), self.url)
# Set a default list URL (place to find available versions)
if not hasattr(self, 'list_url'): if not hasattr(self, 'list_url'):
self.list_url = os.path.dirname(self.url) self.list_url = os.path.dirname(self.url)
@ -356,6 +371,24 @@ def dependents(self):
return tuple(self._dependents) return tuple(self._dependents)
def sanity_check(self):
"""Ensure that this package and its dependencies don't have conflicting
requirements."""
deps = sorted(self.all_dependencies, key=lambda d: d.name)
@property
@memoized
def all_dependencies(self):
"""Set of all transitive dependencies of this package."""
all_deps = set(self.dependencies)
for dep in self.dependencies:
dep_pkg = packages.get(dep.name)
all_deps = all_deps.union(dep_pkg.all_dependencies)
return all_deps
@property @property
def installed(self): def installed(self):
return os.path.exists(self.prefix) return os.path.exists(self.prefix)
@ -379,35 +412,10 @@ def all_dependents(self):
return tuple(all_deps) return tuple(all_deps)
@property
def stage_name(self):
return "%s-%s" % (self.name, self.version)
#
# Below properties determine the path where this package is installed.
#
@property
def platform_path(self):
"""Directory for binaries for the current platform."""
return new_path(install_path, self.sys_type)
@property
def package_path(self):
"""Directory for different versions of this package. Lives just above prefix."""
return new_path(self.platform_path, self.name)
@property
def installed_versions(self):
return [ver for ver in os.listdir(self.package_path)
if os.path.isdir(new_path(self.package_path, ver))]
@property @property
def prefix(self): def prefix(self):
"""Packages are installed in $spack_prefix/opt/<sys_type>/<name>/<version>""" """Get the prefix into which this package should be installed."""
return new_path(self.package_path, self.version) return spack.install_layout.path_for_spec(self.spec)
def url_version(self, version): def url_version(self, version):
@ -417,24 +425,14 @@ def url_version(self, version):
override this, e.g. for boost versions where you need to ensure that there override this, e.g. for boost versions where you need to ensure that there
are _'s in the download URL. are _'s in the download URL.
""" """
return version.string return str(version)
def remove_prefix(self): def remove_prefix(self):
"""Removes the prefix for a package along with any empty parent directories.""" """Removes the prefix for a package along with any empty parent directories."""
if self.dirty: if self.dirty:
return return
spack.install_layout.remove_path_for_spec(self.spec)
if os.path.exists(self.prefix):
shutil.rmtree(self.prefix, True)
for dir in (self.package_path, self.platform_path):
if not os.path.isdir(dir):
continue
if not os.listdir(dir):
os.rmdir(dir)
else:
break
def do_fetch(self): def do_fetch(self):
@ -469,6 +467,9 @@ def do_install(self):
"""This class should call this version of the install method. """This class should call this version of the install method.
Package implementations should override install(). Package implementations should override install().
""" """
if not self.spec.concrete:
raise ValueError("Can only install concrete packages.")
if os.path.exists(self.prefix): if os.path.exists(self.prefix):
tty.msg("%s is already installed." % self.name) tty.msg("%s is already installed." % self.name)
tty.pkg(self.prefix) tty.pkg(self.prefix)
@ -480,6 +481,10 @@ def do_install(self):
self.do_stage() self.do_stage()
self.setup_install_environment() self.setup_install_environment()
# Add convenience commands to the package's module scope to
# make building easier.
self.add_commands_to_module()
tty.msg("Building %s." % self.name) tty.msg("Building %s." % self.name)
try: try:
self.install(self.prefix) self.install(self.prefix)
@ -599,6 +604,34 @@ def do_clean_dist(self):
tty.msg("Successfully cleaned %s" % self.name) tty.msg("Successfully cleaned %s" % self.name)
@property
def available_versions(self):
if not self._available_versions:
self._available_versions = VersionList()
try:
# Run curl but grab the mime type from the http headers
listing = spack.curl('-s', '-L', self.list_url, return_output=True)
url_regex = os.path.basename(url.wildcard_version(self.url))
strings = re.findall(url_regex, listing)
wildcard = self.version.wildcard()
for s in strings:
match = re.search(wildcard, s)
if match:
self._available_versions.add(ver(match.group(0)))
except CalledProcessError:
tty.warn("Fetching %s failed." % self.list_url,
"Package.available_versions requires an internet connection.",
"Version list may be incomplete.")
if not self._available_versions:
tty.warn("Found no versions for %s" % self.name,
"Packate.available_versions may require adding the list_url attribute",
"to the package to tell Spack where to look for versions.")
self._available_versions = [self.version]
return self._available_versions
class MakeExecutable(Executable): class MakeExecutable(Executable):
"""Special Executable for make so the user can specify parallel or """Special Executable for make so the user can specify parallel or
not on a per-invocation basis. Using 'parallel' as a kwarg will not on a per-invocation basis. Using 'parallel' as a kwarg will

View file

@ -7,68 +7,51 @@
import spack import spack
import spack.error import spack.error
import spack.spec
from spack.utils import * from spack.utils import *
import spack.arch as arch import spack.arch as arch
# Valid package names can contain '-' but can't start with it.
# Valid package names -- can contain - but can't start with it. valid_package_re = r'^\w[\w-]*$'
valid_package = r'^\w[\w-]*$'
# Don't allow consecutive [_-] in package names # Don't allow consecutive [_-] in package names
invalid_package = r'[_-][_-]+' invalid_package_re = r'[_-][_-]+'
instances = {} instances = {}
def get(spec):
spec = spack.spec.make_spec(spec)
if not spec in instances:
package_class = get_class_for_package_name(spec.name)
instances[spec] = package_class(spec)
def get(pkg, arch=arch.sys_type()): return instances[spec]
key = (pkg, arch)
if not key in instances:
package_class = get_class(pkg)
instances[key] = package_class(arch)
return instances[key]
class InvalidPackageNameError(spack.error.SpackError): def valid_package_name(pkg_name):
"""Raised when we encounter a bad package name.""" return (re.match(valid_package_re, pkg_name) and
def __init__(self, name): not re.search(invalid_package_re, pkg_name))
super(InvalidPackageNameError, self).__init__(
"Invalid package name: " + name)
self.name = name
def valid_name(pkg): def validate_package_name(pkg_name):
return re.match(valid_package, pkg) and not re.search(invalid_package, pkg) if not valid_package_name(pkg_name):
raise InvalidPackageNameError(pkg_name)
def validate_name(pkg): def filename_for_package_name(pkg_name):
if not valid_name(pkg):
raise InvalidPackageNameError(pkg)
def filename_for(pkg):
"""Get the filename where a package name should be stored.""" """Get the filename where a package name should be stored."""
validate_name(pkg) validate_package_name(pkg_name)
return new_path(spack.packages_path, "%s.py" % pkg) return new_path(spack.packages_path, "%s.py" % pkg_name)
def installed_packages(**kwargs): def installed_packages():
"""Returns a dict from systype strings to lists of Package objects.""" return spack.install_layout.all_specs()
pkgs = {}
if not os.path.isdir(spack.install_path):
return pkgs
for sys_type in os.listdir(spack.install_path):
sys_type = sys_type
sys_path = new_path(spack.install_path, sys_type)
pkgs[sys_type] = [get(pkg) for pkg in os.listdir(sys_path)
if os.path.isdir(new_path(sys_path, pkg))]
return pkgs
def all_package_names(): def all_package_names():
"""Generator function for all packages.""" """Generator function for all packages."""
for mod in list_modules(spack.packages_path): for module in list_modules(spack.packages_path):
yield mod yield module
def all_packages(): def all_packages():
@ -76,12 +59,12 @@ def all_packages():
yield get(name) yield get(name)
def class_for(pkg): def class_name_for_package_name(pkg_name):
"""Get a name for the class the package file should contain. Note that """Get a name for the class the package file should contain. Note that
conflicts don't matter because the classes are in different modules. conflicts don't matter because the classes are in different modules.
""" """
validate_name(pkg) validate_package_name(pkg_name)
class_name = string.capwords(pkg.replace('_', '-'), '-') class_name = string.capwords(pkg_name.replace('_', '-'), '-')
# If a class starts with a number, prefix it with Number_ to make it a valid # If a class starts with a number, prefix it with Number_ to make it a valid
# Python class name. # Python class name.
@ -91,25 +74,27 @@ def class_for(pkg):
return class_name return class_name
def get_class(pkg): def get_class_for_package_name(pkg_name):
file = filename_for(pkg) file_name = filename_for_package_name(pkg_name)
if os.path.exists(file): if os.path.exists(file_name):
if not os.path.isfile(file): if not os.path.isfile(file_name):
tty.die("Something's wrong. '%s' is not a file!" % file) tty.die("Something's wrong. '%s' is not a file!" % file_name)
if not os.access(file, os.R_OK): if not os.access(file_name, os.R_OK):
tty.die("Cannot read '%s'!" % file) tty.die("Cannot read '%s'!" % file_name)
else:
raise UnknownPackageError(pkg_name)
class_name = pkg.capitalize() class_name = pkg_name.capitalize()
try: try:
module_name = "%s.%s" % (__name__, pkg) module_name = "%s.%s" % (__name__, pkg_name)
module = __import__(module_name, fromlist=[class_name]) module = __import__(module_name, fromlist=[class_name])
except ImportError, e: except ImportError, e:
tty.die("Error while importing %s.%s:\n%s" % (pkg, class_name, e.message)) tty.die("Error while importing %s.%s:\n%s" % (pkg_name, class_name, e.message))
klass = getattr(module, class_name) klass = getattr(module, class_name)
if not inspect.isclass(klass): if not inspect.isclass(klass):
tty.die("%s.%s is not a class" % (pkg, class_name)) tty.die("%s.%s is not a class" % (pkg_name, class_name))
return klass return klass
@ -152,3 +137,19 @@ def quote(string):
for pair in deps: for pair in deps:
out.write(' "%s" -> "%s"\n' % pair) out.write(' "%s" -> "%s"\n' % pair)
out.write('}\n') out.write('}\n')
class InvalidPackageNameError(spack.error.SpackError):
"""Raised when we encounter a bad package name."""
def __init__(self, name):
super(InvalidPackageNameError, self).__init__(
"Invalid package name: " + name)
self.name = name
class UnknownPackageError(spack.error.SpackError):
"""Raised when we encounter a package spack doesn't have."""
def __init__(self, name):
super(UnknownPackageError, self).__init__("Package %s not found." % name)
self.name = name

View file

@ -45,18 +45,19 @@ class Mpileaks(Package):
spack install mpileaks ^mpich spack install mpileaks ^mpich
""" """
import sys import sys
from dependency import Dependency import spack.spec
def depends_on(*args): def depends_on(*specs):
"""Adds a dependencies local variable in the locals of """Adds a dependencies local variable in the locals of
the calling class, based on args. the calling class, based on args.
""" """
# Get the enclosing package's scope and add deps to it. # Get the enclosing package's scope and add deps to it.
locals = sys._getframe(1).f_locals locals = sys._getframe(1).f_locals
dependencies = locals.setdefault("dependencies", []) dependencies = locals.setdefault("dependencies", [])
for name in args: for string in specs:
dependencies.append(Dependency(name)) for spec in spack.spec.parse(string):
dependencies.append(spec)
def provides(*args): def provides(*args):

View file

@ -68,101 +68,275 @@
import tty import tty
import spack.parse import spack.parse
import spack.error import spack.error
from spack.version import Version, VersionRange import spack.compilers
from spack.color import ColorStream import spack.compilers.gcc
import spack.packages as packages
import spack.arch as arch
from spack.version import *
from spack.color import *
# Color formats for various parts of specs when using color output. """This map determines the coloring of specs when using color output.
compiler_fmt = '@g' We make the fields different colors to enhance readability.
version_fmt = '@c' See spack.color for descriptions of the color codes.
architecture_fmt = '@m' """
variant_enabled_fmt = '@B' color_formats = {'%' : '@g', # compiler
variant_disabled_fmt = '@r' '@' : '@c', # version
'=' : '@m', # architecture
'+' : '@B', # enable variant
'~' : '@r', # disable variant
'^' : '@.'} # dependency
"""Regex used for splitting by spec field separators."""
separators = '[%s]' % ''.join(color_formats.keys())
class SpecError(spack.error.SpackError): def colorize_spec(spec):
"""Superclass for all errors that occur while constructing specs.""" """Returns a spec colorized according to the colors specified in
def __init__(self, message): color_formats."""
super(SpecError, self).__init__(message) class insert_color:
def __init__(self):
self.last = None
class DuplicateDependencyError(SpecError): def __call__(self, match):
"""Raised when the same dependency occurs in a spec twice.""" # ignore compiler versions (color same as compiler)
def __init__(self, message): sep = match.group(0)
super(DuplicateDependencyError, self).__init__(message) if self.last == '%' and sep == '@':
return cescape(sep)
self.last = sep
class DuplicateVariantError(SpecError): return '%s%s' % (color_formats[sep], cescape(sep))
"""Raised when the same variant occurs in a spec twice."""
def __init__(self, message):
super(DuplicateVariantError, self).__init__(message)
class DuplicateCompilerError(SpecError): return colorize(re.sub(separators, insert_color(), str(spec)) + '@.')
"""Raised when the same compiler occurs in a spec twice."""
def __init__(self, message):
super(DuplicateCompilerError, self).__init__(message)
class DuplicateArchitectureError(SpecError):
"""Raised when the same architecture occurs in a spec twice."""
def __init__(self, message):
super(DuplicateArchitectureError, self).__init__(message)
class Compiler(object): class Compiler(object):
def __init__(self, name): """The Compiler field represents the compiler or range of compiler
versions that a package should be built with. Compilers have a
name and a version list.
"""
def __init__(self, name, version=None):
if name not in spack.compilers.supported_compilers():
raise UnknownCompilerError(name)
self.name = name self.name = name
self.versions = [] self.versions = VersionList()
if version:
self.versions.add(version)
def add_version(self, version):
self.versions.append(version)
def stringify(self, **kwargs): def _add_version(self, version):
color = kwargs.get("color", False) self.versions.add(version)
out = StringIO()
out.write("%s{%%%s}" % (compiler_fmt, self.name))
if self.versions: @property
vlist = ",".join(str(v) for v in sorted(self.versions)) def concrete(self):
out.write("%s{@%s}" % (compiler_fmt, vlist)) return self.versions.concrete
return out.getvalue()
def _concretize(self):
"""If this spec could describe more than one version, variant, or build
of a package, this will resolve it to be concrete.
"""
# TODO: support compilers other than GCC.
if self.concrete:
return
gcc_version = spack.compilers.gcc.get_version()
self.versions = VersionList([gcc_version])
def concretized(self):
clone = self.copy()
clone._concretize()
return clone
@property
def version(self):
if not self.concrete:
raise SpecError("Spec is not concrete: " + str(self))
return self.versions[0]
def copy(self):
clone = Compiler(self.name)
clone.versions = self.versions.copy()
return clone
def __eq__(self, other):
return (self.name, self.versions) == (other.name, other.versions)
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash((self.name, self.versions))
def __str__(self): def __str__(self):
return self.stringify() out = self.name
if self.versions:
vlist = ",".join(str(v) for v in sorted(self.versions))
out += "@%s" % vlist
return out
@total_ordering
class Variant(object):
"""Variants are named, build-time options for a package. Names depend
on the particular package being built, and each named variant can
be enabled or disabled.
"""
def __init__(self, name, enabled):
self.name = name
self.enabled = enabled
def __eq__(self, other):
return self.name == other.name and self.enabled == other.enabled
def __ne__(self, other):
return not (self == other)
@property
def tuple(self):
return (self.name, self.enabled)
def __hash__(self):
return hash(self.tuple)
def __lt__(self, other):
return self.tuple < other.tuple
def __str__(self):
out = '+' if self.enabled else '~'
return out + self.name
@total_ordering
class HashableMap(dict):
"""This is a hashable, comparable dictionary. Hash is performed on
a tuple of the values in the dictionary."""
def __eq__(self, other):
return (len(self) == len(other) and
sorted(self.values()) == sorted(other.values()))
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
return tuple(sorted(self.values())) < tuple(sorted(other.values()))
def __hash__(self):
return hash(tuple(sorted(self.values())))
def copy(self):
"""Type-agnostic clone method. Preserves subclass type."""
# Construct a new dict of my type
T = type(self)
clone = T()
# Copy everything from this dict into it.
for key in self:
clone[key] = self[key]
return clone
class VariantMap(HashableMap):
def __str__(self):
sorted_keys = sorted(self.keys())
return ''.join(str(self[key]) for key in sorted_keys)
class DependencyMap(HashableMap):
"""Each spec has a DependencyMap containing specs for its dependencies.
The DependencyMap is keyed by name. """
@property
def concrete(self):
return all(d.concrete for d in self.values())
def __str__(self):
sorted_keys = sorted(self.keys())
return ''.join(
["^" + str(self[name]) for name in sorted_keys])
@total_ordering
class Spec(object): class Spec(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self._package = None self.versions = VersionList()
self.versions = [] self.variants = VariantMap()
self.variants = {}
self.architecture = None self.architecture = None
self.compiler = None self.compiler = None
self.dependencies = {} self.dependencies = DependencyMap()
def add_version(self, version): #
self.versions.append(version) # Private routines here are called by the parser when building a spec.
#
def _add_version(self, version):
"""Called by the parser to add an allowable version."""
self.versions.add(version)
def add_variant(self, name, enabled):
def _add_variant(self, name, enabled):
"""Called by the parser to add a variant."""
if name in self.variants: raise DuplicateVariantError( if name in self.variants: raise DuplicateVariantError(
"Cannot specify variant '%s' twice" % name) "Cannot specify variant '%s' twice" % name)
self.variants[name] = enabled self.variants[name] = Variant(name, enabled)
def add_compiler(self, compiler):
def _set_compiler(self, compiler):
"""Called by the parser to set the compiler."""
if self.compiler: raise DuplicateCompilerError( if self.compiler: raise DuplicateCompilerError(
"Spec for '%s' cannot have two compilers." % self.name) "Spec for '%s' cannot have two compilers." % self.name)
self.compiler = compiler self.compiler = compiler
def add_architecture(self, architecture):
def _set_architecture(self, architecture):
"""Called by the parser to set the architecture."""
if self.architecture: raise DuplicateArchitectureError( if self.architecture: raise DuplicateArchitectureError(
"Spec for '%s' cannot have two architectures." % self.name) "Spec for '%s' cannot have two architectures." % self.name)
self.architecture = architecture self.architecture = architecture
def add_dependency(self, dep):
def _add_dependency(self, dep):
"""Called by the parser to add another spec as a dependency."""
if dep.name in self.dependencies: if dep.name in self.dependencies:
raise DuplicateDependencyError("Cannot depend on '%s' twice" % dep) raise DuplicateDependencyError("Cannot depend on '%s' twice" % dep)
self.dependencies[dep.name] = dep self.dependencies[dep.name] = dep
def canonicalize(self):
"""Ensures that the spec is in canonical form. @property
def concrete(self):
return (self.versions.concrete
# TODO: support variants
and self.architecture
and self.compiler and self.compiler.concrete
and self.dependencies.concrete)
def _concretize(self):
"""A spec is concrete if it describes one build of a package uniquely.
This will ensure that this spec is concrete.
If this spec could describe more than one version, variant, or build
of a package, this will resolve it to be concrete.
Ensures that the spec is in canonical form.
This means: This means:
1. All dependencies of this package and of its dependencies are 1. All dependencies of this package and of its dependencies are
@ -173,49 +347,164 @@ def canonicalize(self):
that each package exists an that spec criteria don't violate package that each package exists an that spec criteria don't violate package
criteria. criteria.
""" """
pass # TODO: modularize the process of selecting concrete versions.
# There should be a set of user-configurable policies for these decisions.
self.check_sanity()
@property # take the system's architecture for starters
def package(self): if not self.architecture:
if self._package == None: self.architecture = arch.sys_type()
self._package = packages.get(self.name)
return self._package
def stringify(self, **kwargs):
color = kwargs.get("color", False)
out = ColorStream(StringIO(), color)
out.write("%s" % self.name)
if self.versions:
vlist = ",".join(str(v) for v in sorted(self.versions))
out.write("%s{@%s}" % (version_fmt, vlist))
if self.compiler: if self.compiler:
out.write(self.compiler.stringify(color=color)) self.compiler._concretize()
for name in sorted(self.variants.keys()): # TODO: handle variants.
enabled = self.variants[name]
if enabled:
out.write('%s{+%s}' % (variant_enabled_fmt, name))
else:
out.write('%s{~%s}' % (variant_disabled_fmt, name))
if self.architecture: pkg = packages.get(self.name)
out.write("%s{=%s}" % (architecture_fmt, self.architecture))
for name in sorted(self.dependencies.keys()): # Take the highest version in a range
dep = " ^" + self.dependencies[name].stringify(color=color) if not self.versions.concrete:
out.write(dep, raw=True) preferred = self.versions.highest() or pkg.version
self.versions = VersionList([preferred])
return out.getvalue() # Ensure dependencies have right versions
def check_sanity(self):
"""Check names of packages and dependency validity."""
self.check_package_name_sanity()
self.check_dependency_sanity()
self.check_dependence_constraint_sanity()
def check_package_name_sanity(self):
"""Ensure that all packages mentioned in the spec exist."""
packages.get(self.name)
for dep in self.dependencies.values():
packages.get(dep.name)
def check_dependency_sanity(self):
"""Ensure that dependencies specified on the spec are actual
dependencies of the package it represents.
"""
pkg = packages.get(self.name)
dep_names = set(dep.name for dep in pkg.all_dependencies)
invalid_dependencies = [d.name for d in self.dependencies.values()
if d.name not in dep_names]
if invalid_dependencies:
raise InvalidDependencyException(
"The packages (%s) are not dependencies of %s" %
(','.join(invalid_dependencies), self.name))
def check_dependence_constraint_sanity(self):
"""Ensure that package's dependencies have consistent constraints on
their dependencies.
"""
pkg = packages.get(self.name)
specs = {}
for spec in pkg.all_dependencies:
if not spec.name in specs:
specs[spec.name] = spec
continue
merged = specs[spec.name]
# Specs in deps can't be disjoint.
if not spec.versions.overlaps(merged.versions):
raise InvalidConstraintException(
"One package %s, version constraint %s conflicts with %s"
% (pkg.name, spec.versions, merged.versions))
def merge(self, other):
"""Considering these specs as constraints, attempt to merge.
Raise an exception if specs are disjoint.
"""
pass
def concretized(self):
clone = self.copy()
clone._concretize()
return clone
def copy(self):
clone = Spec(self.name)
clone.versions = self.versions.copy()
clone.variants = self.variants.copy()
clone.architecture = self.architecture
clone.compiler = None
if self.compiler:
clone.compiler = self.compiler.copy()
clone.dependencies = self.dependencies.copy()
return clone
@property
def version(self):
if not self.concrete:
raise SpecError("Spec is not concrete: " + str(self))
return self.versions[0]
@property
def tuple(self):
return (self.name, self.versions, self.variants,
self.architecture, self.compiler, self.dependencies)
@property
def tuple(self):
return (self.name, self.versions, self.variants, self.architecture,
self.compiler, self.dependencies)
def __eq__(self, other):
return self.tuple == other.tuple
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
return self.tuple < other.tuple
def __hash__(self):
return hash(self.tuple)
def colorized(self):
return colorize_spec(self)
def __repr__(self):
return str(self)
def write(self, stream=sys.stdout):
isatty = stream.isatty()
stream.write(self.stringify(color=isatty))
def __str__(self): def __str__(self):
return self.stringify() out = self.name
# If the version range is entirely open, omit it
if self.versions and self.versions != VersionList([':']):
out += "@%s" % self.versions
if self.compiler:
out += "%%%s" % self.compiler
out += str(self.variants)
if self.architecture:
out += "=%s" % self.architecture
out += str(self.dependencies)
return out
# #
# These are possible token types in the spec grammar. # These are possible token types in the spec grammar.
@ -254,7 +543,7 @@ def do_parse(self):
if not specs: if not specs:
self.last_token_error("Dependency has no package") self.last_token_error("Dependency has no package")
self.expect(ID) self.expect(ID)
specs[-1].add_dependency(self.spec()) specs[-1]._add_dependency(self.spec())
else: else:
self.unexpected_token() self.unexpected_token()
@ -265,28 +554,34 @@ def do_parse(self):
def spec(self): def spec(self):
self.check_identifier() self.check_identifier()
spec = Spec(self.token.value) spec = Spec(self.token.value)
added_version = False
while self.next: while self.next:
if self.accept(AT): if self.accept(AT):
vlist = self.version_list() vlist = self.version_list()
for version in vlist: for version in vlist:
spec.add_version(version) spec._add_version(version)
added_version = True
elif self.accept(ON): elif self.accept(ON):
spec.add_variant(self.variant(), True) spec._add_variant(self.variant(), True)
elif self.accept(OFF): elif self.accept(OFF):
spec.add_variant(self.variant(), False) spec._add_variant(self.variant(), False)
elif self.accept(PCT): elif self.accept(PCT):
spec.add_compiler(self.compiler()) spec._set_compiler(self.compiler())
elif self.accept(EQ): elif self.accept(EQ):
spec.add_architecture(self.architecture()) spec._set_architecture(self.architecture())
else: else:
break break
# If there was no version in the spec, consier it an open range
if not added_version:
spec.versions = VersionList([':'])
return spec return spec
@ -318,12 +613,9 @@ def version(self):
# No colon and no id: invalid version. # No colon and no id: invalid version.
self.next_token_error("Invalid version specifier") self.next_token_error("Invalid version specifier")
if not start and not end: if start: start = Version(start)
self.next_token_error("Lone colon: version range needs a version") if end: end = Version(end)
else: return VersionRange(start, end)
if start: start = Version(start)
if end: end = Version(end)
return VersionRange(start, end)
def version_list(self): def version_list(self):
@ -341,7 +633,7 @@ def compiler(self):
if self.accept(AT): if self.accept(AT):
vlist = self.version_list() vlist = self.version_list()
for version in vlist: for version in vlist:
compiler.add_version(version) compiler._add_version(version)
return compiler return compiler
@ -357,3 +649,79 @@ def check_identifier(self):
def parse(string): def parse(string):
"""Returns a list of specs from an input string.""" """Returns a list of specs from an input string."""
return SpecParser().parse(string) return SpecParser().parse(string)
def parse_one(string):
"""Parses a string containing only one spec, then returns that
spec. If more than one spec is found, raises a ValueError.
"""
spec_list = parse(string)
if len(spec_list) > 1:
raise ValueError("string contains more than one spec!")
elif len(spec_list) < 1:
raise ValueError("string contains no specs!")
return spec_list[0]
def make_spec(spec_like):
if type(spec_like) == str:
specs = parse(spec_like)
if len(specs) != 1:
raise ValueError("String contains multiple specs: '%s'" % spec_like)
return specs[0]
elif type(spec_like) == Spec:
return spec_like
else:
raise TypeError("Can't make spec out of %s" % type(spec_like))
class SpecError(spack.error.SpackError):
"""Superclass for all errors that occur while constructing specs."""
def __init__(self, message):
super(SpecError, self).__init__(message)
class DuplicateDependencyError(SpecError):
"""Raised when the same dependency occurs in a spec twice."""
def __init__(self, message):
super(DuplicateDependencyError, self).__init__(message)
class DuplicateVariantError(SpecError):
"""Raised when the same variant occurs in a spec twice."""
def __init__(self, message):
super(DuplicateVariantError, self).__init__(message)
class DuplicateCompilerError(SpecError):
"""Raised when the same compiler occurs in a spec twice."""
def __init__(self, message):
super(DuplicateCompilerError, self).__init__(message)
class UnknownCompilerError(SpecError):
"""Raised when the user asks for a compiler spack doesn't know about."""
def __init__(self, compiler_name):
super(UnknownCompilerError, self).__init__(
"Unknown compiler: %s" % compiler_name)
class DuplicateArchitectureError(SpecError):
"""Raised when the same architecture occurs in a spec twice."""
def __init__(self, message):
super(DuplicateArchitectureError, self).__init__(message)
class InvalidDependencyException(SpecError):
"""Raised when a dependency in a spec is not actually a dependency
of the package."""
def __init__(self, message):
super(InvalidDependencyException, self).__init__(message)
class InvalidConstraintException(SpecError):
"""Raised when a package dependencies conflict."""
def __init__(self, message):
super(InvalidConstraintException, self).__init__(message)

View file

@ -18,8 +18,8 @@ def __init__(self, url):
class Stage(object): class Stage(object):
"""A Stage object manaages a directory where an archive is downloaded, """A Stage object manaages a directory where an archive is downloaded,
expanded, and built before being installed. A stage's lifecycle looks expanded, and built before being installed. It also handles downloading
like this: the archive. A stage's lifecycle looks like this:
setup() Create the stage directory. setup() Create the stage directory.
fetch() Fetch a source archive into the stage. fetch() Fetch a source archive into the stage.
@ -32,21 +32,16 @@ class Stage(object):
in a tmp directory. Otherwise, stages are created directly in in a tmp directory. Otherwise, stages are created directly in
spack.stage_path. spack.stage_path.
""" """
def __init__(self, path, url):
def __init__(self, stage_name, url):
"""Create a stage object. """Create a stage object.
Parameters: Parameters:
stage_name Name of the stage directory that will be created. path Relative path from the stage root to where the stage will
url URL of the archive to be downloaded into this stage. be created.
url URL of the archive to be downloaded into this stage.
""" """
self.stage_name = stage_name self.path = os.path.join(spack.stage_path, path)
self.url = url self.url = url
@property
def path(self):
"""Absolute path to the stage directory."""
return spack.new_path(spack.stage_path, self.stage_name)
def setup(self): def setup(self):
"""Creates the stage directory. """Creates the stage directory.
@ -103,8 +98,7 @@ def setup(self):
if username: if username:
tmp_dir = spack.new_path(tmp_dir, username) tmp_dir = spack.new_path(tmp_dir, username)
spack.mkdirp(tmp_dir) spack.mkdirp(tmp_dir)
tmp_dir = tempfile.mkdtemp( tmp_dir = tempfile.mkdtemp('.stage', 'spack-stage-', tmp_dir)
'.stage', self.stage_name + '-', tmp_dir)
os.symlink(tmp_dir, self.path) os.symlink(tmp_dir, self.path)

View file

@ -0,0 +1,13 @@
import unittest
import spack.spec
class ConcretizeTest(unittest.TestCase):
def check_concretize(self, abstract_spec):
abstract = spack.spec.parse_one(abstract_spec)
self.assertTrue(abstract.concretized().concrete)
def test_packages(self):
self.check_concretize("libelf")

View file

@ -1,6 +1,7 @@
import unittest import unittest
import spack.spec
from spack.spec import * from spack.spec import *
from spack.parse import * from spack.parse import Token, ParseError
# Sample output for a complex lexing. # Sample output for a complex lexing.
complex_lex = [Token(ID, 'mvapich_foo'), complex_lex = [Token(ID, 'mvapich_foo'),
@ -29,10 +30,6 @@
class SpecTest(unittest.TestCase): class SpecTest(unittest.TestCase):
def setUp(self):
self.parser = SpecParser()
self.lexer = SpecLexer()
# ================================================================================ # ================================================================================
# Parse checks # Parse checks
# ================================================================================ # ================================================================================
@ -47,14 +44,14 @@ def check_parse(self, expected, spec=None):
""" """
if spec == None: if spec == None:
spec = expected spec = expected
output = self.parser.parse(spec) output = spack.spec.parse(spec)
parsed = (" ".join(str(spec) for spec in output)) parsed = (" ".join(str(spec) for spec in output))
self.assertEqual(expected, parsed) self.assertEqual(expected, parsed)
def check_lex(self, tokens, spec): def check_lex(self, tokens, spec):
"""Check that the provided spec parses to the provided list of tokens.""" """Check that the provided spec parses to the provided list of tokens."""
lex_output = self.lexer.lex(spec) lex_output = SpecLexer().lex(spec)
for tok, spec_tok in zip(tokens, lex_output): for tok, spec_tok in zip(tokens, lex_output):
if tok.type == ID: if tok.type == ID:
self.assertEqual(tok, spec_tok) self.assertEqual(tok, spec_tok)
@ -71,31 +68,33 @@ def test_package_names(self):
self.check_parse("_mvapich_foo") self.check_parse("_mvapich_foo")
def test_simple_dependence(self): def test_simple_dependence(self):
self.check_parse("openmpi ^hwloc") self.check_parse("openmpi^hwloc")
self.check_parse("openmpi ^hwloc ^libunwind") self.check_parse("openmpi^hwloc^libunwind")
def test_dependencies_with_versions(self): def test_dependencies_with_versions(self):
self.check_parse("openmpi ^hwloc@1.2e6") self.check_parse("openmpi^hwloc@1.2e6")
self.check_parse("openmpi ^hwloc@1.2e6:") self.check_parse("openmpi^hwloc@1.2e6:")
self.check_parse("openmpi ^hwloc@:1.4b7-rc3") self.check_parse("openmpi^hwloc@:1.4b7-rc3")
self.check_parse("openmpi ^hwloc@1.2e6:1.4b7-rc3") self.check_parse("openmpi^hwloc@1.2e6:1.4b7-rc3")
def test_full_specs(self): def test_full_specs(self):
self.check_parse("mvapich_foo ^_openmpi@1.2:1.4,1.6%intel@12.1+debug~qt_4 ^stackwalker@8.1_1e") self.check_parse("mvapich_foo^_openmpi@1.2:1.4,1.6%intel@12.1+debug~qt_4^stackwalker@8.1_1e")
def test_canonicalize(self): def test_canonicalize(self):
self.check_parse( self.check_parse(
"mvapich_foo ^_openmpi@1.2:1.4,1.6%intel@12.1:12.6+debug~qt_4 ^stackwalker@8.1_1e", "mvapich_foo^_openmpi@1.2:1.4,1.6%intel@12.1:12.6+debug~qt_4^stackwalker@8.1_1e",
"mvapich_foo ^_openmpi@1.6,1.2:1.4%intel@12.1:12.6+debug~qt_4 ^stackwalker@8.1_1e") "mvapich_foo ^_openmpi@1.6,1.2:1.4%intel@12.1:12.6+debug~qt_4 ^stackwalker@8.1_1e")
self.check_parse( self.check_parse(
"mvapich_foo ^_openmpi@1.2:1.4,1.6%intel@12.1:12.6+debug~qt_4 ^stackwalker@8.1_1e", "mvapich_foo^_openmpi@1.2:1.4,1.6%intel@12.1:12.6+debug~qt_4^stackwalker@8.1_1e",
"mvapich_foo ^stackwalker@8.1_1e ^_openmpi@1.6,1.2:1.4%intel@12.1:12.6~qt_4+debug") "mvapich_foo ^stackwalker@8.1_1e ^_openmpi@1.6,1.2:1.4%intel@12.1:12.6~qt_4+debug")
self.check_parse( self.check_parse(
"x ^y@1,2:3,4%intel@1,2,3,4+a~b+c~d+e~f", "x^y@1,2:3,4%intel@1,2,3,4+a~b+c~d+e~f",
"x ^y~f+e~d+c~b+a@4,2:3,1%intel@4,3,2,1") "x ^y~f+e~d+c~b+a@4,2:3,1%intel@4,3,2,1")
self.check_parse("x^y", "x@: ^y@:")
def test_parse_errors(self): def test_parse_errors(self):
self.assertRaises(ParseError, self.check_parse, "x@@1.2") self.assertRaises(ParseError, self.check_parse, "x@@1.2")
self.assertRaises(ParseError, self.check_parse, "x ^y@@1.2") self.assertRaises(ParseError, self.check_parse, "x ^y@@1.2")
@ -111,11 +110,11 @@ def test_duplicate_depdendence(self):
def test_duplicate_compiler(self): def test_duplicate_compiler(self):
self.assertRaises(DuplicateCompilerError, self.check_parse, "x%intel%intel") self.assertRaises(DuplicateCompilerError, self.check_parse, "x%intel%intel")
self.assertRaises(DuplicateCompilerError, self.check_parse, "x%intel%gnu") self.assertRaises(DuplicateCompilerError, self.check_parse, "x%intel%gcc")
self.assertRaises(DuplicateCompilerError, self.check_parse, "x%gnu%intel") self.assertRaises(DuplicateCompilerError, self.check_parse, "x%gcc%intel")
self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%intel%intel") self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%intel%intel")
self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%intel%gnu") self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%intel%gcc")
self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%gnu%intel") self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%gcc%intel")
# ================================================================================ # ================================================================================

View file

@ -39,6 +39,26 @@ def assert_ver_eq(self, a, b):
self.assertTrue(a <= b) self.assertTrue(a <= b)
def assert_in(self, needle, haystack):
self.assertTrue(ver(needle) in ver(haystack))
def assert_not_in(self, needle, haystack):
self.assertFalse(ver(needle) in ver(haystack))
def assert_canonical(self, canonical_list, version_list):
self.assertEqual(ver(canonical_list), ver(version_list))
def assert_overlaps(self, v1, v2):
self.assertTrue(ver(v1).overlaps(ver(v2)))
def assert_no_overlap(self, v1, v2):
self.assertFalse(ver(v1).overlaps(ver(v2)))
def test_two_segments(self): def test_two_segments(self):
self.assert_ver_eq('1.0', '1.0') self.assert_ver_eq('1.0', '1.0')
self.assert_ver_lt('1.0', '2.0') self.assert_ver_lt('1.0', '2.0')
@ -50,6 +70,7 @@ def test_three_segments(self):
self.assert_ver_lt('2.0', '2.0.1') self.assert_ver_lt('2.0', '2.0.1')
self.assert_ver_gt('2.0.1', '2.0') self.assert_ver_gt('2.0.1', '2.0')
def test_alpha(self): def test_alpha(self):
# TODO: not sure whether I like this. 2.0.1a is *usually* # TODO: not sure whether I like this. 2.0.1a is *usually*
# TODO: less than 2.0.1, but special-casing it makes version # TODO: less than 2.0.1, but special-casing it makes version
@ -58,6 +79,7 @@ def test_alpha(self):
self.assert_ver_gt('2.0.1a', '2.0.1') self.assert_ver_gt('2.0.1a', '2.0.1')
self.assert_ver_lt('2.0.1', '2.0.1a') self.assert_ver_lt('2.0.1', '2.0.1a')
def test_patch(self): def test_patch(self):
self.assert_ver_eq('5.5p1', '5.5p1') self.assert_ver_eq('5.5p1', '5.5p1')
self.assert_ver_lt('5.5p1', '5.5p2') self.assert_ver_lt('5.5p1', '5.5p2')
@ -66,6 +88,7 @@ def test_patch(self):
self.assert_ver_lt('5.5p1', '5.5p10') self.assert_ver_lt('5.5p1', '5.5p10')
self.assert_ver_gt('5.5p10', '5.5p1') self.assert_ver_gt('5.5p10', '5.5p1')
def test_num_alpha_with_no_separator(self): def test_num_alpha_with_no_separator(self):
self.assert_ver_lt('10xyz', '10.1xyz') self.assert_ver_lt('10xyz', '10.1xyz')
self.assert_ver_gt('10.1xyz', '10xyz') self.assert_ver_gt('10.1xyz', '10xyz')
@ -73,6 +96,7 @@ def test_num_alpha_with_no_separator(self):
self.assert_ver_lt('xyz10', 'xyz10.1') self.assert_ver_lt('xyz10', 'xyz10.1')
self.assert_ver_gt('xyz10.1', 'xyz10') self.assert_ver_gt('xyz10.1', 'xyz10')
def test_alpha_with_dots(self): def test_alpha_with_dots(self):
self.assert_ver_eq('xyz.4', 'xyz.4') self.assert_ver_eq('xyz.4', 'xyz.4')
self.assert_ver_lt('xyz.4', '8') self.assert_ver_lt('xyz.4', '8')
@ -80,25 +104,30 @@ def test_alpha_with_dots(self):
self.assert_ver_lt('xyz.4', '2') self.assert_ver_lt('xyz.4', '2')
self.assert_ver_gt('2', 'xyz.4') self.assert_ver_gt('2', 'xyz.4')
def test_nums_and_patch(self): def test_nums_and_patch(self):
self.assert_ver_lt('5.5p2', '5.6p1') self.assert_ver_lt('5.5p2', '5.6p1')
self.assert_ver_gt('5.6p1', '5.5p2') self.assert_ver_gt('5.6p1', '5.5p2')
self.assert_ver_lt('5.6p1', '6.5p1') self.assert_ver_lt('5.6p1', '6.5p1')
self.assert_ver_gt('6.5p1', '5.6p1') self.assert_ver_gt('6.5p1', '5.6p1')
def test_rc_versions(self): def test_rc_versions(self):
self.assert_ver_gt('6.0.rc1', '6.0') self.assert_ver_gt('6.0.rc1', '6.0')
self.assert_ver_lt('6.0', '6.0.rc1') self.assert_ver_lt('6.0', '6.0.rc1')
def test_alpha_beta(self): def test_alpha_beta(self):
self.assert_ver_gt('10b2', '10a1') self.assert_ver_gt('10b2', '10a1')
self.assert_ver_lt('10a2', '10b2') self.assert_ver_lt('10a2', '10b2')
def test_double_alpha(self): def test_double_alpha(self):
self.assert_ver_eq('1.0aa', '1.0aa') self.assert_ver_eq('1.0aa', '1.0aa')
self.assert_ver_lt('1.0a', '1.0aa') self.assert_ver_lt('1.0a', '1.0aa')
self.assert_ver_gt('1.0aa', '1.0a') self.assert_ver_gt('1.0aa', '1.0a')
def test_padded_numbers(self): def test_padded_numbers(self):
self.assert_ver_eq('10.0001', '10.0001') self.assert_ver_eq('10.0001', '10.0001')
self.assert_ver_eq('10.0001', '10.1') self.assert_ver_eq('10.0001', '10.1')
@ -106,20 +135,24 @@ def test_padded_numbers(self):
self.assert_ver_lt('10.0001', '10.0039') self.assert_ver_lt('10.0001', '10.0039')
self.assert_ver_gt('10.0039', '10.0001') self.assert_ver_gt('10.0039', '10.0001')
def test_close_numbers(self): def test_close_numbers(self):
self.assert_ver_lt('4.999.9', '5.0') self.assert_ver_lt('4.999.9', '5.0')
self.assert_ver_gt('5.0', '4.999.9') self.assert_ver_gt('5.0', '4.999.9')
def test_date_stamps(self): def test_date_stamps(self):
self.assert_ver_eq('20101121', '20101121') self.assert_ver_eq('20101121', '20101121')
self.assert_ver_lt('20101121', '20101122') self.assert_ver_lt('20101121', '20101122')
self.assert_ver_gt('20101122', '20101121') self.assert_ver_gt('20101122', '20101121')
def test_underscores(self): def test_underscores(self):
self.assert_ver_eq('2_0', '2_0') self.assert_ver_eq('2_0', '2_0')
self.assert_ver_eq('2.0', '2_0') self.assert_ver_eq('2.0', '2_0')
self.assert_ver_eq('2_0', '2.0') self.assert_ver_eq('2_0', '2.0')
def test_rpm_oddities(self): def test_rpm_oddities(self):
self.assert_ver_eq('1b.fc17', '1b.fc17') self.assert_ver_eq('1b.fc17', '1b.fc17')
self.assert_ver_lt('1b.fc17', '1.fc17') self.assert_ver_lt('1b.fc17', '1.fc17')
@ -139,3 +172,89 @@ def test_version_ranges(self):
self.assert_ver_lt('1.2:1.4', '1.5:1.6') self.assert_ver_lt('1.2:1.4', '1.5:1.6')
self.assert_ver_gt('1.5:1.6', '1.2:1.4') self.assert_ver_gt('1.5:1.6', '1.2:1.4')
def test_contains(self):
self.assert_in('1.3', '1.2:1.4')
self.assert_in('1.2.5', '1.2:1.4')
self.assert_in('1.3.5', '1.2:1.4')
self.assert_in('1.3.5-7', '1.2:1.4')
self.assert_not_in('1.1', '1.2:1.4')
self.assert_not_in('1.5', '1.2:1.4')
self.assert_not_in('1.4.2', '1.2:1.4')
self.assert_in('1.2.8', '1.2.7:1.4')
self.assert_in('1.2.7:1.4', ':')
self.assert_not_in('1.2.5', '1.2.7:1.4')
self.assert_not_in('1.4.1', '1.2.7:1.4')
def test_in_list(self):
self.assert_in('1.2', ['1.5', '1.2', '1.3'])
self.assert_in('1.2.5', ['1.5', '1.2:1.3'])
self.assert_in('1.5', ['1.5', '1.2:1.3'])
self.assert_not_in('1.4', ['1.5', '1.2:1.3'])
self.assert_in('1.2.5:1.2.7', [':'])
self.assert_in('1.2.5:1.2.7', ['1.5', '1.2:1.3'])
self.assert_not_in('1.2.5:1.5', ['1.5', '1.2:1.3'])
self.assert_not_in('1.1:1.2.5', ['1.5', '1.2:1.3'])
def test_ranges_overlap(self):
self.assert_overlaps('1.2', '1.2')
self.assert_overlaps('1.2.1', '1.2.1')
self.assert_overlaps('1.2.1b', '1.2.1b')
self.assert_overlaps('1.2:1.7', '1.6:1.9')
self.assert_overlaps(':1.7', '1.6:1.9')
self.assert_overlaps(':1.7', ':1.9')
self.assert_overlaps(':1.7', '1.6:')
self.assert_overlaps('1.2:', '1.6:1.9')
self.assert_overlaps('1.2:', ':1.9')
self.assert_overlaps('1.2:', '1.6:')
self.assert_overlaps(':', ':')
self.assert_overlaps(':', '1.6:1.9')
def test_lists_overlap(self):
self.assert_overlaps('1.2b:1.7,5', '1.6:1.9,1')
self.assert_overlaps('1,2,3,4,5', '3,4,5,6,7')
self.assert_overlaps('1,2,3,4,5', '5,6,7')
self.assert_overlaps('1,2,3,4,5', '5:7')
self.assert_overlaps('1,2,3,4,5', '3, 6:7')
self.assert_overlaps('1, 2, 4, 6.5', '3, 6:7')
self.assert_overlaps('1, 2, 4, 6.5', ':, 5, 8')
self.assert_overlaps('1, 2, 4, 6.5', ':')
self.assert_no_overlap('1, 2, 4', '3, 6:7')
self.assert_no_overlap('1,2,3,4,5', '6,7')
self.assert_no_overlap('1,2,3,4,5', '6:7')
def test_canonicalize_list(self):
self.assert_canonical(['1.2', '1.3', '1.4'],
['1.2', '1.3', '1.3', '1.4'])
self.assert_canonical(['1.2', '1.3:1.4'],
['1.2', '1.3', '1.3:1.4'])
self.assert_canonical(['1.2', '1.3:1.4'],
['1.2', '1.3:1.4', '1.4'])
self.assert_canonical(['1.3:1.4'],
['1.3:1.4', '1.3', '1.3.1', '1.3.9', '1.4'])
self.assert_canonical(['1.3:1.4'],
['1.3', '1.3.1', '1.3.9', '1.4', '1.3:1.4'])
self.assert_canonical(['1.3:1.5'],
['1.3', '1.3.1', '1.3.9', '1.4:1.5', '1.3:1.4'])
self.assert_canonical(['1.3:1.5'],
['1.3, 1.3.1,1.3.9,1.4:1.5,1.3:1.4'])
self.assert_canonical(['1.3:1.5'],
['1.3, 1.3.1,1.3.9,1.4 : 1.5 , 1.3 : 1.4'])
self.assert_canonical([':'],
[':,1.3, 1.3.1,1.3.9,1.4 : 1.5 , 1.3 : 1.4'])

View file

@ -1,18 +1,18 @@
import sys import sys
import spack import spack
from spack.color import cprint from spack.color import *
indent = " " indent = " "
def msg(message, *args): def msg(message, *args):
cprint("@*b{==>} @*w{%s}" % str(message)) cprint("@*b{==>} @*w{%s}" % cescape(message))
for arg in args: for arg in args:
print indent + str(arg) print indent + str(arg)
def info(message, *args, **kwargs): def info(message, *args, **kwargs):
format = kwargs.get('format', '*b') format = kwargs.get('format', '*b')
cprint("@%s{==>} %s" % (format, str(message))) cprint("@%s{==>} %s" % (format, cescape(message)))
for arg in args: for arg in args:
print indent + str(arg) print indent + str(arg)

View file

@ -1,29 +1,83 @@
"""
This file implements Version and version-ish objects. These are:
Version
A single version of a package.
VersionRange
A range of versions of a package.
VersionList
A list of Versions and VersionRanges.
All of these types support the following operations, which can
be called on any of the types:
__eq__, __ne__, __lt__, __gt__, __ge__, __le__, __hash__
__contains__
overlaps
merge
concrete
True if the Version, VersionRange or VersionList represents
a single version.
"""
import os import os
import sys
import re import re
from bisect import bisect_left
from functools import total_ordering from functools import total_ordering
import utils import utils
from none_compare import *
import spack.error import spack.error
# Valid version characters # Valid version characters
VALID_VERSION = r'[A-Za-z0-9_.-]' VALID_VERSION = r'[A-Za-z0-9_.-]'
def int_if_int(string): def int_if_int(string):
"""Convert a string to int if possible. Otherwise, return a string.""" """Convert a string to int if possible. Otherwise, return a string."""
try: try:
return int(string) return int(string)
except: except ValueError:
return string return string
def ver(string): def coerce_versions(a, b):
"""Parses either a version or version range from a string.""" """Convert both a and b to the 'greatest' type between them, in this order:
if ':' in string: Version < VersionRange < VersionList
start, end = string.split(':') This is used to simplify comparison operations below so that we're always
return VersionRange(Version(start), Version(end)) comparing things that are of the same type.
"""
order = (Version, VersionRange, VersionList)
ta, tb = type(a), type(b)
def check_type(t):
if t not in order:
raise TypeError("coerce_versions cannot be called on %s" % t)
check_type(ta)
check_type(tb)
if ta == tb:
return (a, b)
elif order.index(ta) > order.index(tb):
if ta == VersionRange:
return (a, VersionRange(b, b))
else:
return (a, VersionList([b]))
else: else:
return Version(string) if tb == VersionRange:
return (VersionRange(a, a), b)
else:
return (VersionList([a]), b)
def coerced(method):
"""Decorator that ensures that argument types of a method are coerced."""
def coercing_method(a, b):
if type(a) == type(b) or a is None or b is None:
return method(a, b)
else:
ca, cb = coerce_versions(a, b)
return getattr(ca, method.__name__)(cb)
return coercing_method
@total_ordering @total_ordering
@ -33,7 +87,8 @@ def __init__(self, string):
if not re.match(VALID_VERSION, string): if not re.match(VALID_VERSION, string):
raise ValueError("Bad characters in version string: %s" % string) raise ValueError("Bad characters in version string: %s" % string)
# preserve the original string # preserve the original string, but trimmed.
string = string.strip()
self.string = string self.string = string
# Split version into alphabetical and numeric segments # Split version into alphabetical and numeric segments
@ -52,6 +107,15 @@ def up_to(self, index):
""" """
return '.'.join(str(x) for x in self[:index]) return '.'.join(str(x) for x in self[:index])
def lowest(self):
return self
def highest(self):
return self
def wildcard(self): def wildcard(self):
"""Create a regex that will match variants of this version string.""" """Create a regex that will match variants of this version string."""
def a_or_n(seg): def a_or_n(seg):
@ -75,31 +139,39 @@ def a_or_n(seg):
wc += ')?' * (len(seg_res) - 1) wc += ')?' * (len(seg_res) - 1)
return wc return wc
def __iter__(self): def __iter__(self):
for v in self.version: for v in self.version:
yield v yield v
def __getitem__(self, idx): def __getitem__(self, idx):
return tuple(self.version[idx]) return tuple(self.version[idx])
def __repr__(self): def __repr__(self):
return self.string return self.string
def __str__(self): def __str__(self):
return self.string return self.string
@property
def concrete(self):
return self
@coerced
def __lt__(self, other): def __lt__(self, other):
"""Version comparison is designed for consistency with the way RPM """Version comparison is designed for consistency with the way RPM
does things. If you need more complicated versions in installed does things. If you need more complicated versions in installed
packages, you should override your package's version string to packages, you should override your package's version string to
express it more sensibly. express it more sensibly.
""" """
assert(other is not None) if other is None:
return False
# Let VersionRange do all the range-based comparison
if type(other) == VersionRange:
return not other < self
# Coerce if other is not a Version
# simple equality test first. # simple equality test first.
if self.version == other.version: if self.version == other.version:
return False return False
@ -121,22 +193,42 @@ def __lt__(self, other):
# If the common prefix is equal, the one with more segments is bigger. # If the common prefix is equal, the one with more segments is bigger.
return len(self.version) < len(other.version) return len(self.version) < len(other.version)
@coerced
def __eq__(self, other): def __eq__(self, other):
"""Implemented to match __lt__. See __lt__.""" return (other is not None and
if type(other) != Version: type(other) == Version and self.version == other.version)
return False
return self.version == other.version
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
def __hash__(self): def __hash__(self):
return hash(self.version) return hash(self.version)
@coerced
def __contains__(self, other):
return self == other
@coerced
def overlaps(self, other):
return self == other
@coerced
def merge(self, other):
if self == other:
return self
else:
return VersionList([self, other])
@total_ordering @total_ordering
class VersionRange(object): class VersionRange(object):
def __init__(self, start, end=None): def __init__(self, start, end):
if type(start) == str: if type(start) == str:
start = Version(start) start = Version(start)
if type(end) == str: if type(end) == str:
@ -148,37 +240,74 @@ def __init__(self, start, end=None):
raise ValueError("Invalid Version range: %s" % self) raise ValueError("Invalid Version range: %s" % self)
def lowest(self):
return self.start
def highest(self):
return self.end
@coerced
def __lt__(self, other): def __lt__(self, other):
if type(other) == Version: """Sort VersionRanges lexicographically so that they are ordered first
return self.end and self.end < other by start and then by end. None denotes an open range, so None in
elif type(other) == VersionRange: the start position is less than everything except None, and None in
return self.end and other.start and self.end < other.start the end position is greater than everything but None.
else: """
raise TypeError("Can't compare VersionRange to %s" % type(other)) if other is None:
return False
def __gt__(self, other): return (none_low_lt(self.start, other.start) or
if type(other) == Version: (self.start == other.start and
return self.start and self.start > other none_high_lt(self.end, other.end)))
elif type(other) == VersionRange:
return self.start and other.end and self.start > other.end
else:
raise TypeError("Can't compare VersionRange to %s" % type(other))
@coerced
def __eq__(self, other): def __eq__(self, other):
return (type(other) == VersionRange return (other is not None and
and self.start == other.start type(other) == VersionRange and
and self.end == other.end) self.start == other.start and self.end == other.end)
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
@property
def concrete(self):
return self.start if self.start == self.end else None
@coerced
def __contains__(self, other):
return (none_low_ge(other.start, self.start) and
none_high_le(other.end, self.end))
@coerced
def overlaps(self, other):
return (other in self or self in other or
((self.start == None or other.end == None or
self.start <= other.end) and
(other.start == None or self.end == None or
other.start <= self.end)))
@coerced
def merge(self, other):
return VersionRange(none_low_min(self.start, other.start),
none_high_max(self.end, other.end))
def __hash__(self):
return hash((self.start, self.end))
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
def __str__(self): def __str__(self):
out = "" out = ""
if self.start: if self.start:
@ -187,3 +316,179 @@ def __str__(self):
if self.end: if self.end:
out += str(self.end) out += str(self.end)
return out return out
@total_ordering
class VersionList(object):
"""Sorted, non-redundant list of Versions and VersionRanges."""
def __init__(self, vlist=None):
self.versions = []
if vlist != None:
vlist = list(vlist)
for v in vlist:
self.add(ver(v))
def add(self, version):
if type(version) in (Version, VersionRange):
# This normalizes single-value version ranges.
if version.concrete:
version = version.concrete
i = bisect_left(self, version)
while i-1 >= 0 and version.overlaps(self[i-1]):
version = version.merge(self[i-1])
del self.versions[i-1]
i -= 1
while i < len(self) and version.overlaps(self[i]):
version = version.merge(self[i])
del self.versions[i]
self.versions.insert(i, version)
elif type(version) == VersionList:
for v in version:
self.add(v)
else:
raise TypeError("Can't add %s to VersionList" % type(version))
@property
def concrete(self):
if len(self) == 1:
return self[0].concrete
else:
return None
def copy(self):
return VersionList(self)
def lowest(self):
"""Get the lowest version in the list."""
if not self:
return None
else:
return self[0].lowest()
def highest(self):
"""Get the highest version in the list."""
if not self:
return None
else:
return self[-1].highest()
@coerced
def overlaps(self, other):
if not other or not self:
return False
i = o = 0
while i < len(self) and o < len(other):
if self[i].overlaps(other[o]):
return True
elif self[i] < other[o]:
i += 1
else:
o += 1
return False
@coerced
def merge(self, other):
return VersionList(self.versions + other.versions)
@coerced
def __contains__(self, other):
if len(self) == 0:
return False
for version in other:
i = bisect_left(self, other)
if i == 0:
if version not in self[0]:
return False
elif all(version not in v for v in self[i-1:]):
return False
return True
def __getitem__(self, index):
return self.versions[index]
def __iter__(self):
for v in self.versions:
yield v
def __len__(self):
return len(self.versions)
@coerced
def __eq__(self, other):
return other is not None and self.versions == other.versions
def __ne__(self, other):
return not (self == other)
@coerced
def __lt__(self, other):
return other is not None and self.versions < other.versions
def __hash__(self):
return hash(tuple(self.versions))
def __str__(self):
return ",".join(str(v) for v in self.versions)
def __repr__(self):
return str(self.versions)
def _string_to_version(string):
"""Converts a string to a Version, VersionList, or VersionRange.
This is private. Client code should use ver().
"""
string = string.replace(' ','')
if ',' in string:
return VersionList(string.split(','))
elif ':' in string:
s, e = string.split(':')
start = Version(s) if s else None
end = Version(e) if e else None
return VersionRange(start, end)
else:
return Version(string)
def ver(obj):
"""Parses a Version, VersionRange, or VersionList from a string
or list of strings.
"""
t = type(obj)
if t == list:
return VersionList(obj)
elif t == str:
return _string_to_version(obj)
elif t in (Version, VersionRange, VersionList):
return obj
else:
raise TypeError("ver() can't convert %s to version!" % t)