From b42a96df980458602c6c5314b3d3fac08f2d7442 Mon Sep 17 00:00:00 2001 From: Massimiliano Culpo Date: Wed, 25 Mar 2020 17:48:05 +0100 Subject: [PATCH] provider index: removed import from + refactored a few parts (#15570) Removed provider_index use of 'import from' and refactored a few routines to a further subclassing of _IndexBase for implementing user defined bindings of provider specs. --- lib/spack/spack/provider_index.py | 313 +++++++++++++++++------------- lib/spack/spack/test/spec_dag.py | 2 +- 2 files changed, 182 insertions(+), 133 deletions(-) diff --git a/lib/spack/spack/provider_index.py b/lib/spack/spack/provider_index.py index 9bf4af8911..326f6aa8f1 100644 --- a/lib/spack/spack/provider_index.py +++ b/lib/spack/spack/provider_index.py @@ -2,54 +2,147 @@ # Spack Project Developers. See the top-level COPYRIGHT file for details. # # SPDX-License-Identifier: (Apache-2.0 OR MIT) +"""Classes and functions to manage providers of virtual dependencies""" +import itertools -""" -The ``virtual`` module contains utility classes for virtual dependencies. -""" - -from itertools import product as iproduct -from six import iteritems -from pprint import pformat - +import six import spack.error import spack.util.spack_json as sjson -class ProviderIndex(object): - """This is a dict of dicts used for finding providers of particular - virtual dependencies. The dict of dicts looks like: - - { vpkg name : - { full vpkg spec : set(packages providing spec) } } - - Callers can use this to first find which packages provide a vpkg, - then find a matching full spec. e.g., in this scenario: - - { 'mpi' : - { mpi@:1.1 : set([mpich]), - mpi@:2.3 : set([mpich2@1.9:]) } } - - Calling providers_for(spec) will find specs that provide a - matching implementation of MPI. +def _cross_provider_maps(lmap, rmap): + """Return a dictionary that combines constraint requests from both input. + Args: + lmap: main provider map + rmap: provider map with additional constraints """ + # TODO: this is pretty darned nasty, and inefficient, but there + # TODO: are not that many vdeps in most specs. + result = {} + for lspec, rspec in itertools.product(lmap, rmap): + try: + constrained = lspec.constrained(rspec) + except spack.error.UnsatisfiableSpecError: + continue + # lp and rp are left and right provider specs. + for lp_spec, rp_spec in itertools.product(lmap[lspec], rmap[rspec]): + if lp_spec.name == rp_spec.name: + try: + const = lp_spec.constrained(rp_spec, deps=False) + result.setdefault(constrained, set()).add(const) + except spack.error.UnsatisfiableSpecError: + continue + return result + + +class _IndexBase(object): + #: This is a dict of dicts used for finding providers of particular + #: virtual dependencies. The dict of dicts looks like: + #: + #: { vpkg name : + #: { full vpkg spec : set(packages providing spec) } } + #: + #: Callers can use this to first find which packages provide a vpkg, + #: then find a matching full spec. e.g., in this scenario: + #: + #: { 'mpi' : + #: { mpi@:1.1 : set([mpich]), + #: mpi@:2.3 : set([mpich2@1.9:]) } } + #: + #: Calling providers_for(spec) will find specs that provide a + #: matching implementation of MPI. Derived class need to construct + #: this attribute according to the semantics above. + providers = None + + def providers_for(self, virtual_spec): + """Return a list of specs of all packages that provide virtual + packages with the supplied spec. + + Args: + virtual_spec: virtual spec to be provided + """ + result = set() + # Allow string names to be passed as input, as well as specs + if isinstance(virtual_spec, six.string_types): + virtual_spec = spack.spec.Spec(virtual_spec) + + # Add all the providers that satisfy the vpkg spec. + if virtual_spec.name in self.providers: + for p_spec, spec_set in self.providers[virtual_spec.name].items(): + if p_spec.satisfies(virtual_spec, deps=False): + result.update(spec_set) + + # Return providers in order. Defensively copy. + return sorted(s.copy() for s in result) + + def __contains__(self, name): + return name in self.providers + + def satisfies(self, other): + """Determine if the providers of virtual specs are compatible. + + Args: + other: another provider index + + Returns: + True if the providers are compatible, False otherwise. + """ + common = set(self.providers) & set(other.providers) + if not common: + return True + + # This ensures that some provider in other COULD satisfy the + # vpkg constraints on self. + result = {} + for name in common: + crossed = _cross_provider_maps( + self.providers[name], other.providers[name] + ) + if crossed: + result[name] = crossed + + return all(c in result for c in common) + + def __eq__(self, other): + return self.providers == other.providers + + def _transform(self, transform_fun, out_mapping_type=dict): + """Transform this provider index dictionary and return it. + + Args: + transform_fun: transform_fun takes a (vpkg, pset) mapping and runs + it on each pair in nested dicts. + out_mapping_type: type to be used internally on the + transformed (vpkg, pset) + + Returns: + Transformed mapping + """ + return _transform(self.providers, transform_fun, out_mapping_type) + + def __str__(self): + return str(self.providers) + + def __repr__(self): + return repr(self.providers) + + +class ProviderIndex(_IndexBase): def __init__(self, specs=None, restrict=False): - """Create a new ProviderIndex. + """Provider index based on a single mapping of providers. - Optional arguments: + Args: + specs (list of specs): if provided, will call update on each + single spec to initialize this provider index. - specs - List (or sequence) of specs. If provided, will call - `update` on this ProviderIndex with each spec in the list. + restrict: "restricts" values to the verbatim input specs; do not + pre-apply package's constraints. - restrict - "restricts" values to the verbatim input specs; do not - pre-apply package's constraints. - - TODO: rename this. It is intended to keep things as broad - as possible without overly restricting results, so it is - not the best name. + TODO: rename this. It is intended to keep things as broad + TODO: as possible without overly restricting results, so it is + TODO: not the best name. """ if specs is None: specs = [] @@ -67,6 +160,11 @@ def __init__(self, specs=None, restrict=False): self.update(spec) def update(self, spec): + """Update the provider index with additional virtual specs. + + Args: + spec: spec potentially providing additional virtual specs + """ if not isinstance(spec, spack.spec.Spec): spec = spack.spec.Spec(spec) @@ -74,10 +172,10 @@ def update(self, spec): # Empty specs do not have a package return - assert(not spec.virtual) + assert not spec.virtual, "cannot update an index using a virtual spec" pkg_provided = spec.package_class.provided - for provided_spec, provider_specs in iteritems(pkg_provided): + for provided_spec, provider_specs in six.iteritems(pkg_provided): for provider_spec in provider_specs: # TODO: fix this comment. # We want satisfaction other than flags @@ -110,94 +208,24 @@ def update(self, spec): constrained.constrain(provider_spec) provider_map[provided_spec].add(constrained) - def providers_for(self, *vpkg_specs): - """Gives specs of all packages that provide virtual packages - with the supplied specs.""" - providers = set() - for vspec in vpkg_specs: - # Allow string names to be passed as input, as well as specs - if type(vspec) == str: - vspec = spack.spec.Spec(vspec) - - # Add all the providers that satisfy the vpkg spec. - if vspec.name in self.providers: - for p_spec, spec_set in self.providers[vspec.name].items(): - if p_spec.satisfies(vspec, deps=False): - providers.update(spec_set) - - # Return providers in order. Defensively copy. - return sorted(s.copy() for s in providers) - - # TODO: this is pretty darned nasty, and inefficient, but there - # are not that many vdeps in most specs. - def _cross_provider_maps(self, lmap, rmap): - result = {} - for lspec, rspec in iproduct(lmap, rmap): - try: - constrained = lspec.constrained(rspec) - except spack.error.UnsatisfiableSpecError: - continue - - # lp and rp are left and right provider specs. - for lp_spec, rp_spec in iproduct(lmap[lspec], rmap[rspec]): - if lp_spec.name == rp_spec.name: - try: - const = lp_spec.constrained(rp_spec, deps=False) - result.setdefault(constrained, set()).add(const) - except spack.error.UnsatisfiableSpecError: - continue - return result - - def __contains__(self, name): - """Whether a particular vpkg name is in the index.""" - return name in self.providers - - def satisfies(self, other): - """Check that providers of virtual specs are compatible.""" - common = set(self.providers) & set(other.providers) - if not common: - return True - - # This ensures that some provider in other COULD satisfy the - # vpkg constraints on self. - result = {} - for name in common: - crossed = self._cross_provider_maps(self.providers[name], - other.providers[name]) - if crossed: - result[name] = crossed - - return all(c in result for c in common) - def to_json(self, stream=None): + """Dump a JSON representation of this object. + + Args: + stream: stream where to dump + """ provider_list = self._transform( lambda vpkg, pset: [ vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list) sjson.dump({'provider_index': {'providers': provider_list}}, stream) - @staticmethod - def from_json(stream): - data = sjson.load(stream) - - if not isinstance(data, dict): - raise ProviderIndexError("JSON ProviderIndex data was not a dict.") - - if 'provider_index' not in data: - raise ProviderIndexError( - "YAML ProviderIndex does not start with 'provider_index'") - - index = ProviderIndex() - providers = data['provider_index']['providers'] - index.providers = _transform( - providers, - lambda vpkg, plist: ( - spack.spec.Spec.from_node_dict(vpkg), - set(spack.spec.Spec.from_node_dict(p) for p in plist))) - return index - def merge(self, other): - """Merge `other` ProviderIndex into this one.""" + """Merge another provider index into this one. + + Args: + other (ProviderIndex): provider index to be merged + """ other = other.copy() # defensive copy. for pkg in other.providers: @@ -236,40 +264,61 @@ def remove_provider(self, pkg_name): del self.providers[pkg] def copy(self): - """Deep copy of this ProviderIndex.""" + """Return a deep copy of this index.""" clone = ProviderIndex() clone.providers = self._transform( lambda vpkg, pset: (vpkg, set((p.copy() for p in pset)))) return clone - def __eq__(self, other): - return self.providers == other.providers + @staticmethod + def from_json(stream): + """Construct a provider index from its JSON representation. - def _transform(self, transform_fun, out_mapping_type=dict): - return _transform(self.providers, transform_fun, out_mapping_type) + Args: + stream: stream where to read from the JSON data + """ + data = sjson.load(stream) - def __str__(self): - return pformat( - _transform(self.providers, - lambda k, v: (k, list(v)))) + if not isinstance(data, dict): + raise ProviderIndexError("JSON ProviderIndex data was not a dict.") + + if 'provider_index' not in data: + raise ProviderIndexError( + "YAML ProviderIndex does not start with 'provider_index'") + + index = ProviderIndex() + providers = data['provider_index']['providers'] + index.providers = _transform( + providers, + lambda vpkg, plist: ( + spack.spec.Spec.from_node_dict(vpkg), + set(spack.spec.Spec.from_node_dict(p) for p in plist))) + return index def _transform(providers, transform_fun, out_mapping_type=dict): """Syntactic sugar for transforming a providers dict. - transform_fun takes a (vpkg, pset) mapping and runs it on each - pair in nested dicts. + Args: + providers: provider dictionary + transform_fun: transform_fun takes a (vpkg, pset) mapping and runs + it on each pair in nested dicts. + out_mapping_type: type to be used internally on the + transformed (vpkg, pset) + Returns: + Transformed mapping """ def mapiter(mappings): if isinstance(mappings, dict): - return iteritems(mappings) + return six.iteritems(mappings) else: return iter(mappings) return dict( - (name, out_mapping_type([ - transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)])) + (name, out_mapping_type( + [transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)] + )) for name, mappings in providers.items()) diff --git a/lib/spack/spack/test/spec_dag.py b/lib/spack/spack/test/spec_dag.py index 419a39968e..25917f9424 100644 --- a/lib/spack/spack/test/spec_dag.py +++ b/lib/spack/spack/test/spec_dag.py @@ -189,7 +189,7 @@ def test_conditional_dep_with_user_constraints(): assert ('y@3' in spec) -@pytest.mark.usefixtures('mutable_mock_repo') +@pytest.mark.usefixtures('mutable_mock_repo', 'config') class TestSpecDag(object): def test_conflicting_package_constraints(self, set_dependency):