provider index: removed import from + refactored a few parts (#15570)

Removed provider_index use of 'import from' and refactored a few routines to a further subclassing of _IndexBase for implementing user defined bindings of provider specs.
This commit is contained in:
Massimiliano Culpo 2020-03-25 17:48:05 +01:00 committed by GitHub
parent 3aa225cd5c
commit b42a96df98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 182 additions and 133 deletions

View file

@ -2,54 +2,147 @@
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""Classes and functions to manage providers of virtual dependencies"""
import itertools
"""
The ``virtual`` module contains utility classes for virtual dependencies.
"""
from itertools import product as iproduct
from six import iteritems
from pprint import pformat
import six
import spack.error
import spack.util.spack_json as sjson
class ProviderIndex(object):
"""This is a dict of dicts used for finding providers of particular
virtual dependencies. The dict of dicts looks like:
{ vpkg name :
{ full vpkg spec : set(packages providing spec) } }
Callers can use this to first find which packages provide a vpkg,
then find a matching full spec. e.g., in this scenario:
{ 'mpi' :
{ mpi@:1.1 : set([mpich]),
mpi@:2.3 : set([mpich2@1.9:]) } }
Calling providers_for(spec) will find specs that provide a
matching implementation of MPI.
def _cross_provider_maps(lmap, rmap):
"""Return a dictionary that combines constraint requests from both input.
Args:
lmap: main provider map
rmap: provider map with additional constraints
"""
# TODO: this is pretty darned nasty, and inefficient, but there
# TODO: are not that many vdeps in most specs.
result = {}
for lspec, rspec in itertools.product(lmap, rmap):
try:
constrained = lspec.constrained(rspec)
except spack.error.UnsatisfiableSpecError:
continue
# lp and rp are left and right provider specs.
for lp_spec, rp_spec in itertools.product(lmap[lspec], rmap[rspec]):
if lp_spec.name == rp_spec.name:
try:
const = lp_spec.constrained(rp_spec, deps=False)
result.setdefault(constrained, set()).add(const)
except spack.error.UnsatisfiableSpecError:
continue
return result
class _IndexBase(object):
#: This is a dict of dicts used for finding providers of particular
#: virtual dependencies. The dict of dicts looks like:
#:
#: { vpkg name :
#: { full vpkg spec : set(packages providing spec) } }
#:
#: Callers can use this to first find which packages provide a vpkg,
#: then find a matching full spec. e.g., in this scenario:
#:
#: { 'mpi' :
#: { mpi@:1.1 : set([mpich]),
#: mpi@:2.3 : set([mpich2@1.9:]) } }
#:
#: Calling providers_for(spec) will find specs that provide a
#: matching implementation of MPI. Derived class need to construct
#: this attribute according to the semantics above.
providers = None
def providers_for(self, virtual_spec):
"""Return a list of specs of all packages that provide virtual
packages with the supplied spec.
Args:
virtual_spec: virtual spec to be provided
"""
result = set()
# Allow string names to be passed as input, as well as specs
if isinstance(virtual_spec, six.string_types):
virtual_spec = spack.spec.Spec(virtual_spec)
# Add all the providers that satisfy the vpkg spec.
if virtual_spec.name in self.providers:
for p_spec, spec_set in self.providers[virtual_spec.name].items():
if p_spec.satisfies(virtual_spec, deps=False):
result.update(spec_set)
# Return providers in order. Defensively copy.
return sorted(s.copy() for s in result)
def __contains__(self, name):
return name in self.providers
def satisfies(self, other):
"""Determine if the providers of virtual specs are compatible.
Args:
other: another provider index
Returns:
True if the providers are compatible, False otherwise.
"""
common = set(self.providers) & set(other.providers)
if not common:
return True
# This ensures that some provider in other COULD satisfy the
# vpkg constraints on self.
result = {}
for name in common:
crossed = _cross_provider_maps(
self.providers[name], other.providers[name]
)
if crossed:
result[name] = crossed
return all(c in result for c in common)
def __eq__(self, other):
return self.providers == other.providers
def _transform(self, transform_fun, out_mapping_type=dict):
"""Transform this provider index dictionary and return it.
Args:
transform_fun: transform_fun takes a (vpkg, pset) mapping and runs
it on each pair in nested dicts.
out_mapping_type: type to be used internally on the
transformed (vpkg, pset)
Returns:
Transformed mapping
"""
return _transform(self.providers, transform_fun, out_mapping_type)
def __str__(self):
return str(self.providers)
def __repr__(self):
return repr(self.providers)
class ProviderIndex(_IndexBase):
def __init__(self, specs=None, restrict=False):
"""Create a new ProviderIndex.
"""Provider index based on a single mapping of providers.
Optional arguments:
Args:
specs (list of specs): if provided, will call update on each
single spec to initialize this provider index.
specs
List (or sequence) of specs. If provided, will call
`update` on this ProviderIndex with each spec in the list.
restrict: "restricts" values to the verbatim input specs; do not
pre-apply package's constraints.
restrict
"restricts" values to the verbatim input specs; do not
pre-apply package's constraints.
TODO: rename this. It is intended to keep things as broad
as possible without overly restricting results, so it is
not the best name.
TODO: rename this. It is intended to keep things as broad
TODO: as possible without overly restricting results, so it is
TODO: not the best name.
"""
if specs is None:
specs = []
@ -67,6 +160,11 @@ def __init__(self, specs=None, restrict=False):
self.update(spec)
def update(self, spec):
"""Update the provider index with additional virtual specs.
Args:
spec: spec potentially providing additional virtual specs
"""
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
@ -74,10 +172,10 @@ def update(self, spec):
# Empty specs do not have a package
return
assert(not spec.virtual)
assert not spec.virtual, "cannot update an index using a virtual spec"
pkg_provided = spec.package_class.provided
for provided_spec, provider_specs in iteritems(pkg_provided):
for provided_spec, provider_specs in six.iteritems(pkg_provided):
for provider_spec in provider_specs:
# TODO: fix this comment.
# We want satisfaction other than flags
@ -110,94 +208,24 @@ def update(self, spec):
constrained.constrain(provider_spec)
provider_map[provided_spec].add(constrained)
def providers_for(self, *vpkg_specs):
"""Gives specs of all packages that provide virtual packages
with the supplied specs."""
providers = set()
for vspec in vpkg_specs:
# Allow string names to be passed as input, as well as specs
if type(vspec) == str:
vspec = spack.spec.Spec(vspec)
# Add all the providers that satisfy the vpkg spec.
if vspec.name in self.providers:
for p_spec, spec_set in self.providers[vspec.name].items():
if p_spec.satisfies(vspec, deps=False):
providers.update(spec_set)
# Return providers in order. Defensively copy.
return sorted(s.copy() for s in providers)
# TODO: this is pretty darned nasty, and inefficient, but there
# are not that many vdeps in most specs.
def _cross_provider_maps(self, lmap, rmap):
result = {}
for lspec, rspec in iproduct(lmap, rmap):
try:
constrained = lspec.constrained(rspec)
except spack.error.UnsatisfiableSpecError:
continue
# lp and rp are left and right provider specs.
for lp_spec, rp_spec in iproduct(lmap[lspec], rmap[rspec]):
if lp_spec.name == rp_spec.name:
try:
const = lp_spec.constrained(rp_spec, deps=False)
result.setdefault(constrained, set()).add(const)
except spack.error.UnsatisfiableSpecError:
continue
return result
def __contains__(self, name):
"""Whether a particular vpkg name is in the index."""
return name in self.providers
def satisfies(self, other):
"""Check that providers of virtual specs are compatible."""
common = set(self.providers) & set(other.providers)
if not common:
return True
# This ensures that some provider in other COULD satisfy the
# vpkg constraints on self.
result = {}
for name in common:
crossed = self._cross_provider_maps(self.providers[name],
other.providers[name])
if crossed:
result[name] = crossed
return all(c in result for c in common)
def to_json(self, stream=None):
"""Dump a JSON representation of this object.
Args:
stream: stream where to dump
"""
provider_list = self._transform(
lambda vpkg, pset: [
vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list)
sjson.dump({'provider_index': {'providers': provider_list}}, stream)
@staticmethod
def from_json(stream):
data = sjson.load(stream)
if not isinstance(data, dict):
raise ProviderIndexError("JSON ProviderIndex data was not a dict.")
if 'provider_index' not in data:
raise ProviderIndexError(
"YAML ProviderIndex does not start with 'provider_index'")
index = ProviderIndex()
providers = data['provider_index']['providers']
index.providers = _transform(
providers,
lambda vpkg, plist: (
spack.spec.Spec.from_node_dict(vpkg),
set(spack.spec.Spec.from_node_dict(p) for p in plist)))
return index
def merge(self, other):
"""Merge `other` ProviderIndex into this one."""
"""Merge another provider index into this one.
Args:
other (ProviderIndex): provider index to be merged
"""
other = other.copy() # defensive copy.
for pkg in other.providers:
@ -236,40 +264,61 @@ def remove_provider(self, pkg_name):
del self.providers[pkg]
def copy(self):
"""Deep copy of this ProviderIndex."""
"""Return a deep copy of this index."""
clone = ProviderIndex()
clone.providers = self._transform(
lambda vpkg, pset: (vpkg, set((p.copy() for p in pset))))
return clone
def __eq__(self, other):
return self.providers == other.providers
@staticmethod
def from_json(stream):
"""Construct a provider index from its JSON representation.
def _transform(self, transform_fun, out_mapping_type=dict):
return _transform(self.providers, transform_fun, out_mapping_type)
Args:
stream: stream where to read from the JSON data
"""
data = sjson.load(stream)
def __str__(self):
return pformat(
_transform(self.providers,
lambda k, v: (k, list(v))))
if not isinstance(data, dict):
raise ProviderIndexError("JSON ProviderIndex data was not a dict.")
if 'provider_index' not in data:
raise ProviderIndexError(
"YAML ProviderIndex does not start with 'provider_index'")
index = ProviderIndex()
providers = data['provider_index']['providers']
index.providers = _transform(
providers,
lambda vpkg, plist: (
spack.spec.Spec.from_node_dict(vpkg),
set(spack.spec.Spec.from_node_dict(p) for p in plist)))
return index
def _transform(providers, transform_fun, out_mapping_type=dict):
"""Syntactic sugar for transforming a providers dict.
transform_fun takes a (vpkg, pset) mapping and runs it on each
pair in nested dicts.
Args:
providers: provider dictionary
transform_fun: transform_fun takes a (vpkg, pset) mapping and runs
it on each pair in nested dicts.
out_mapping_type: type to be used internally on the
transformed (vpkg, pset)
Returns:
Transformed mapping
"""
def mapiter(mappings):
if isinstance(mappings, dict):
return iteritems(mappings)
return six.iteritems(mappings)
else:
return iter(mappings)
return dict(
(name, out_mapping_type([
transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]))
(name, out_mapping_type(
[transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]
))
for name, mappings in providers.items())

View file

@ -189,7 +189,7 @@ def test_conditional_dep_with_user_constraints():
assert ('y@3' in spec)
@pytest.mark.usefixtures('mutable_mock_repo')
@pytest.mark.usefixtures('mutable_mock_repo', 'config')
class TestSpecDag(object):
def test_conflicting_package_constraints(self, set_dependency):