SpackCommand uses log_output to capture command output.

This commit is contained in:
Todd Gamblin 2017-08-22 14:01:02 -07:00
parent 4f444c5f58
commit fa1faa61c4
5 changed files with 29 additions and 43 deletions

View file

@ -34,9 +34,10 @@
import inspect import inspect
import pstats import pstats
import argparse import argparse
import tempfile from six import StringIO
import llnl.util.tty as tty import llnl.util.tty as tty
from llnl.util.tty.log import log_output
from llnl.util.tty.color import * from llnl.util.tty.color import *
import spack import spack
@ -367,7 +368,7 @@ class SpackCommand(object):
install('-v', 'mpich') install('-v', 'mpich')
Use this to invoke Spack commands directly from Python and check Use this to invoke Spack commands directly from Python and check
their stdout and stderr. their output.
""" """
def __init__(self, command): def __init__(self, command):
"""Create a new SpackCommand that invokes ``command`` when called.""" """Create a new SpackCommand that invokes ``command`` when called."""
@ -376,9 +377,6 @@ def __init__(self, command):
self.command_name = command self.command_name = command
self.command = spack.cmd.get_command(command) self.command = spack.cmd.get_command(command)
self.returncode = None
self.error = None
def __call__(self, *argv, **kwargs): def __call__(self, *argv, **kwargs):
"""Invoke this SpackCommand. """Invoke this SpackCommand.
@ -389,24 +387,24 @@ def __call__(self, *argv, **kwargs):
fail_on_error (optional bool): Don't raise an exception on error fail_on_error (optional bool): Don't raise an exception on error
Returns: Returns:
(str, str): output and error as a strings (str): combined output and error as a string
On return, if ``fail_on_error`` is False, return value of comman On return, if ``fail_on_error`` is False, return value of comman
is set in ``returncode`` property, and the error is set in the is set in ``returncode`` property, and the error is set in the
``error`` property. Otherwise, raise an error. ``error`` property. Otherwise, raise an error.
""" """
# set these before every call to clear them out
self.returncode = None
self.error = None
args, unknown = self.parser.parse_known_args( args, unknown = self.parser.parse_known_args(
[self.command_name] + list(argv)) [self.command_name] + list(argv))
fail_on_error = kwargs.get('fail_on_error', True) fail_on_error = kwargs.get('fail_on_error', True)
out, err = sys.stdout, sys.stderr out = StringIO()
ofd, ofn = tempfile.mkstemp()
efd, efn = tempfile.mkstemp()
try: try:
sys.stdout = open(ofn, 'w') with log_output(out):
sys.stderr = open(efn, 'w')
self.returncode = _invoke_spack_command( self.returncode = _invoke_spack_command(
self.command, self.parser, args, unknown) self.command, self.parser, args, unknown)
@ -418,25 +416,13 @@ def __call__(self, *argv, **kwargs):
if fail_on_error: if fail_on_error:
raise raise
finally:
sys.stdout.flush()
sys.stdout.close()
sys.stderr.flush()
sys.stderr.close()
sys.stdout, sys.stderr = out, err
return_out = open(ofn).read()
return_err = open(efn).read()
os.unlink(ofn)
os.unlink(efn)
if fail_on_error and self.returncode not in (None, 0): if fail_on_error and self.returncode not in (None, 0):
raise SpackCommandError( raise SpackCommandError(
"Command exited with code %d: %s(%s)" % ( "Command exited with code %d: %s(%s)" % (
self.returncode, self.command_name, self.returncode, self.command_name,
', '.join("'%s'" % a for a in argv))) ', '.join("'%s'" % a for a in argv)))
return return_out, return_err return out.getvalue()
def _main(command, parser, args, unknown_args): def _main(command, parser, args, unknown_args):

View file

@ -36,14 +36,14 @@
def test_immediate_dependencies(builtin_mock): def test_immediate_dependencies(builtin_mock):
out, err = dependencies('mpileaks') out = dependencies('mpileaks')
actual = set(re.split(r'\s+', out.strip())) actual = set(re.split(r'\s+', out.strip()))
expected = set(['callpath'] + mpis) expected = set(['callpath'] + mpis)
assert expected == actual assert expected == actual
def test_transitive_dependencies(builtin_mock): def test_transitive_dependencies(builtin_mock):
out, err = dependencies('--transitive', 'mpileaks') out = dependencies('--transitive', 'mpileaks')
actual = set(re.split(r'\s+', out.strip())) actual = set(re.split(r'\s+', out.strip()))
expected = set( expected = set(
['callpath', 'dyninst', 'libdwarf', 'libelf'] + mpis + mpi_deps) ['callpath', 'dyninst', 'libdwarf', 'libelf'] + mpis + mpi_deps)
@ -52,7 +52,7 @@ def test_transitive_dependencies(builtin_mock):
def test_immediate_installed_dependencies(builtin_mock, database): def test_immediate_installed_dependencies(builtin_mock, database):
with color_when(False): with color_when(False):
out, err = dependencies('--installed', 'mpileaks^mpich') out = dependencies('--installed', 'mpileaks^mpich')
lines = [l for l in out.strip().split('\n') if not l.startswith('--')] lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
hashes = set([re.split(r'\s+', l)[0] for l in lines]) hashes = set([re.split(r'\s+', l)[0] for l in lines])
@ -65,7 +65,7 @@ def test_immediate_installed_dependencies(builtin_mock, database):
def test_transitive_installed_dependencies(builtin_mock, database): def test_transitive_installed_dependencies(builtin_mock, database):
with color_when(False): with color_when(False):
out, err = dependencies('--installed', '--transitive', 'mpileaks^zmpi') out = dependencies('--installed', '--transitive', 'mpileaks^zmpi')
lines = [l for l in out.strip().split('\n') if not l.startswith('--')] lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
hashes = set([re.split(r'\s+', l)[0] for l in lines]) hashes = set([re.split(r'\s+', l)[0] for l in lines])

