SpackCommand uses log_output to capture command output.
This commit is contained in:
parent
4f444c5f58
commit
fa1faa61c4
5 changed files with 29 additions and 43 deletions
|
@ -34,9 +34,10 @@
|
|||
import inspect
|
||||
import pstats
|
||||
import argparse
|
||||
import tempfile
|
||||
from six import StringIO
|
||||
|
||||
import llnl.util.tty as tty
|
||||
from llnl.util.tty.log import log_output
|
||||
from llnl.util.tty.color import *
|
||||
|
||||
import spack
|
||||
|
@ -367,7 +368,7 @@ class SpackCommand(object):
|
|||
install('-v', 'mpich')
|
||||
|
||||
Use this to invoke Spack commands directly from Python and check
|
||||
their stdout and stderr.
|
||||
their output.
|
||||
"""
|
||||
def __init__(self, command):
|
||||
"""Create a new SpackCommand that invokes ``command`` when called."""
|
||||
|
@ -376,9 +377,6 @@ def __init__(self, command):
|
|||
self.command_name = command
|
||||
self.command = spack.cmd.get_command(command)
|
||||
|
||||
self.returncode = None
|
||||
self.error = None
|
||||
|
||||
def __call__(self, *argv, **kwargs):
|
||||
"""Invoke this SpackCommand.
|
||||
|
||||
|
@ -389,24 +387,24 @@ def __call__(self, *argv, **kwargs):
|
|||
fail_on_error (optional bool): Don't raise an exception on error
|
||||
|
||||
Returns:
|
||||
(str, str): output and error as a strings
|
||||
(str): combined output and error as a string
|
||||
|
||||
On return, if ``fail_on_error`` is False, return value of comman
|
||||
is set in ``returncode`` property, and the error is set in the
|
||||
``error`` property. Otherwise, raise an error.
|
||||
"""
|
||||
# set these before every call to clear them out
|
||||
self.returncode = None
|
||||
self.error = None
|
||||
|
||||
args, unknown = self.parser.parse_known_args(
|
||||
[self.command_name] + list(argv))
|
||||
|
||||
fail_on_error = kwargs.get('fail_on_error', True)
|
||||
|
||||
out, err = sys.stdout, sys.stderr
|
||||
ofd, ofn = tempfile.mkstemp()
|
||||
efd, efn = tempfile.mkstemp()
|
||||
|
||||
out = StringIO()
|
||||
try:
|
||||
sys.stdout = open(ofn, 'w')
|
||||
sys.stderr = open(efn, 'w')
|
||||
with log_output(out):
|
||||
self.returncode = _invoke_spack_command(
|
||||
self.command, self.parser, args, unknown)
|
||||
|
||||
|
@ -418,25 +416,13 @@ def __call__(self, *argv, **kwargs):
|
|||
if fail_on_error:
|
||||
raise
|
||||
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
sys.stdout.close()
|
||||
sys.stderr.flush()
|
||||
sys.stderr.close()
|
||||
sys.stdout, sys.stderr = out, err
|
||||
|
||||
return_out = open(ofn).read()
|
||||
return_err = open(efn).read()
|
||||
os.unlink(ofn)
|
||||
os.unlink(efn)
|
||||
|
||||
if fail_on_error and self.returncode not in (None, 0):
|
||||
raise SpackCommandError(
|
||||
"Command exited with code %d: %s(%s)" % (
|
||||
self.returncode, self.command_name,
|
||||
', '.join("'%s'" % a for a in argv)))
|
||||
|
||||
return return_out, return_err
|
||||
return out.getvalue()
|
||||
|
||||
|
||||
def _main(command, parser, args, unknown_args):
|
||||
|
|
|
@ -36,14 +36,14 @@
|
|||
|
||||
|
||||
def test_immediate_dependencies(builtin_mock):
|
||||
out, err = dependencies('mpileaks')
|
||||
out = dependencies('mpileaks')
|
||||
actual = set(re.split(r'\s+', out.strip()))
|
||||
expected = set(['callpath'] + mpis)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_transitive_dependencies(builtin_mock):
|
||||
out, err = dependencies('--transitive', 'mpileaks')
|
||||
out = dependencies('--transitive', 'mpileaks')
|
||||
actual = set(re.split(r'\s+', out.strip()))
|
||||
expected = set(
|
||||
['callpath', 'dyninst', 'libdwarf', 'libelf'] + mpis + mpi_deps)
|
||||
|
@ -52,7 +52,7 @@ def test_transitive_dependencies(builtin_mock):
|
|||
|
||||
def test_immediate_installed_dependencies(builtin_mock, database):
|
||||
with color_when(False):
|
||||
out, err = dependencies('--installed', 'mpileaks^mpich')
|
||||
out = dependencies('--installed', 'mpileaks^mpich')
|
||||
|
||||
lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
|
||||
hashes = set([re.split(r'\s+', l)[0] for l in lines])
|
||||
|
@ -65,7 +65,7 @@ def test_immediate_installed_dependencies(builtin_mock, database):
|
|||
|
||||
def test_transitive_installed_dependencies(builtin_mock, database):
|
||||
with color_when(False):
|
||||
out, err = dependencies('--installed', '--transitive', 'mpileaks^zmpi')
|
||||
out = dependencies('--installed', '--transitive', 'mpileaks^zmpi')
|
||||
|
||||
lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
|
||||
hashes = set([re.split(r'\s+', l)[0] for l in lines])
|
||||
|
|
|
@ -33,13 +33,13 @@
|
|||
|
||||
|
||||
def test_immediate_dependents(builtin_mock):
|
||||
out, err = dependents('libelf')
|
||||
out = dependents('libelf')
|
||||
actual = set(re.split(r'\s+', out.strip()))
|
||||
assert actual == set(['dyninst', 'libdwarf'])
|
||||
|
||||
|
||||
def test_transitive_dependents(builtin_mock):
|
||||
out, err = dependents('--transitive', 'libelf')
|
||||
out = dependents('--transitive', 'libelf')
|
||||
actual = set(re.split(r'\s+', out.strip()))
|
||||
assert actual == set(
|
||||
['callpath', 'dyninst', 'libdwarf', 'mpileaks', 'multivalue_variant',
|
||||
|
@ -48,7 +48,7 @@ def test_transitive_dependents(builtin_mock):
|
|||
|
||||
def test_immediate_installed_dependents(builtin_mock, database):
|
||||
with color_when(False):
|
||||
out, err = dependents('--installed', 'libelf')
|
||||
out = dependents('--installed', 'libelf')
|
||||
|
||||
lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
|
||||
hashes = set([re.split(r'\s+', l)[0] for l in lines])
|
||||
|
@ -64,7 +64,7 @@ def test_immediate_installed_dependents(builtin_mock, database):
|
|||
|
||||
def test_transitive_installed_dependents(builtin_mock, database):
|
||||
with color_when(False):
|
||||
out, err = dependents('--installed', '--transitive', 'fake')
|
||||
out = dependents('--installed', '--transitive', 'fake')
|
||||
|
||||
lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
|
||||
hashes = set([re.split(r'\s+', l)[0] for l in lines])
|
||||
|
|
|
@ -29,5 +29,5 @@
|
|||
|
||||
|
||||
def test_python():
|
||||
out, err = python('-c', 'import spack; print(spack.spack_version)')
|
||||
out = python('-c', 'import spack; print(spack.spack_version)')
|
||||
assert out.strip() == str(spack.spack_version)
|
||||
|
|
|
@ -83,30 +83,30 @@ def test_url_with_no_version_fails():
|
|||
|
||||
|
||||
def test_url_list():
|
||||
out, err = url('list')
|
||||
out = url('list')
|
||||
total_urls = len(out.split('\n'))
|
||||
|
||||
# The following two options should not change the number of URLs printed.
|
||||
out, err = url('list', '--color', '--extrapolation')
|
||||
out = url('list', '--color', '--extrapolation')
|
||||
colored_urls = len(out.split('\n'))
|
||||
assert colored_urls == total_urls
|
||||
|
||||
# The following options should print fewer URLs than the default.
|
||||
# If they print the same number of URLs, something is horribly broken.
|
||||
# If they say we missed 0 URLs, something is probably broken too.
|
||||
out, err = url('list', '--incorrect-name')
|
||||
out = url('list', '--incorrect-name')
|
||||
incorrect_name_urls = len(out.split('\n'))
|
||||
assert 0 < incorrect_name_urls < total_urls
|
||||
|
||||
out, err = url('list', '--incorrect-version')
|
||||
out = url('list', '--incorrect-version')
|
||||
incorrect_version_urls = len(out.split('\n'))
|
||||
assert 0 < incorrect_version_urls < total_urls
|
||||
|
||||
out, err = url('list', '--correct-name')
|
||||
out = url('list', '--correct-name')
|
||||
correct_name_urls = len(out.split('\n'))
|
||||
assert 0 < correct_name_urls < total_urls
|
||||
|
||||
out, err = url('list', '--correct-version')
|
||||
out = url('list', '--correct-version')
|
||||
correct_version_urls = len(out.split('\n'))
|
||||
assert 0 < correct_version_urls < total_urls
|
||||
|
||||
|
@ -121,7 +121,7 @@ def test_url_summary():
|
|||
assert 0 < correct_versions <= sum(version_count_dict.values()) <= total_urls # noqa
|
||||
|
||||
# make sure it agrees with the actual command.
|
||||
out, err = url('summary')
|
||||
out = url('summary')
|
||||
out_total_urls = int(
|
||||
re.search(r'Total URLs found:\s*(\d+)', out).group(1))
|
||||
assert out_total_urls == total_urls
|
||||
|
|
Loading…
Reference in a new issue