ASP-based solver: decouple setup phase from clingo.backend
(#41952)
Currently, the `SpackSolverSetup` and the `PyclingoDriver` are more coupled than necessary: 1. The driver object needs a setup object to be injected during a solve, 2. And the setup object will get a reference back to the driver This design is necessary because we use the low-level `clingo.backend` interface to setup our problem. This interface though is meant to bypass the grounder and add symbols directly in the grounded table, which is a feature we don't currently use. The PR simplifies the encoding by having the setup object returning the problem-specific facts / rules as a list of strings, and the driver ingesting them using the [clingo.Control.add](https://potassco.org/clingo/python-api/5.6/clingo/control.html#clingo.control.Control.add) method. This removes any use of the low level interface. Using this encoding makes it easy to hash the output of the setup phase, since it is returned as a string.
This commit is contained in:
parent
97fb9565ee
commit
5c49bb45c7
1 changed files with 133 additions and 135 deletions
|
@ -96,6 +96,7 @@
|
|||
# these are from clingo.ast and bootstrapped later
|
||||
ASTType = None
|
||||
parse_files = None
|
||||
parse_term = None
|
||||
|
||||
#: Enable the addition of a runtime node
|
||||
WITH_RUNTIME = sys.platform != "win32"
|
||||
|
@ -310,11 +311,11 @@ def _id(thing):
|
|||
if isinstance(thing, AspObject):
|
||||
return thing
|
||||
elif isinstance(thing, bool):
|
||||
return '"%s"' % str(thing)
|
||||
return f'"{str(thing)}"'
|
||||
elif isinstance(thing, int):
|
||||
return str(thing)
|
||||
else:
|
||||
return '"%s"' % str(thing)
|
||||
return f'"{str(thing)}"'
|
||||
|
||||
|
||||
@llnl.util.lang.key_ordering
|
||||
|
@ -351,21 +352,20 @@ def __call__(self, *args):
|
|||
"""
|
||||
return AspFunction(self.name, self.args + args)
|
||||
|
||||
def symbol(self, positive=True):
|
||||
def argify(arg):
|
||||
if isinstance(arg, bool):
|
||||
return clingo.String(str(arg))
|
||||
elif isinstance(arg, int):
|
||||
return clingo.Number(arg)
|
||||
elif isinstance(arg, AspFunction):
|
||||
return clingo.Function(arg.name, [argify(x) for x in arg.args], positive=positive)
|
||||
else:
|
||||
return clingo.String(str(arg))
|
||||
def argify(self, arg):
|
||||
if isinstance(arg, bool):
|
||||
return clingo.String(str(arg))
|
||||
elif isinstance(arg, int):
|
||||
return clingo.Number(arg)
|
||||
elif isinstance(arg, AspFunction):
|
||||
return clingo.Function(arg.name, [self.argify(x) for x in arg.args], positive=True)
|
||||
return clingo.String(str(arg))
|
||||
|
||||
return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive)
|
||||
def symbol(self):
|
||||
return clingo.Function(self.name, [self.argify(arg) for arg in self.args], positive=True)
|
||||
|
||||
def __str__(self):
|
||||
return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args))
|
||||
return f"{self.name}({', '.join(str(_id(arg)) for arg in self.args)})"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
@ -664,7 +664,7 @@ def _spec_with_default_name(spec_str, name):
|
|||
|
||||
|
||||
def bootstrap_clingo():
|
||||
global clingo, ASTType, parse_files
|
||||
global clingo, ASTType, parse_files, parse_term
|
||||
|
||||
if not clingo:
|
||||
import spack.bootstrap
|
||||
|
@ -677,9 +677,10 @@ def bootstrap_clingo():
|
|||
|
||||
try:
|
||||
from clingo.ast import parse_files
|
||||
from clingo.symbol import parse_term
|
||||
except ImportError:
|
||||
# older versions of clingo have this one namespace up
|
||||
from clingo import parse_files
|
||||
from clingo import parse_files, parse_term
|
||||
|
||||
|
||||
class NodeArgument(NamedTuple):
|
||||
|
@ -882,53 +883,9 @@ def __init__(self, cores=True):
|
|||
error reporting.
|
||||
"""
|
||||
bootstrap_clingo()
|
||||
|
||||
self.out = llnl.util.lang.Devnull()
|
||||
self.cores = cores
|
||||
|
||||
# These attributes are part of the object, but will be reset
|
||||
# at each call to solve
|
||||
# This attribute will be reset at each call to solve
|
||||
self.control = None
|
||||
self.backend = None
|
||||
self.assumptions = None
|
||||
|
||||
def title(self, name, char):
|
||||
self.out.write("\n")
|
||||
self.out.write("%" + (char * 76))
|
||||
self.out.write("\n")
|
||||
self.out.write("%% %s\n" % name)
|
||||
self.out.write("%" + (char * 76))
|
||||
self.out.write("\n")
|
||||
|
||||
def h1(self, name):
|
||||
self.title(name, "=")
|
||||
|
||||
def h2(self, name):
|
||||
self.title(name, "-")
|
||||
|
||||
def newline(self):
|
||||
self.out.write("\n")
|
||||
|
||||
def fact(self, head):
|
||||
"""ASP fact (a rule without a body).
|
||||
|
||||
Arguments:
|
||||
head (AspFunction): ASP function to generate as fact
|
||||
"""
|
||||
symbol = head.symbol() if hasattr(head, "symbol") else head
|
||||
|
||||
# This is commented out to avoid evaluating str(symbol) when we have no stream
|
||||
if not isinstance(self.out, llnl.util.lang.Devnull):
|
||||
self.out.write(f"{str(symbol)}.\n")
|
||||
|
||||
atom = self.backend.add_atom(symbol)
|
||||
|
||||
# Only functions relevant for constructing bug reports for bad error messages
|
||||
# are assumptions, and only when using cores.
|
||||
choice = self.cores and symbol.name == "internal_error"
|
||||
self.backend.add_rule([atom], [], choice=choice)
|
||||
if choice:
|
||||
self.assumptions.append(atom)
|
||||
|
||||
def solve(self, setup, specs, reuse=None, output=None, control=None, allow_deprecated=False):
|
||||
"""Set up the input and solve for dependencies of ``specs``.
|
||||
|
@ -948,49 +905,24 @@ def solve(self, setup, specs, reuse=None, output=None, control=None, allow_depre
|
|||
solve, and the internal statistics from clingo.
|
||||
"""
|
||||
output = output or DEFAULT_OUTPUT_CONFIGURATION
|
||||
# allow solve method to override the output stream
|
||||
if output.out is not None:
|
||||
self.out = output.out
|
||||
|
||||
timer = spack.util.timer.Timer()
|
||||
|
||||
# Initialize the control object for the solver
|
||||
self.control = control or default_clingo_control()
|
||||
# set up the problem -- this generates facts and rules
|
||||
self.assumptions = []
|
||||
|
||||
timer.start("setup")
|
||||
with self.control.backend() as backend:
|
||||
self.backend = backend
|
||||
setup.setup(self, specs, reuse=reuse, allow_deprecated=allow_deprecated)
|
||||
asp_problem = setup.setup(specs, reuse=reuse, allow_deprecated=allow_deprecated)
|
||||
if output.out is not None:
|
||||
output.out.write(asp_problem)
|
||||
if output.setup_only:
|
||||
return Result(specs), None, None
|
||||
timer.stop("setup")
|
||||
|
||||
timer.start("load")
|
||||
# read in the main ASP program and display logic -- these are
|
||||
# handwritten, not generated, so we load them as resources
|
||||
parent_dir = os.path.dirname(__file__)
|
||||
|
||||
# extract error messages from concretize.lp by inspecting its AST
|
||||
with self.backend:
|
||||
|
||||
def visit(node):
|
||||
if ast_type(node) == ASTType.Rule:
|
||||
for term in node.body:
|
||||
if ast_type(term) == ASTType.Literal:
|
||||
if ast_type(term.atom) == ASTType.SymbolicAtom:
|
||||
name = ast_sym(term.atom).name
|
||||
if name == "internal_error":
|
||||
arg = ast_sym(ast_sym(term.atom).arguments[0])
|
||||
self.fact(AspFunction(name)(arg.string))
|
||||
|
||||
self.h1("Error messages")
|
||||
path = os.path.join(parent_dir, "concretize.lp")
|
||||
parse_files([path], visit)
|
||||
|
||||
# If we're only doing setup, just return an empty solve result
|
||||
if output.setup_only:
|
||||
return Result(specs), None, None
|
||||
|
||||
# Add the problem instance
|
||||
self.control.add("base", [], asp_problem)
|
||||
# Load the file itself
|
||||
parent_dir = os.path.dirname(__file__)
|
||||
self.control.load(os.path.join(parent_dir, "concretize.lp"))
|
||||
self.control.load(os.path.join(parent_dir, "heuristic.lp"))
|
||||
if spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") != "none":
|
||||
|
@ -1016,7 +948,7 @@ def on_model(model):
|
|||
models.append((model.cost, model.symbols(shown=True, terms=True)))
|
||||
|
||||
solve_kwargs = {
|
||||
"assumptions": self.assumptions,
|
||||
"assumptions": setup.assumptions,
|
||||
"on_model": on_model,
|
||||
"on_core": cores.append,
|
||||
}
|
||||
|
@ -1142,6 +1074,7 @@ class SpackSolverSetup:
|
|||
def __init__(self, tests=False):
|
||||
self.gen = None # set by setup()
|
||||
|
||||
self.assumptions = []
|
||||
self.declared_versions = collections.defaultdict(list)
|
||||
self.possible_versions = collections.defaultdict(set)
|
||||
self.deprecated_versions = collections.defaultdict(set)
|
||||
|
@ -1878,36 +1811,7 @@ def _spec_clauses(
|
|||
"""
|
||||
clauses = []
|
||||
|
||||
# TODO: do this with consistent suffixes.
|
||||
class Head:
|
||||
node = fn.attr("node")
|
||||
virtual_node = fn.attr("virtual_node")
|
||||
node_platform = fn.attr("node_platform_set")
|
||||
node_os = fn.attr("node_os_set")
|
||||
node_target = fn.attr("node_target_set")
|
||||
variant_value = fn.attr("variant_set")
|
||||
node_compiler = fn.attr("node_compiler_set")
|
||||
node_compiler_version = fn.attr("node_compiler_version_set")
|
||||
node_flag = fn.attr("node_flag_set")
|
||||
node_flag_source = fn.attr("node_flag_source")
|
||||
node_flag_propagate = fn.attr("node_flag_propagate")
|
||||
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
|
||||
|
||||
class Body:
|
||||
node = fn.attr("node")
|
||||
virtual_node = fn.attr("virtual_node")
|
||||
node_platform = fn.attr("node_platform")
|
||||
node_os = fn.attr("node_os")
|
||||
node_target = fn.attr("node_target")
|
||||
variant_value = fn.attr("variant_value")
|
||||
node_compiler = fn.attr("node_compiler")
|
||||
node_compiler_version = fn.attr("node_compiler_version")
|
||||
node_flag = fn.attr("node_flag")
|
||||
node_flag_source = fn.attr("node_flag_source")
|
||||
node_flag_propagate = fn.attr("node_flag_propagate")
|
||||
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
|
||||
|
||||
f = Body if body else Head
|
||||
f = _Body if body else _Head
|
||||
|
||||
if spec.name:
|
||||
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
|
||||
|
@ -2503,12 +2407,11 @@ def define_concrete_input_specs(self, specs, possible):
|
|||
|
||||
def setup(
|
||||
self,
|
||||
driver: PyclingoDriver,
|
||||
specs: Sequence[spack.spec.Spec],
|
||||
*,
|
||||
reuse: Optional[List[spack.spec.Spec]] = None,
|
||||
allow_deprecated: bool = False,
|
||||
):
|
||||
) -> str:
|
||||
"""Generate an ASP program with relevant constraints for specs.
|
||||
|
||||
This calls methods on the solve driver to set up the problem with
|
||||
|
@ -2516,7 +2419,6 @@ def setup(
|
|||
specs, as well as constraints from the specs themselves.
|
||||
|
||||
Arguments:
|
||||
driver: driver instance of this solve
|
||||
specs: list of Specs to solve
|
||||
reuse: list of concrete specs that can be reused
|
||||
allow_deprecated: if True adds deprecated versions into the solve
|
||||
|
@ -2542,9 +2444,7 @@ def setup(
|
|||
if node.namespace is not None:
|
||||
self.explicitly_required_namespaces[node.name] = node.namespace
|
||||
|
||||
# driver is used by all the functions below to add facts and
|
||||
# rules to generate an ASP program.
|
||||
self.gen = driver
|
||||
self.gen = ProblemInstanceBuilder()
|
||||
|
||||
if not allow_deprecated:
|
||||
self.gen.fact(fn.deprecated_versions_not_allowed())
|
||||
|
@ -2648,6 +2548,29 @@ def setup(
|
|||
self.gen.h1("Target Constraints")
|
||||
self.define_target_constraints()
|
||||
|
||||
self.gen.h1("Internal errors")
|
||||
self.internal_errors()
|
||||
|
||||
return self.gen.value()
|
||||
|
||||
def internal_errors(self):
|
||||
parent_dir = os.path.dirname(__file__)
|
||||
|
||||
def visit(node):
|
||||
if ast_type(node) == ASTType.Rule:
|
||||
for term in node.body:
|
||||
if ast_type(term) == ASTType.Literal:
|
||||
if ast_type(term.atom) == ASTType.SymbolicAtom:
|
||||
name = ast_sym(term.atom).name
|
||||
if name == "internal_error":
|
||||
arg = ast_sym(ast_sym(term.atom).arguments[0])
|
||||
symbol = AspFunction(name)(arg.string)
|
||||
self.assumptions.append((parse_term(str(symbol)), True))
|
||||
self.gen.asp_problem.append(f"{{ {symbol} }}.\n")
|
||||
|
||||
path = os.path.join(parent_dir, "concretize.lp")
|
||||
parse_files([path], visit)
|
||||
|
||||
def define_runtime_constraints(self):
|
||||
"""Define the constraints to be imposed on the runtimes"""
|
||||
recorder = RuntimePropertyRecorder(self)
|
||||
|
@ -2778,6 +2701,83 @@ def pkg_class(self, pkg_name: str) -> typing.Type["spack.package_base.PackageBas
|
|||
return spack.repo.PATH.get_pkg_class(request)
|
||||
|
||||
|
||||
class _Head:
|
||||
"""ASP functions used to express spec clauses in the HEAD of a rule"""
|
||||
|
||||
node = fn.attr("node")
|
||||
virtual_node = fn.attr("virtual_node")
|
||||
node_platform = fn.attr("node_platform_set")
|
||||
node_os = fn.attr("node_os_set")
|
||||
node_target = fn.attr("node_target_set")
|
||||
variant_value = fn.attr("variant_set")
|
||||
node_compiler = fn.attr("node_compiler_set")
|
||||
node_compiler_version = fn.attr("node_compiler_version_set")
|
||||
node_flag = fn.attr("node_flag_set")
|
||||
node_flag_source = fn.attr("node_flag_source")
|
||||
node_flag_propagate = fn.attr("node_flag_propagate")
|
||||
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
|
||||
|
||||
|
||||
class _Body:
|
||||
"""ASP functions used to express spec clauses in the BODY of a rule"""
|
||||
|
||||
node = fn.attr("node")
|
||||
virtual_node = fn.attr("virtual_node")
|
||||
node_platform = fn.attr("node_platform")
|
||||
node_os = fn.attr("node_os")
|
||||
node_target = fn.attr("node_target")
|
||||
variant_value = fn.attr("variant_value")
|
||||
node_compiler = fn.attr("node_compiler")
|
||||
node_compiler_version = fn.attr("node_compiler_version")
|
||||
node_flag = fn.attr("node_flag")
|
||||
node_flag_source = fn.attr("node_flag_source")
|
||||
node_flag_propagate = fn.attr("node_flag_propagate")
|
||||
variant_propagation_candidate = fn.attr("variant_propagation_candidate")
|
||||
|
||||
|
||||
class ProblemInstanceBuilder:
|
||||
"""Provides an interface to construct a problem instance.
|
||||
|
||||
Once all the facts and rules have been added, the problem instance can be retrieved with:
|
||||
|
||||
>>> builder = ProblemInstanceBuilder()
|
||||
>>> ...
|
||||
>>> problem_instance = builder.value()
|
||||
|
||||
The problem instance can be added directly to the "control" structure of clingo.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.asp_problem = []
|
||||
|
||||
def fact(self, atom: AspFunction) -> None:
|
||||
symbol = atom.symbol() if hasattr(atom, "symbol") else atom
|
||||
self.asp_problem.append(f"{str(symbol)}.\n")
|
||||
|
||||
def append(self, rule: str) -> None:
|
||||
self.asp_problem.append(rule)
|
||||
|
||||
def title(self, header: str, char: str) -> None:
|
||||
self.asp_problem.append("\n")
|
||||
self.asp_problem.append("%" + (char * 76))
|
||||
self.asp_problem.append("\n")
|
||||
self.asp_problem.append(f"% {header}\n")
|
||||
self.asp_problem.append("%" + (char * 76))
|
||||
self.asp_problem.append("\n")
|
||||
|
||||
def h1(self, header: str) -> None:
|
||||
self.title(header, "=")
|
||||
|
||||
def h2(self, header: str) -> None:
|
||||
self.title(header, "-")
|
||||
|
||||
def newline(self):
|
||||
self.asp_problem.append("\n")
|
||||
|
||||
def value(self) -> str:
|
||||
return "".join(self.asp_problem)
|
||||
|
||||
|
||||
class RequirementParser:
|
||||
"""Parses requirements from package.py files and configuration, and returns rules."""
|
||||
|
||||
|
@ -3085,9 +3085,7 @@ def consume_facts(self):
|
|||
self._setup.gen.h2("Runtimes: rules")
|
||||
self._setup.gen.newline()
|
||||
for rule in self.rules:
|
||||
if not isinstance(self._setup.gen.out, llnl.util.lang.Devnull):
|
||||
self._setup.gen.out.write(rule)
|
||||
self._setup.gen.control.add("base", [], rule)
|
||||
self._setup.gen.append(rule)
|
||||
|
||||
self._setup.gen.h2("Runtimes: conditions")
|
||||
for runtime_pkg in spack.repo.PATH.packages_with_tags("runtime"):
|
||||
|
|
Loading…
Reference in a new issue