View file

@ -33,13 +33,13 @@
def test_immediate_dependents(builtin_mock): def test_immediate_dependents(builtin_mock):
out, err = dependents('libelf') out = dependents('libelf')
actual = set(re.split(r'\s+', out.strip())) actual = set(re.split(r'\s+', out.strip()))
assert actual == set(['dyninst', 'libdwarf']) assert actual == set(['dyninst', 'libdwarf'])
def test_transitive_dependents(builtin_mock): def test_transitive_dependents(builtin_mock):
out, err = dependents('--transitive', 'libelf') out = dependents('--transitive', 'libelf')
actual = set(re.split(r'\s+', out.strip())) actual = set(re.split(r'\s+', out.strip()))
assert actual == set( assert actual == set(
['callpath', 'dyninst', 'libdwarf', 'mpileaks', 'multivalue_variant', ['callpath', 'dyninst', 'libdwarf', 'mpileaks', 'multivalue_variant',
@ -48,7 +48,7 @@ def test_transitive_dependents(builtin_mock):
def test_immediate_installed_dependents(builtin_mock, database): def test_immediate_installed_dependents(builtin_mock, database):
with color_when(False): with color_when(False):
out, err = dependents('--installed', 'libelf') out = dependents('--installed', 'libelf')
lines = [l for l in out.strip().split('\n') if not l.startswith('--')] lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
hashes = set([re.split(r'\s+', l)[0] for l in lines]) hashes = set([re.split(r'\s+', l)[0] for l in lines])
@ -64,7 +64,7 @@ def test_immediate_installed_dependents(builtin_mock, database):
def test_transitive_installed_dependents(builtin_mock, database): def test_transitive_installed_dependents(builtin_mock, database):
with color_when(False): with color_when(False):
out, err = dependents('--installed', '--transitive', 'fake') out = dependents('--installed', '--transitive', 'fake')
lines = [l for l in out.strip().split('\n') if not l.startswith('--')] lines = [l for l in out.strip().split('\n') if not l.startswith('--')]
hashes = set([re.split(r'\s+', l)[0] for l in lines]) hashes = set([re.split(r'\s+', l)[0] for l in lines])

View file

@ -29,5 +29,5 @@
def test_python(): def test_python():
out, err = python('-c', 'import spack; print(spack.spack_version)') out = python('-c', 'import spack; print(spack.spack_version)')
assert out.strip() == str(spack.spack_version) assert out.strip() == str(spack.spack_version)

View file

@ -83,30 +83,30 @@ def test_url_with_no_version_fails():
def test_url_list(): def test_url_list():
out, err = url('list') out = url('list')
total_urls = len(out.split('\n')) total_urls = len(out.split('\n'))
# The following two options should not change the number of URLs printed. # The following two options should not change the number of URLs printed.
out, err = url('list', '--color', '--extrapolation') out = url('list', '--color', '--extrapolation')
colored_urls = len(out.split('\n')) colored_urls = len(out.split('\n'))
assert colored_urls == total_urls assert colored_urls == total_urls
# The following options should print fewer URLs than the default. # The following options should print fewer URLs than the default.
# If they print the same number of URLs, something is horribly broken. # If they print the same number of URLs, something is horribly broken.
# If they say we missed 0 URLs, something is probably broken too. # If they say we missed 0 URLs, something is probably broken too.
out, err = url('list', '--incorrect-name') out = url('list', '--incorrect-name')
incorrect_name_urls = len(out.split('\n')) incorrect_name_urls = len(out.split('\n'))
assert 0 < incorrect_name_urls < total_urls assert 0 < incorrect_name_urls < total_urls
out, err = url('list', '--incorrect-version') out = url('list', '--incorrect-version')
incorrect_version_urls = len(out.split('\n')) incorrect_version_urls = len(out.split('\n'))
assert 0 < incorrect_version_urls < total_urls assert 0 < incorrect_version_urls < total_urls
out, err = url('list', '--correct-name') out = url('list', '--correct-name')
correct_name_urls = len(out.split('\n')) correct_name_urls = len(out.split('\n'))
assert 0 < correct_name_urls < total_urls assert 0 < correct_name_urls < total_urls
out, err = url('list', '--correct-version') out = url('list', '--correct-version')
correct_version_urls = len(out.split('\n')) correct_version_urls = len(out.split('\n'))
assert 0 < correct_version_urls < total_urls assert 0 < correct_version_urls < total_urls
@ -121,7 +121,7 @@ def test_url_summary():
assert 0 < correct_versions <= sum(version_count_dict.values()) <= total_urls # noqa assert 0 < correct_versions <= sum(version_count_dict.values()) <= total_urls # noqa
# make sure it agrees with the actual command. # make sure it agrees with the actual command.
out, err = url('summary') out = url('summary')
out_total_urls = int( out_total_urls = int(
re.search(r'Total URLs found:\s*(\d+)', out).group(1)) re.search(r'Total URLs found:\s*(\d+)', out).group(1))
assert out_total_urls == total_urls assert out_total_urls == total_urls