From 5c49bb45c7955490319f229ffbe72cfdf3102b22 Mon Sep 17 00:00:00 2001 From: Massimiliano Culpo Date: Wed, 31 Jan 2024 15:37:59 +0100 Subject: [PATCH] ASP-based solver: decouple setup phase from `clingo.backend` (#41952) Currently, the `SpackSolverSetup` and the `PyclingoDriver` are more coupled than necessary: 1. The driver object needs a setup object to be injected during a solve, 2. And the setup object will get a reference back to the driver This design is necessary because we use the low-level `clingo.backend` interface to setup our problem. This interface though is meant to bypass the grounder and add symbols directly in the grounded table, which is a feature we don't currently use. The PR simplifies the encoding by having the setup object returning the problem-specific facts / rules as a list of strings, and the driver ingesting them using the [clingo.Control.add](https://potassco.org/clingo/python-api/5.6/clingo/control.html#clingo.control.Control.add) method. This removes any use of the low level interface. Using this encoding makes it easy to hash the output of the setup phase, since it is returned as a string. --- lib/spack/spack/solver/asp.py | 268 +++++++++++++++++----------------- 1 file changed, 133 insertions(+), 135 deletions(-) diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py index 151aef20a6..504d151ae3 100644 --- a/lib/spack/spack/solver/asp.py +++ b/lib/spack/spack/solver/asp.py @@ -96,6 +96,7 @@ # these are from clingo.ast and bootstrapped later ASTType = None parse_files = None +parse_term = None #: Enable the addition of a runtime node WITH_RUNTIME = sys.platform != "win32" @@ -310,11 +311,11 @@ def _id(thing): if isinstance(thing, AspObject): return thing elif isinstance(thing, bool): - return '"%s"' % str(thing) + return f'"{str(thing)}"' elif isinstance(thing, int): return str(thing) else: - return '"%s"' % str(thing) + return f'"{str(thing)}"' @llnl.util.lang.key_ordering @@ -351,21 +352,20 @@ def __call__(self, *args): """ return AspFunction(self.name, self.args + args) - def symbol(self, positive=True): - def argify(arg): - if isinstance(arg, bool): - return clingo.String(str(arg)) - elif isinstance(arg, int): - return clingo.Number(arg) - elif isinstance(arg, AspFunction): - return clingo.Function(arg.name, [argify(x) for x in arg.args], positive=positive) - else: - return clingo.String(str(arg)) + def argify(self, arg): + if isinstance(arg, bool): + return clingo.String(str(arg)) + elif isinstance(arg, int): + return clingo.Number(arg) + elif isinstance(arg, AspFunction): + return clingo.Function(arg.name, [self.argify(x) for x in arg.args], positive=True) + return clingo.String(str(arg)) - return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive) + def symbol(self): + return clingo.Function(self.name, [self.argify(arg) for arg in self.args], positive=True) def __str__(self): - return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args)) + return f"{self.name}({', '.join(str(_id(arg)) for arg in self.args)})" def __repr__(self): return str(self) @@ -664,7 +664,7 @@ def _spec_with_default_name(spec_str, name): def bootstrap_clingo(): - global clingo, ASTType, parse_files + global clingo, ASTType, parse_files, parse_term if not clingo: import spack.bootstrap @@ -677,9 +677,10 @@ def bootstrap_clingo(): try: from clingo.ast import parse_files + from clingo.symbol import parse_term except ImportError: # older versions of clingo have this one namespace up - from clingo import parse_files + from clingo import parse_files, parse_term class NodeArgument(NamedTuple): @@ -882,53 +883,9 @@ def __init__(self, cores=True): error reporting. """ bootstrap_clingo() - - self.out = llnl.util.lang.Devnull() self.cores = cores - - # These attributes are part of the object, but will be reset - # at each call to solve + # This attribute will be reset at each call to solve self.control = None - self.backend = None - self.assumptions = None - - def title(self, name, char): - self.out.write("\n") - self.out.write("%" + (char * 76)) - self.out.write("\n") - self.out.write("%% %s\n" % name) - self.out.write("%" + (char * 76)) - self.out.write("\n") - - def h1(self, name): - self.title(name, "=") - - def h2(self, name): - self.title(name, "-") - - def newline(self): - self.out.write("\n") - - def fact(self, head): - """ASP fact (a rule without a body). - - Arguments: - head (AspFunction): ASP function to generate as fact - """ - symbol = head.symbol() if hasattr(head, "symbol") else head - - # This is commented out to avoid evaluating str(symbol) when we have no stream - if not isinstance(self.out, llnl.util.lang.Devnull): - self.out.write(f"{str(symbol)}.\n") - - atom = self.backend.add_atom(symbol) - - # Only functions relevant for constructing bug reports for bad error messages - # are assumptions, and only when using cores. - choice = self.cores and symbol.name == "internal_error" - self.backend.add_rule([atom], [], choice=choice) - if choice: - self.assumptions.append(atom) def solve(self, setup, specs, reuse=None, output=None, control=None, allow_deprecated=False): """Set up the input and solve for dependencies of ``specs``. @@ -948,49 +905,24 @@ def solve(self, setup, specs, reuse=None, output=None, control=None, allow_depre solve, and the internal statistics from clingo. """ output = output or DEFAULT_OUTPUT_CONFIGURATION - # allow solve method to override the output stream - if output.out is not None: - self.out = output.out - timer = spack.util.timer.Timer() # Initialize the control object for the solver self.control = control or default_clingo_control() - # set up the problem -- this generates facts and rules - self.assumptions = [] + timer.start("setup") - with self.control.backend() as backend: - self.backend = backend - setup.setup(self, specs, reuse=reuse, allow_deprecated=allow_deprecated) + asp_problem = setup.setup(specs, reuse=reuse, allow_deprecated=allow_deprecated) + if output.out is not None: + output.out.write(asp_problem) + if output.setup_only: + return Result(specs), None, None timer.stop("setup") timer.start("load") - # read in the main ASP program and display logic -- these are - # handwritten, not generated, so we load them as resources - parent_dir = os.path.dirname(__file__) - - # extract error messages from concretize.lp by inspecting its AST - with self.backend: - - def visit(node): - if ast_type(node) == ASTType.Rule: - for term in node.body: - if ast_type(term) == ASTType.Literal: - if ast_type(term.atom) == ASTType.SymbolicAtom: - name = ast_sym(term.atom).name - if name == "internal_error": - arg = ast_sym(ast_sym(term.atom).arguments[0]) - self.fact(AspFunction(name)(arg.string)) - - self.h1("Error messages") - path = os.path.join(parent_dir, "concretize.lp") - parse_files([path], visit) - - # If we're only doing setup, just return an empty solve result - if output.setup_only: - return Result(specs), None, None - + # Add the problem instance + self.control.add("base", [], asp_problem) # Load the file itself + parent_dir = os.path.dirname(__file__) self.control.load(os.path.join(parent_dir, "concretize.lp")) self.control.load(os.path.join(parent_dir, "heuristic.lp")) if spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") != "none": @@ -1016,7 +948,7 @@ def on_model(model): models.append((model.cost, model.symbols(shown=True, terms=True))) solve_kwargs = { - "assumptions": self.assumptions, + "assumptions": setup.assumptions, "on_model": on_model, "on_core": cores.append, } @@ -1142,6 +1074,7 @@ class SpackSolverSetup: def __init__(self, tests=False): self.gen = None # set by setup() + self.assumptions = [] self.declared_versions = collections.defaultdict(list) self.possible_versions = collections.defaultdict(set) self.deprecated_versions = collections.defaultdict(set) @@ -1878,36 +1811,7 @@ def _spec_clauses( """ clauses = [] - # TODO: do this with consistent suffixes. - class Head: - node = fn.attr("node") - virtual_node = fn.attr("virtual_node") - node_platform = fn.attr("node_platform_set") - node_os = fn.attr("node_os_set") - node_target = fn.attr("node_target_set") - variant_value = fn.attr("variant_set") - node_compiler = fn.attr("node_compiler_set") - node_compiler_version = fn.attr("node_compiler_version_set") - node_flag = fn.attr("node_flag_set") - node_flag_source = fn.attr("node_flag_source") - node_flag_propagate = fn.attr("node_flag_propagate") - variant_propagation_candidate = fn.attr("variant_propagation_candidate") - - class Body: - node = fn.attr("node") - virtual_node = fn.attr("virtual_node") - node_platform = fn.attr("node_platform") - node_os = fn.attr("node_os") - node_target = fn.attr("node_target") - variant_value = fn.attr("variant_value") - node_compiler = fn.attr("node_compiler") - node_compiler_version = fn.attr("node_compiler_version") - node_flag = fn.attr("node_flag") - node_flag_source = fn.attr("node_flag_source") - node_flag_propagate = fn.attr("node_flag_propagate") - variant_propagation_candidate = fn.attr("variant_propagation_candidate") - - f = Body if body else Head + f = _Body if body else _Head if spec.name: clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name)) @@ -2503,12 +2407,11 @@ def define_concrete_input_specs(self, specs, possible): def setup( self, - driver: PyclingoDriver, specs: Sequence[spack.spec.Spec], *, reuse: Optional[List[spack.spec.Spec]] = None, allow_deprecated: bool = False, - ): + ) -> str: """Generate an ASP program with relevant constraints for specs. This calls methods on the solve driver to set up the problem with @@ -2516,7 +2419,6 @@ def setup( specs, as well as constraints from the specs themselves. Arguments: - driver: driver instance of this solve specs: list of Specs to solve reuse: list of concrete specs that can be reused allow_deprecated: if True adds deprecated versions into the solve @@ -2542,9 +2444,7 @@ def setup( if node.namespace is not None: self.explicitly_required_namespaces[node.name] = node.namespace - # driver is used by all the functions below to add facts and - # rules to generate an ASP program. - self.gen = driver + self.gen = ProblemInstanceBuilder() if not allow_deprecated: self.gen.fact(fn.deprecated_versions_not_allowed()) @@ -2648,6 +2548,29 @@ def setup( self.gen.h1("Target Constraints") self.define_target_constraints() + self.gen.h1("Internal errors") + self.internal_errors() + + return self.gen.value() + + def internal_errors(self): + parent_dir = os.path.dirname(__file__) + + def visit(node): + if ast_type(node) == ASTType.Rule: + for term in node.body: + if ast_type(term) == ASTType.Literal: + if ast_type(term.atom) == ASTType.SymbolicAtom: + name = ast_sym(term.atom).name + if name == "internal_error": + arg = ast_sym(ast_sym(term.atom).arguments[0]) + symbol = AspFunction(name)(arg.string) + self.assumptions.append((parse_term(str(symbol)), True)) + self.gen.asp_problem.append(f"{{ {symbol} }}.\n") + + path = os.path.join(parent_dir, "concretize.lp") + parse_files([path], visit) + def define_runtime_constraints(self): """Define the constraints to be imposed on the runtimes""" recorder = RuntimePropertyRecorder(self) @@ -2778,6 +2701,83 @@ def pkg_class(self, pkg_name: str) -> typing.Type["spack.package_base.PackageBas return spack.repo.PATH.get_pkg_class(request) +class _Head: + """ASP functions used to express spec clauses in the HEAD of a rule""" + + node = fn.attr("node") + virtual_node = fn.attr("virtual_node") + node_platform = fn.attr("node_platform_set") + node_os = fn.attr("node_os_set") + node_target = fn.attr("node_target_set") + variant_value = fn.attr("variant_set") + node_compiler = fn.attr("node_compiler_set") + node_compiler_version = fn.attr("node_compiler_version_set") + node_flag = fn.attr("node_flag_set") + node_flag_source = fn.attr("node_flag_source") + node_flag_propagate = fn.attr("node_flag_propagate") + variant_propagation_candidate = fn.attr("variant_propagation_candidate") + + +class _Body: + """ASP functions used to express spec clauses in the BODY of a rule""" + + node = fn.attr("node") + virtual_node = fn.attr("virtual_node") + node_platform = fn.attr("node_platform") + node_os = fn.attr("node_os") + node_target = fn.attr("node_target") + variant_value = fn.attr("variant_value") + node_compiler = fn.attr("node_compiler") + node_compiler_version = fn.attr("node_compiler_version") + node_flag = fn.attr("node_flag") + node_flag_source = fn.attr("node_flag_source") + node_flag_propagate = fn.attr("node_flag_propagate") + variant_propagation_candidate = fn.attr("variant_propagation_candidate") + + +class ProblemInstanceBuilder: + """Provides an interface to construct a problem instance. + + Once all the facts and rules have been added, the problem instance can be retrieved with: + + >>> builder = ProblemInstanceBuilder() + >>> ... + >>> problem_instance = builder.value() + + The problem instance can be added directly to the "control" structure of clingo. + """ + + def __init__(self): + self.asp_problem = [] + + def fact(self, atom: AspFunction) -> None: + symbol = atom.symbol() if hasattr(atom, "symbol") else atom + self.asp_problem.append(f"{str(symbol)}.\n") + + def append(self, rule: str) -> None: + self.asp_problem.append(rule) + + def title(self, header: str, char: str) -> None: + self.asp_problem.append("\n") + self.asp_problem.append("%" + (char * 76)) + self.asp_problem.append("\n") + self.asp_problem.append(f"% {header}\n") + self.asp_problem.append("%" + (char * 76)) + self.asp_problem.append("\n") + + def h1(self, header: str) -> None: + self.title(header, "=") + + def h2(self, header: str) -> None: + self.title(header, "-") + + def newline(self): + self.asp_problem.append("\n") + + def value(self) -> str: + return "".join(self.asp_problem) + + class RequirementParser: """Parses requirements from package.py files and configuration, and returns rules.""" @@ -3085,9 +3085,7 @@ def consume_facts(self): self._setup.gen.h2("Runtimes: rules") self._setup.gen.newline() for rule in self.rules: - if not isinstance(self._setup.gen.out, llnl.util.lang.Devnull): - self._setup.gen.out.write(rule) - self._setup.gen.control.add("base", [], rule) + self._setup.gen.append(rule) self._setup.gen.h2("Runtimes: conditions") for runtime_pkg in spack.repo.PATH.packages_with_tags("runtime"):