Rework the encoding to introduce node(ID, Package) nested facts

So far the encoding has a single ID per package, i.e. all the
facts will be node(0, Package). This will prepare the stage for
extending this logic and having multiple nodes from the same
package in a DAG.
This commit is contained in:
Massimiliano Culpo 2023-06-16 15:25:30 +02:00 committed by Todd Gamblin
parent 907a80ca71
commit 27ab53b68a
3 changed files with 631 additions and 539 deletions

View file

@ -515,15 +515,17 @@ def _compute_specs_from_answer_set(self):
best = min(self.answers)
opt, _, answer = best
for input_spec in self.abstract_specs:
key = input_spec.name
node = SpecBuilder.root_node(pkg=input_spec.name)
if input_spec.virtual:
providers = [spec.name for spec in answer.values() if spec.package.provides(key)]
key = providers[0]
candidate = answer.get(key)
providers = [
spec.name for spec in answer.values() if spec.package.provides(input_spec.name)
]
node = SpecBuilder.root_node(pkg=providers[0])
candidate = answer.get(node)
if candidate and candidate.satisfies(input_spec):
self._concrete_specs.append(answer[key])
self._concrete_specs_by_input[input_spec] = answer[key]
self._concrete_specs.append(answer[node])
self._concrete_specs_by_input[input_spec] = answer[node]
else:
self._unsolved_specs.append(input_spec)
@ -2426,6 +2428,18 @@ class SpecBuilder:
)
)
node_regex = re.compile(r"node\(\d,\"(.*)\"\)")
@staticmethod
def root_node(*, pkg: str) -> str:
"""Given a package name, returns the string representation of the root node in
the ASP encoding.
Args:
pkg: name of a package
"""
return f'node(0,"{pkg}")'
def __init__(self, specs, hash_lookup=None):
self._specs = {}
self._result = None
@ -2438,100 +2452,121 @@ def __init__(self, specs, hash_lookup=None):
# from this dictionary during reconstruction
self._hash_lookup = hash_lookup or {}
def hash(self, pkg, h):
if pkg not in self._specs:
self._specs[pkg] = self._hash_lookup[h]
self._hash_specs.append(pkg)
@staticmethod
def extract_pkg(node: str) -> str:
"""Extracts the package name from a node fact, and returns it.
def node(self, pkg):
if pkg not in self._specs:
self._specs[pkg] = spack.spec.Spec(pkg)
Args:
node: node from which the package name is to be extracted
"""
m = SpecBuilder.node_regex.match(node)
if m is None:
raise spack.error.SpackError(f"cannot extract package information from '{node}'")
def _arch(self, pkg):
arch = self._specs[pkg].architecture
return m.group(1)
def hash(self, node, h):
if node not in self._specs:
self._specs[node] = self._hash_lookup[h]
self._hash_specs.append(node)
def node(self, node):
pkg = self.extract_pkg(node)
if node not in self._specs:
self._specs[node] = spack.spec.Spec(pkg)
def _arch(self, node):
arch = self._specs[node].architecture
if not arch:
arch = spack.spec.ArchSpec()
self._specs[pkg].architecture = arch
self._specs[node].architecture = arch
return arch
def node_platform(self, pkg, platform):
self._arch(pkg).platform = platform
def node_platform(self, node, platform):
self._arch(node).platform = platform
def node_os(self, pkg, os):
self._arch(pkg).os = os
def node_os(self, node, os):
self._arch(node).os = os
def node_target(self, pkg, target):
self._arch(pkg).target = target
def node_target(self, node, target):
self._arch(node).target = target
def variant_value(self, pkg, name, value):
def variant_value(self, node, name, value):
# FIXME: is there a way not to special case 'dev_path' everywhere?
if name == "dev_path":
self._specs[pkg].variants.setdefault(
self._specs[node].variants.setdefault(
name, spack.variant.SingleValuedVariant(name, value)
)
return
if name == "patches":
self._specs[pkg].variants.setdefault(
self._specs[node].variants.setdefault(
name, spack.variant.MultiValuedVariant(name, value)
)
return
self._specs[pkg].update_variant_validate(name, value)
self._specs[node].update_variant_validate(name, value)
def version(self, pkg, version):
self._specs[pkg].versions = vn.VersionList([vn.Version(version)])
def version(self, node, version):
self._specs[node].versions = vn.VersionList([vn.Version(version)])
def node_compiler_version(self, pkg, compiler, version):
self._specs[pkg].compiler = spack.spec.CompilerSpec(compiler)
self._specs[pkg].compiler.versions = vn.VersionList([vn.Version(version)])
def node_compiler_version(self, node, compiler, version):
self._specs[node].compiler = spack.spec.CompilerSpec(compiler)
self._specs[node].compiler.versions = vn.VersionList([vn.Version(version)])
def node_flag_compiler_default(self, pkg):
self._flag_compiler_defaults.add(pkg)
def node_flag_compiler_default(self, node):
self._flag_compiler_defaults.add(node)
def node_flag(self, pkg, flag_type, flag):
self._specs[pkg].compiler_flags.add_flag(flag_type, flag, False)
def node_flag(self, node, flag_type, flag):
self._specs[node].compiler_flags.add_flag(flag_type, flag, False)
def node_flag_source(self, pkg, flag_type, source):
self._flag_sources[(pkg, flag_type)].add(source)
def node_flag_source(self, node, flag_type, source):
self._flag_sources[(node, flag_type)].add(source)
def no_flags(self, pkg, flag_type):
self._specs[pkg].compiler_flags[flag_type] = []
def no_flags(self, node, flag_type):
self._specs[node].compiler_flags[flag_type] = []
def external_spec_selected(self, pkg, idx):
def external_spec_selected(self, node, idx):
"""This means that the external spec and index idx
has been selected for this package.
"""
packages_yaml = spack.config.get("packages")
packages_yaml = _normalize_packages_yaml(packages_yaml)
pkg = self.extract_pkg(node)
spec_info = packages_yaml[pkg]["externals"][int(idx)]
self._specs[pkg].external_path = spec_info.get("prefix", None)
self._specs[pkg].external_modules = spack.spec.Spec._format_module_list(
self._specs[node].external_path = spec_info.get("prefix", None)
self._specs[node].external_modules = spack.spec.Spec._format_module_list(
spec_info.get("modules", None)
)
self._specs[pkg].extra_attributes = spec_info.get("extra_attributes", {})
self._specs[node].extra_attributes = spec_info.get("extra_attributes", {})
# If this is an extension, update the dependencies to include the extendee
package = self._specs[pkg].package_class(self._specs[pkg])
package = self._specs[node].package_class(self._specs[node])
extendee_spec = package.extendee_spec
if extendee_spec:
package.update_external_dependencies(self._specs.get(extendee_spec.name, None))
def depends_on(self, pkg, dep, type):
dependencies = self._specs[pkg].edges_to_dependencies(name=dep)
if extendee_spec:
extendee_node = SpecBuilder.root_node(pkg=extendee_spec.name)
package.update_external_dependencies(self._specs.get(extendee_node, None))
def depends_on(self, parent_node, dependency_node, type):
dependencies = self._specs[parent_node].edges_to_dependencies(name=dependency_node)
# TODO: assertion to be removed when cross-compilation is handled correctly
msg = "Current solver does not handle multiple dependency edges of the same name"
assert len(dependencies) < 2, msg
if not dependencies:
self._specs[pkg].add_dependency_edge(self._specs[dep], deptypes=(type,), virtuals=())
self._specs[parent_node].add_dependency_edge(
self._specs[dependency_node], deptypes=(type,), virtuals=()
)
else:
# TODO: This assumes that each solve unifies dependencies
dependencies[0].update_deptypes(deptypes=(type,))
def virtual_on_edge(self, pkg, provider, virtual):
dependencies = self._specs[pkg].edges_to_dependencies(name=provider)
def virtual_on_edge(self, parent_node, provider_node, virtual):
provider = self.extract_pkg(provider_node)
dependencies = self._specs[parent_node].edges_to_dependencies(name=provider)
assert len(dependencies) == 1
dependencies[0].update_virtuals((virtual,))
@ -2562,17 +2597,22 @@ def reorder_flags(self):
# order is determined by the DAG. A spec's flags come after any of its ancestors
# on the compile line
source_key = (spec.name, flag_type)
node = SpecBuilder.root_node(pkg=spec.name)
source_key = (node, flag_type)
if source_key in self._flag_sources:
order = [s.name for s in spec.traverse(order="post", direction="parents")]
order = [
SpecBuilder.root_node(pkg=s.name)
for s in spec.traverse(order="post", direction="parents")
]
sorted_sources = sorted(
self._flag_sources[source_key], key=lambda s: order.index(s)
)
# add flags from each source, lowest to highest precedence
for name in sorted_sources:
for node in sorted_sources:
all_src_flags = list()
per_pkg_sources = [self._specs[name]]
per_pkg_sources = [self._specs[node]]
name = self.extract_pkg(node)
if name in cmd_specs:
per_pkg_sources.append(cmd_specs[name])
for source in per_pkg_sources:
@ -2645,14 +2685,14 @@ def build_specs(self, function_tuples):
# solving but don't construct anything. Do not ignore error
# predicates on virtual packages.
if name != "error":
pkg = args[0]
pkg = self.extract_pkg(args[0])
if spack.repo.PATH.is_virtual(pkg):
continue
# if we've already gotten a concrete spec for this pkg,
# do not bother calling actions on it except for node_flag_source,
# since node_flag_source is tracking information not in the spec itself
spec = self._specs.get(pkg)
spec = self._specs.get(args[0])
if spec and spec.concrete:
if name != "node_flag_source":
continue

File diff suppressed because it is too large Load diff

View file

@ -2983,9 +2983,10 @@ def _new_concretize(self, tests=False):
providers = [spec.name for spec in answer.values() if spec.package.provides(name)]
name = providers[0]
assert name in answer
node = spack.solver.asp.SpecBuilder.root_node(pkg=name)
assert node in answer, f"cannot find {name} in the list of specs {','.join(answer.keys())}"
concretized = answer[name]
concretized = answer[node]
self._dup(concretized)
def concretize(self, tests=False):
@ -3519,7 +3520,8 @@ def update_variant_validate(self, variant_name, values):
for value in values:
if self.variants.get(variant_name):
msg = (
"Cannot append a value to a single-valued " "variant with an already set value"
f"cannot append the new value '{value}' to the single-valued "
f"variant '{self.variants[variant_name]}'"
)
assert pkg_variant.multi, msg
self.variants[variant_name].append(value)