ASP-based solver: decouple setup phase from clingo.backend (#41952)

Currently, the `SpackSolverSetup` and the `PyclingoDriver` are more coupled than necessary:
1. The driver object needs a setup object to be injected during a solve, 
2. And the setup object will get a reference back to the driver

This design is necessary because we use the low-level `clingo.backend` interface to setup our problem. This interface though is meant to bypass the grounder and add symbols directly in the grounded table, which is a feature we don't currently use.

The PR simplifies the encoding by having the setup object returning the problem-specific facts / rules as a list of strings, and the driver ingesting them using the [clingo.Control.add](https://potassco.org/clingo/python-api/5.6/clingo/control.html#clingo.control.Control.add) method. This removes any use of the low level interface.

Using this encoding makes it easy to hash the output of the setup phase, since it is returned as a string.
This commit is contained in:
Massimiliano Culpo 2024-01-31 15:37:59 +01:00 committed by GitHub
parent 97fb9565ee
commit 5c49bb45c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -96,6 +96,7 @@
# these are from clingo.ast and bootstrapped later
ASTType = None
parse_files = None
parse_term = None
#: Enable the addition of a runtime node
WITH_RUNTIME = sys.platform != "win32"
@ -310,11 +311,11 @@ def _id(thing):
if isinstance(thing, AspObject):
return thing
elif isinstance(thing, bool):
return '"%s"' % str(thing)
return f'"{str(thing)}"'
elif isinstance(thing, int):
return str(thing)
else:
return '"%s"' % str(thing)
return f'"{str(thing)}"'
@llnl.util.lang.key_ordering
@ -351,21 +352,20 @@ def __call__(self, *args):
"""
return AspFunction(self.name, self.args + args)
def symbol(self, positive=True):
def argify(arg):
if isinstance(arg, bool):
return clingo.String(str(arg))
elif isinstance(arg, int):
return clingo.Number(arg)
elif isinstance(arg, AspFunction):
return clingo.Function(arg.name, [argify(x) for x in arg.args], positive=positive)
else:
return clingo.String(str(arg))
def argify(self, arg):
if isinstance(arg, bool):
return clingo.String(str(arg))
elif isinstance(arg, int):
return clingo.Number(arg)
elif isinstance(arg, AspFunction):
return clingo.Function(arg.name, [self.argify(x) for x in arg.args], positive=True)
return clingo.String(str(arg))
return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive)
def symbol(self):
return clingo.Function(self.name, [self.argify(arg) for arg in self.args], positive=True)
def __str__(self):
return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args))
return f"{self.name}({', '.join(str(_id(arg)) for arg in self.args)})"
def __repr__(self):
return str(self)
@ -664,7 +664,7 @@ def _spec_with_default_name(spec_str, name):
def bootstrap_clingo():
global clingo, ASTType, parse_files
global clingo, ASTType, parse_files, parse_term
if not clingo:
import spack.bootstrap
@ -677,9 +677,10 @@ def bootstrap_clingo():
try:
from clingo.ast import parse_files
from clingo.symbol import parse_term
except ImportError:
# older versions of clingo have this one namespace up
from clingo import parse_files
from clingo import parse_files, parse_term
class NodeArgument(NamedTuple):
@ -882,53 +883,9 @@ def __init__(self, cores=True):
error reporting.
"""
bootstrap_clingo()
self.out = llnl.util.lang.Devnull()
self.cores = cores
# These attributes are part of the object, but will be reset
# at each call to solve
# This attribute will be reset at each call to solve
self.control = None
self.backend = None
self.assumptions = None
def title(self, name, char):
self.out.write("\n")
self.out.write("%" + (char * 76))
self.out.write("\n")
self.out.write("%% %s\n" % name)
self.out.write("%" + (char * 76))
self.out.write("\n")
def h1(self, name):
self.title(name, "=")
def h2(self, name):
self.title(name, "-")
def newline(self):
self.out.write("\n")
def fact(self, head):
"""ASP fact (a rule without a body).
Arguments:
head (AspFunction): ASP function to generate as fact
"""
symbol = head.symbol() if hasattr(head, "symbol") else head
# This is commented out to avoid evaluating str(symbol) when we have no stream
if not isinstance(self.out, llnl.util.lang.Devnull):
self.out.write(f"{str(symbol)}.\n")
atom = self.backend.add_atom(symbol)
# Only functions relevant for constructing bug reports for bad error messages
# are assumptions, and only when using cores.
choice = self.cores and symbol.name == "internal_error"
self.backend.add_rule([atom], [], choice=choice)
if choice:
self.assumptions.append(atom)
def solve(self, setup, specs, reuse=None, output=None, control=None, allow_deprecated=False):
"""Set up the input and solve for dependencies of ``specs``.
@ -948,49 +905,24 @@ def solve(self, setup, specs, reuse=None, output=None, control=None, allow_depre
solve, and the internal statistics from clingo.
"""
output = output or DEFAULT_OUTPUT_CONFIGURATION
# allow solve method to override the output stream
if output.out is not None:
self.out = output.out
timer = spack.util.timer.Timer()
# Initialize the control object for the solver
self.control = control or default_clingo_control()
# set up the problem -- this generates facts and rules
self.assumptions = []
timer.start("setup")
with self.control.backend() as backend:
self.backend = backend
setup.setup(self, specs, reuse=reuse, allow_deprecated=allow_deprecated)
asp_problem = setup.setup(specs, reuse=reuse, allow_deprecated=allow_deprecated)
if output.out is not None:
output.out.write(asp_problem)
if output.setup_only:
return Result(specs), None, None
timer.stop("setup")
timer.start("load")
# read in the main ASP program and display logic -- these are
# handwritten, not generated, so we load them as resources
parent_dir = os.path.dirname(__file__)
# extract error messages from concretize.lp by inspecting its AST
with self.backend:
def visit(node):
if ast_type(node) == ASTType.Rule:
for term in node.body:
if ast_type(term) == ASTType.Literal:
if ast_type(term.atom) == ASTType.SymbolicAtom:
name = ast_sym(term.atom).name
if name == "internal_error":
arg = ast_sym(ast_sym(term.atom).arguments[0])
self.fact(AspFunction(name)(arg.string))
self.h1("Error messages")
path = os.path.join(parent_dir, "concretize.lp")
parse_files([path], visit)
# If we're only doing setup, just return an empty solve result
if output.setup_only:
return Result(specs), None, None
# Add the problem instance
self.control.add("base", [], asp_problem)
# Load the file itself
parent_dir = os.path.dirname(__file__)
self.control.load(os.path.join(parent_dir, "concretize.lp"))
self.control.load(os.path.join(parent_dir, "heuristic.lp"))
if spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") != "none":
@ -1016,7 +948,7 @@ def on_model(model):
models.append((model.cost, model.symbols(shown=True, terms=True)))
solve_kwargs = {
"assumptions": self.assumptions,
"assumptions": setup.assumptions,
"on_model": on_model,
"on_core": cores.append,
}
@ -1142,6 +1074,7 @@ class SpackSolverSetup:
def __init__(self, tests=False):
self.gen = None # set by setup()
self.assumptions = []
self.declared_versions = collections.defaultdict(list)
self.possible_versions = collections.defaultdict(set)
self.deprecated_versions = collections.defaultdict(set)
@ -1878,36 +1811,7 @@ def _spec_clauses(
"""
clauses = []
# TODO: do this with consistent suffixes.
class Head:
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform_set")
node_os = fn.attr("node_os_set")
node_target = fn.attr("node_target_set")
variant_value = fn.attr("variant_set")
node_compiler = fn.attr("node_compiler_set")
node_compiler_version = fn.attr("node_compiler_version_set")
node_flag = fn.attr("node_flag_set")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
class Body:
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform")
node_os = fn.attr("node_os")
node_target = fn.attr("node_target")
variant_value = fn.attr("variant_value")
node_compiler = fn.attr("node_compiler")
node_compiler_version = fn.attr("node_compiler_version")
node_flag = fn.attr("node_flag")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
f = Body if body else Head
f = _Body if body else _Head
if spec.name:
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
@ -2503,12 +2407,11 @@ def define_concrete_input_specs(self, specs, possible):
def setup(
self,
driver: PyclingoDriver,
specs: Sequence[spack.spec.Spec],
*,
reuse: Optional[List[spack.spec.Spec]] = None,
allow_deprecated: bool = False,
):
) -> str:
"""Generate an ASP program with relevant constraints for specs.
This calls methods on the solve driver to set up the problem with
@ -2516,7 +2419,6 @@ def setup(
specs, as well as constraints from the specs themselves.
Arguments:
driver: driver instance of this solve
specs: list of Specs to solve
reuse: list of concrete specs that can be reused
allow_deprecated: if True adds deprecated versions into the solve
@ -2542,9 +2444,7 @@ def setup(
if node.namespace is not None:
self.explicitly_required_namespaces[node.name] = node.namespace
# driver is used by all the functions below to add facts and
# rules to generate an ASP program.
self.gen = driver
self.gen = ProblemInstanceBuilder()
if not allow_deprecated:
self.gen.fact(fn.deprecated_versions_not_allowed())
@ -2648,6 +2548,29 @@ def setup(
self.gen.h1("Target Constraints")
self.define_target_constraints()
self.gen.h1("Internal errors")
self.internal_errors()
return self.gen.value()
def internal_errors(self):
parent_dir = os.path.dirname(__file__)
def visit(node):
if ast_type(node) == ASTType.Rule:
for term in node.body:
if ast_type(term) == ASTType.Literal:
if ast_type(term.atom) == ASTType.SymbolicAtom:
name = ast_sym(term.atom).name
if name == "internal_error":
arg = ast_sym(ast_sym(term.atom).arguments[0])
symbol = AspFunction(name)(arg.string)
self.assumptions.append((parse_term(str(symbol)), True))
self.gen.asp_problem.append(f"{{ {symbol} }}.\n")
path = os.path.join(parent_dir, "concretize.lp")
parse_files([path], visit)
def define_runtime_constraints(self):
"""Define the constraints to be imposed on the runtimes"""
recorder = RuntimePropertyRecorder(self)
@ -2778,6 +2701,83 @@ def pkg_class(self, pkg_name: str) -> typing.Type["spack.package_base.PackageBas
return spack.repo.PATH.get_pkg_class(request)
class _Head:
"""ASP functions used to express spec clauses in the HEAD of a rule"""
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform_set")
node_os = fn.attr("node_os_set")
node_target = fn.attr("node_target_set")
variant_value = fn.attr("variant_set")
node_compiler = fn.attr("node_compiler_set")
node_compiler_version = fn.attr("node_compiler_version_set")
node_flag = fn.attr("node_flag_set")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
class _Body:
"""ASP functions used to express spec clauses in the BODY of a rule"""
node = fn.attr("node")
virtual_node = fn.attr("virtual_node")
node_platform = fn.attr("node_platform")
node_os = fn.attr("node_os")
node_target = fn.attr("node_target")
variant_value = fn.attr("variant_value")
node_compiler = fn.attr("node_compiler")
node_compiler_version = fn.attr("node_compiler_version")
node_flag = fn.attr("node_flag")
node_flag_source = fn.attr("node_flag_source")
node_flag_propagate = fn.attr("node_flag_propagate")
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
class ProblemInstanceBuilder:
"""Provides an interface to construct a problem instance.
Once all the facts and rules have been added, the problem instance can be retrieved with:
>>> builder = ProblemInstanceBuilder()
>>> ...
>>> problem_instance = builder.value()
The problem instance can be added directly to the "control" structure of clingo.
"""
def __init__(self):
self.asp_problem = []
def fact(self, atom: AspFunction) -> None:
symbol = atom.symbol() if hasattr(atom, "symbol") else atom
self.asp_problem.append(f"{str(symbol)}.\n")
def append(self, rule: str) -> None:
self.asp_problem.append(rule)
def title(self, header: str, char: str) -> None:
self.asp_problem.append("\n")
self.asp_problem.append("%" + (char * 76))
self.asp_problem.append("\n")
self.asp_problem.append(f"% {header}\n")
self.asp_problem.append("%" + (char * 76))
self.asp_problem.append("\n")
def h1(self, header: str) -> None:
self.title(header, "=")
def h2(self, header: str) -> None:
self.title(header, "-")
def newline(self):
self.asp_problem.append("\n")
def value(self) -> str:
return "".join(self.asp_problem)
class RequirementParser:
"""Parses requirements from package.py files and configuration, and returns rules."""
@ -3085,9 +3085,7 @@ def consume_facts(self):
self._setup.gen.h2("Runtimes: rules")
self._setup.gen.newline()
for rule in self.rules:
if not isinstance(self._setup.gen.out, llnl.util.lang.Devnull):
self._setup.gen.out.write(rule)
self._setup.gen.control.add("base", [], rule)
self._setup.gen.append(rule)
self._setup.gen.h2("Runtimes: conditions")
for runtime_pkg in spack.repo.PATH.packages_with_tags("runtime"):