ASP-based solver: decouple setup phase from clingo.backend (#41952)

Currently, the `SpackSolverSetup` and the `PyclingoDriver` are more coupled than necessary:
1. The driver object needs a setup object to be injected during a solve, 
2. And the setup object will get a reference back to the driver

This design is necessary because we use the low-level `clingo.backend` interface to setup our problem. This interface though is meant to bypass the grounder and add symbols directly in the grounded table, which is a feature we don't currently use.

The PR simplifies the encoding by having the setup object returning the problem-specific facts / rules as a list of strings, and the driver ingesting them using the [clingo.Control.add](https://potassco.org/clingo/python-api/5.6/clingo/control.html#clingo.control.Control.add) method. This removes any use of the low level interface.

Using this encoding makes it easy to hash the output of the setup phase, since it is returned as a string.
This commit is contained in:
Massimiliano Culpo 2024-01-31 15:37:59 +01:00 committed by GitHub
parent 97fb9565ee
commit 5c49bb45c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -96,6 +96,7 @@
# these are from clingo.ast and bootstrapped later # these are from clingo.ast and bootstrapped later
ASTType = None ASTType = None
parse_files = None parse_files = None
parse_term = None
#: Enable the addition of a runtime node #: Enable the addition of a runtime node
WITH_RUNTIME = sys.platform != "win32" WITH_RUNTIME = sys.platform != "win32"
@ -310,11 +311,11 @@ def _id(thing):
if isinstance(thing, AspObject): if isinstance(thing, AspObject):
return thing return thing
elif isinstance(thing, bool): elif isinstance(thing, bool):
return '"%s"' % str(thing) return f'"{str(thing)}"'
elif isinstance(thing, int): elif isinstance(thing, int):
return str(thing) return str(thing)
else: else:
return '"%s"' % str(thing) return f'"{str(thing)}"'
@llnl.util.lang.key_ordering @llnl.util.lang.key_ordering
@ -351,21 +352,20 @@ def __call__(self, *args):
""" """
return AspFunction(self.name, self.args + args) return AspFunction(self.name, self.args + args)
def symbol(self, positive=True): def argify(self, arg):
def argify(arg):
if isinstance(arg, bool): if isinstance(arg, bool):
return clingo.String(str(arg)) return clingo.String(str(arg))
elif isinstance(arg, int): elif isinstance(arg, int):
return clingo.Number(arg) return clingo.Number(arg)
elif isinstance(arg, AspFunction): elif isinstance(arg, AspFunction):
return clingo.Function(arg.name, [argify(x) for x in arg.args], positive=positive) return clingo.Function(arg.name, [self.argify(x) for x in arg.args], positive=True)
else:
return clingo.String(str(arg)) return clingo.String(str(arg))
return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive) def symbol(self):
return clingo.Function(self.name, [self.argify(arg) for arg in self.args], positive=True)
def __str__(self): def __str__(self):
return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args)) return f"{self.name}({', '.join(str(_id(arg)) for arg in self.args)})"
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@ -664,7 +664,7 @@ def _spec_with_default_name(spec_str, name):
def bootstrap_clingo(): def bootstrap_clingo():
global clingo, ASTType, parse_files global clingo, ASTType, parse_files, parse_term
if not clingo: if not clingo:
import spack.bootstrap import spack.bootstrap
@ -677,9 +677,10 @@ def bootstrap_clingo():
try: try:
from clingo.ast import parse_files from clingo.ast import parse_files
from clingo.symbol import parse_term
except ImportError: except ImportError:
# older versions of clingo have this one namespace up # older versions of clingo have this one namespace up
from clingo import parse_files from clingo import parse_files, parse_term
class NodeArgument(NamedTuple): class NodeArgument(NamedTuple):
@ -882,53 +883,9 @@ def __init__(self, cores=True):
error reporting. error reporting.
""" """
bootstrap_clingo() bootstrap_clingo()
self.out = llnl.util.lang.Devnull()
self.cores = cores self.cores = cores
# This attribute will be reset at each call to solve
# These attributes are part of the object, but will be reset
# at each call to solve
self.control = None self.control = None
self.backend = None
self.assumptions = None
def title(self, name, char):
self.out.write("\n")
self.out.write("%" + (char * 76))
self.out.write("\n")
self.out.write("%% %s\n" % name)
self.out.write("%" + (char * 76))
self.out.write("\n")
def h1(self, name):
self.title(name, "=")
def h2(self, name):
self.title(name, "-")
def newline(self):
self.out.write("\n")
def fact(self, head):
"""ASP fact (a rule without a body).
Arguments:
head (AspFunction): ASP function to generate as fact
"""
symbol = head.symbol() if hasattr(head, "symbol") else head
# This is commented out to avoid evaluating str(symbol) when we have no stream
if not isinstance(self.out, llnl.util.lang.Devnull):
self.out.write(f"{str(symbol)}.\n")
atom = self.backend.add_atom(symbol)
# Only functions relevant for constructing bug reports for bad error messages
# are assumptions, and only when using cores.
choice = self.cores and symbol.name == "internal_error"
self.backend.add_rule([atom], [], choice=choice)
if choice:
self.assumptions.append(atom)
def solve(self, setup, specs, reuse=None, output=None, control=None, allow_deprecated=False): def solve(self, setup, specs, reuse=None, output=None, control=None, allow_deprecated=False):
"""Set up the input and solve for dependencies of ``specs``. """Set up the input and solve for dependencies of ``specs``.
@ -948,49 +905,24 @@ def solve(self, setup, specs, reuse=None, output=None, control=None, allow_depre
solve, and the internal statistics from clingo. solve, and the internal statistics from clingo.
""" """
output = output or DEFAULT_OUTPUT_CONFIGURATION output = output or DEFAULT_OUTPUT_CONFIGURATION
# allow solve method to override the output stream
if output.out is not None:
self.out = output.out
timer = spack.util.timer.Timer() timer = spack.util.timer.Timer()
# Initialize the control object for the solver # Initialize the control object for the solver
self.control = control or default_clingo_control() self.control = control or default_clingo_control()
# set up the problem -- this generates facts and rules
self.assumptions = []
timer.start("setup") timer.start("setup")
with self.control.backend() as backend: asp_problem = setup.setup(specs, reuse=reuse, allow_deprecated=allow_deprecated)
self.backend = backend if output.out is not None:
setup.setup(self, specs, reuse=reuse, allow_deprecated=allow_deprecated) output.out.write(asp_problem)
if output.setup_only:
return Result(specs), None, None
timer.stop("setup") timer.stop("setup")
timer.start("load") timer.start("load")
# read in the main ASP program and display logic -- these are # Add the problem instance
# handwritten, not generated, so we load them as resources self.control.add("base", [], asp_problem)
parent_dir = os.path.dirname(__file__)
# extract error messages from concretize.lp by inspecting its AST
with self.backend:
def visit(node):
if ast_type(node) == ASTType.Rule:
for term in node.body:
if ast_type(term) == ASTType.Literal:
if ast_type(term.atom) == ASTType.SymbolicAtom:
name = ast_sym(term.atom).name
if name == "internal_error":
arg = ast_sym(ast_sym(term.atom).arguments[0])
self.fact(AspFunction(name)(arg.string))
self.h1("Error messages")
path = os.path.join(parent_dir, "concretize.lp")
parse_files([path], visit)
# If we're only doing setup, just return an empty solve result
if output.setup_only:
return Result(specs), None, None
# Load the file itself # Load the file itself
parent_dir = os.path.dirname(__file__)
self.control.load(os.path.join(parent_dir, "concretize.lp")) self.control.load(os.path.join(parent_dir, "concretize.lp"))
self.control.load(os.path.join(parent_dir, "heuristic.lp")) self.control.load(os.path.join(parent_dir, "heuristic.lp"))
if spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") != "none": if spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") != "none":
@ -1016,7 +948,7 @@ def on_model(model):
models.append((model.cost, model.symbols(shown=True, terms=True))) models.append((model.cost, model.symbols(shown=True, terms=True)))
solve_kwargs = { solve_kwargs = {
"assumptions": self.assumptions, "assumptions": setup.assumptions,
"on_model": on_model, "on_model": on_model,
"on_core": cores.append, "on_core": cores.append,
} }
@ -1142,6 +1074,7 @@ class SpackSolverSetup:
def __init__(self, tests=False): def __init__(self, tests=False):
self.gen = None # set by setup() self.gen = None # set by setup()
self.assumptions = []
self.declared_versions = collections.defaultdict(list) self.declared_versions = collections.defaultdict(list)
self.possible_versions = collections.defaultdict(set) self.possible_versions = collections.defaultdict(set)
self.deprecated_versions = collections.defaultdict(set) self.deprecated_versions = collections.defaultdict(set)
@ -1878,36 +1811,7 @@ def _spec_clauses(
""" """
clauses = [] clauses = []
# TODO: do this with consistent suffixes. f = _Body if body else _Head
class Head:
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform_set")
node_os = fn.attr("node_os_set")
node_target = fn.attr("node_target_set")
variant_value = fn.attr("variant_set")
node_compiler = fn.attr("node_compiler_set")
node_compiler_version = fn.attr("node_compiler_version_set")
node_flag = fn.attr("node_flag_set")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
class Body:
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform")
node_os = fn.attr("node_os")
node_target = fn.attr("node_target")
variant_value = fn.attr("variant_value")
node_compiler = fn.attr("node_compiler")
node_compiler_version = fn.attr("node_compiler_version")
node_flag = fn.attr("node_flag")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
f = Body if body else Head
if spec.name: if spec.name:
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name)) clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
@ -2503,12 +2407,11 @@ def define_concrete_input_specs(self, specs, possible):
def setup( def setup(
self, self,
driver: PyclingoDriver,
specs: Sequence[spack.spec.Spec], specs: Sequence[spack.spec.Spec],
*, *,
reuse: Optional[List[spack.spec.Spec]] = None, reuse: Optional[List[spack.spec.Spec]] = None,
allow_deprecated: bool = False, allow_deprecated: bool = False,
): ) -> str:
"""Generate an ASP program with relevant constraints for specs. """Generate an ASP program with relevant constraints for specs.
This calls methods on the solve driver to set up the problem with This calls methods on the solve driver to set up the problem with
@ -2516,7 +2419,6 @@ def setup(
specs, as well as constraints from the specs themselves. specs, as well as constraints from the specs themselves.
Arguments: Arguments:
driver: driver instance of this solve
specs: list of Specs to solve specs: list of Specs to solve
reuse: list of concrete specs that can be reused reuse: list of concrete specs that can be reused
allow_deprecated: if True adds deprecated versions into the solve allow_deprecated: if True adds deprecated versions into the solve
@ -2542,9 +2444,7 @@ def setup(
if node.namespace is not None: if node.namespace is not None:
self.explicitly_required_namespaces[node.name] = node.namespace self.explicitly_required_namespaces[node.name] = node.namespace
# driver is used by all the functions below to add facts and self.gen = ProblemInstanceBuilder()
# rules to generate an ASP program.
self.gen = driver
if not allow_deprecated: if not allow_deprecated:
self.gen.fact(fn.deprecated_versions_not_allowed()) self.gen.fact(fn.deprecated_versions_not_allowed())
@ -2648,6 +2548,29 @@ def setup(
self.gen.h1("Target Constraints") self.gen.h1("Target Constraints")
self.define_target_constraints() self.define_target_constraints()
self.gen.h1("Internal errors")
self.internal_errors()
return self.gen.value()
def internal_errors(self):
parent_dir = os.path.dirname(__file__)
def visit(node):
if ast_type(node) == ASTType.Rule:
for term in node.body:
if ast_type(term) == ASTType.Literal:
if ast_type(term.atom) == ASTType.SymbolicAtom:
name = ast_sym(term.atom).name
if name == "internal_error":
arg = ast_sym(ast_sym(term.atom).arguments[0])
symbol = AspFunction(name)(arg.string)
self.assumptions.append((parse_term(str(symbol)), True))
self.gen.asp_problem.append(f"{{ {symbol} }}.\n")
path = os.path.join(parent_dir, "concretize.lp")
parse_files([path], visit)
def define_runtime_constraints(self): def define_runtime_constraints(self):
"""Define the constraints to be imposed on the runtimes""" """Define the constraints to be imposed on the runtimes"""
recorder = RuntimePropertyRecorder(self) recorder = RuntimePropertyRecorder(self)
@ -2778,6 +2701,83 @@ def pkg_class(self, pkg_name: str) -> typing.Type["spack.package_base.PackageBas
return spack.repo.PATH.get_pkg_class(request) return spack.repo.PATH.get_pkg_class(request)
class _Head:
"""ASP functions used to express spec clauses in the HEAD of a rule"""
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform_set")
node_os = fn.attr("node_os_set")
node_target = fn.attr("node_target_set")
variant_value = fn.attr("variant_set")
node_compiler = fn.attr("node_compiler_set")
node_compiler_version = fn.attr("node_compiler_version_set")
node_flag = fn.attr("node_flag_set")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
class _Body:
"""ASP functions used to express spec clauses in the BODY of a rule"""
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform")
node_os = fn.attr("node_os")
node_target = fn.attr("node_target")
variant_value = fn.attr("variant_value")
node_compiler = fn.attr("node_compiler")
node_compiler_version = fn.attr("node_compiler_version")
node_flag = fn.attr("node_flag")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
class ProblemInstanceBuilder:
"""Provides an interface to construct a problem instance.
Once all the facts and rules have been added, the problem instance can be retrieved with:
>>> builder = ProblemInstanceBuilder()
>>> ...
>>> problem_instance = builder.value()
The problem instance can be added directly to the "control" structure of clingo.
"""
def __init__(self):
self.asp_problem = []
def fact(self, atom: AspFunction) -> None:
symbol = atom.symbol() if hasattr(atom, "symbol") else atom
self.asp_problem.append(f"{str(symbol)}.\n")
def append(self, rule: str) -> None:
self.asp_problem.append(rule)
def title(self, header: str, char: str) -> None:
self.asp_problem.append("\n")
self.asp_problem.append("%" + (char * 76))
self.asp_problem.append("\n")
self.asp_problem.append(f"% {header}\n")
self.asp_problem.append("%" + (char * 76))
self.asp_problem.append("\n")
def h1(self, header: str) -> None:
self.title(header, "=")
def h2(self, header: str) -> None:
self.title(header, "-")
def newline(self):
self.asp_problem.append("\n")
def value(self) -> str:
return "".join(self.asp_problem)
class RequirementParser: class RequirementParser:
"""Parses requirements from package.py files and configuration, and returns rules.""" """Parses requirements from package.py files and configuration, and returns rules."""
@ -3085,9 +3085,7 @@ def consume_facts(self):
self._setup.gen.h2("Runtimes: rules") self._setup.gen.h2("Runtimes: rules")
self._setup.gen.newline() self._setup.gen.newline()
for rule in self.rules: for rule in self.rules:
if not isinstance(self._setup.gen.out, llnl.util.lang.Devnull): self._setup.gen.append(rule)
self._setup.gen.out.write(rule)
self._setup.gen.control.add("base", [], rule)
self._setup.gen.h2("Runtimes: conditions") self._setup.gen.h2("Runtimes: conditions")
for runtime_pkg in spack.repo.PATH.packages_with_tags("runtime"): for runtime_pkg in spack.repo.PATH.packages_with_tags("runtime"):