Preserve comments for Spack YAML objects (#11602)

This updates the configuration loading/dumping logic (now called
load_config/dump_config) in spack_yaml to preserve comments (by using
ruamel.yaml's RoundTripLoader). This has two effects:

* environment spack.yaml files expect to retain comments, which
  load_config now supports. By using load_config, users can now use the
  ':' override syntax that was previously unavailable for environment
  configs (but was available for other config files).

* config files now retain user comments by default (although in cases
  where Spack updates/overwrites config, the comments can still be
  removed).

Details:

* Subclasses `RoundTripLoader`/`RoundTripDumper` to parse yaml into
  ruamel's `CommentedMap` and analogous data structures

* Applies filename info directly to ruamel objects in cases where the
  updated loader returns those

* Copies management of sections in `SingleFileScope` from #10651 to allow
  overrides to occur

* Updates the loader/dumper to handle the processing of overrides by
  specifically checking for the `:` character
  * Possibly the most controversial aspect, but without that, the parsed
    objects have to be reconstructed (i.e. as was done in
    `mark_overrides`). It is possible that `mark_overrides` could remain
    and a deep copy will not cause problems, but IMO that's generally
    worth avoiding.
  * This is also possibly controversial because Spack YAML strings can
    include `:`. My reckoning is that this only occurs for version
    specifications, so it is safe to check for `endswith(':') and not
    ('@' in string)`
  * As a consequence, this PR ends up reserving spack yaml functions
    load_config/dump_config exclusively for the purpose of storing spack
    config
This commit is contained in:
Todd Gamblin 2019-10-22 23:50:31 -07:00
parent b3f5084b96
commit af65146ef6
13 changed files with 171 additions and 178 deletions

View file

@ -11,6 +11,7 @@
import tempfile import tempfile
import hashlib import hashlib
from contextlib import closing from contextlib import closing
import ruamel.yaml as yaml
import json import json
@ -139,7 +140,7 @@ def read_buildinfo_file(prefix):
filename = buildinfo_file_name(prefix) filename = buildinfo_file_name(prefix)
with open(filename, 'r') as inputfile: with open(filename, 'r') as inputfile:
content = inputfile.read() content = inputfile.read()
buildinfo = syaml.load(content) buildinfo = yaml.load(content)
return buildinfo return buildinfo
@ -380,10 +381,9 @@ def build_tarball(spec, outdir, force=False, rel=False, unsigned=False,
checksum = checksum_tarball(tarfile_path) checksum = checksum_tarball(tarfile_path)
# add sha256 checksum to spec.yaml # add sha256 checksum to spec.yaml
spec_dict = {}
with open(spec_file, 'r') as inputfile: with open(spec_file, 'r') as inputfile:
content = inputfile.read() content = inputfile.read()
spec_dict = syaml.load(content) spec_dict = yaml.load(content)
bchecksum = {} bchecksum = {}
bchecksum['hash_algorithm'] = 'sha256' bchecksum['hash_algorithm'] = 'sha256'
bchecksum['hash'] = checksum bchecksum['hash'] = checksum

View file

@ -655,4 +655,4 @@ def release_jobs(parser, args):
output_object['stages'] = stage_names output_object['stages'] = stage_names
with open(args.output_file, 'w') as outf: with open(args.output_file, 'w') as outf:
outf.write(syaml.dump(output_object)) outf.write(syaml.dump_config(output_object, default_flow_style=True))

View file

@ -36,7 +36,6 @@
import sys import sys
import multiprocessing import multiprocessing
from contextlib import contextmanager from contextlib import contextmanager
from six import string_types
from six import iteritems from six import iteritems
from ordereddict_backport import OrderedDict from ordereddict_backport import OrderedDict
@ -155,7 +154,7 @@ def write_section(self, section):
mkdirp(self.path) mkdirp(self.path)
with open(filename, 'w') as f: with open(filename, 'w') as f:
validate(data, section_schemas[section]) validate(data, section_schemas[section])
syaml.dump(data, stream=f, default_flow_style=False) syaml.dump_config(data, stream=f, default_flow_style=False)
except (yaml.YAMLError, IOError) as e: except (yaml.YAMLError, IOError) as e:
raise ConfigFileError( raise ConfigFileError(
"Error writing to config file: '%s'" % str(e)) "Error writing to config file: '%s'" % str(e))
@ -200,6 +199,22 @@ def get_section(self, section):
# ... data ... # ... data ...
# }, # },
# } # }
#
# To preserve overrides up to the section level (e.g. to override
# the "packages" section with the "::" syntax), data in self.sections
# looks like this:
# {
# 'config': {
# 'config': {
# ... data ...
# }
# },
# 'packages': {
# 'packages': {
# ... data ...
# }
# }
# }
if self._raw_data is None: if self._raw_data is None:
self._raw_data = _read_config_file(self.path, self.schema) self._raw_data = _read_config_file(self.path, self.schema)
if self._raw_data is None: if self._raw_data is None:
@ -215,29 +230,10 @@ def get_section(self, section):
self._raw_data = self._raw_data[key] self._raw_data = self._raw_data[key]
# data in self.sections looks (awkwardly) like this: for section_key, data in self._raw_data.items():
# { self.sections[section_key] = {section_key: data}
# 'config': {
# 'config': { return self.sections.get(section, None)
# ... data ...
# }
# },
# 'packages': {
# 'packages': {
# ... data ...
# }
# }
# }
#
# UNLESS there is no section, in which case it is stored as:
# {
# 'config': None,
# ...
# }
value = self._raw_data.get(section)
self.sections.setdefault(
section, None if value is None else {section: value})
return self.sections[section]
def write_section(self, section): def write_section(self, section):
validate(self.sections, self.schema) validate(self.sections, self.schema)
@ -247,7 +243,8 @@ def write_section(self, section):
tmp = os.path.join(parent, '.%s.tmp' % self.path) tmp = os.path.join(parent, '.%s.tmp' % self.path)
with open(tmp, 'w') as f: with open(tmp, 'w') as f:
syaml.dump(self.sections, stream=f, default_flow_style=False) syaml.dump_config(self.sections, stream=f,
default_flow_style=False)
os.path.move(tmp, self.path) os.path.move(tmp, self.path)
except (yaml.YAMLError, IOError) as e: except (yaml.YAMLError, IOError) as e:
raise ConfigFileError( raise ConfigFileError(
@ -533,7 +530,7 @@ def print_section(self, section, blame=False):
try: try:
data = syaml.syaml_dict() data = syaml.syaml_dict()
data[section] = self.get_config(section) data[section] = self.get_config(section)
syaml.dump( syaml.dump_config(
data, stream=sys.stdout, default_flow_style=False, blame=blame) data, stream=sys.stdout, default_flow_style=False, blame=blame)
except (yaml.YAMLError, IOError): except (yaml.YAMLError, IOError):
raise ConfigError("Error reading configuration: %s" % section) raise ConfigError("Error reading configuration: %s" % section)
@ -708,7 +705,7 @@ def _read_config_file(filename, schema):
try: try:
tty.debug("Reading config file %s" % filename) tty.debug("Reading config file %s" % filename)
with open(filename) as f: with open(filename) as f:
data = _mark_overrides(syaml.load(f)) data = syaml.load_config(f)
if data: if data:
validate(data, schema) validate(data, schema)
@ -734,23 +731,6 @@ def _override(string):
return hasattr(string, 'override') and string.override return hasattr(string, 'override') and string.override
def _mark_overrides(data):
if isinstance(data, list):
return syaml.syaml_list(_mark_overrides(elt) for elt in data)
elif isinstance(data, dict):
marked = syaml.syaml_dict()
for key, val in iteritems(data):
if isinstance(key, string_types) and key.endswith(':'):
key = syaml.syaml_str(key[:-1])
key.override = True
marked[key] = _mark_overrides(val)
return marked
else:
return data
def _mark_internal(data, name): def _mark_internal(data, name):
"""Add a simple name mark to raw YAML/JSON data. """Add a simple name mark to raw YAML/JSON data.
@ -820,9 +800,14 @@ def they_are(t):
# ensure that keys are marked in the destination. the key_marks dict # ensure that keys are marked in the destination. the key_marks dict
# ensures we can get the actual source key objects from dest keys # ensures we can get the actual source key objects from dest keys
for dk in dest.keys(): for dk in list(dest.keys()):
if dk in key_marks: if dk in key_marks and syaml.markable(dk):
syaml.mark(dk, key_marks[dk]) syaml.mark(dk, key_marks[dk])
elif dk in key_marks:
# The destination key may not be markable if it was derived
# from a schema default. In this case replace the key.
val = dest.pop(dk)
dest[key_marks[dk]] = val
return dest return dest

View file

@ -11,7 +11,6 @@
import copy import copy
import socket import socket
import ruamel.yaml
import six import six
from ordereddict_backport import OrderedDict from ordereddict_backport import OrderedDict
@ -27,6 +26,7 @@
import spack.schema.env import spack.schema.env
import spack.spec import spack.spec
import spack.util.spack_json as sjson import spack.util.spack_json as sjson
import spack.util.spack_yaml as syaml
import spack.config import spack.config
import spack.build_environment as build_env import spack.build_environment as build_env
@ -401,7 +401,7 @@ def validate(data, filename=None):
def _read_yaml(str_or_file): def _read_yaml(str_or_file):
"""Read YAML from a file for round-trip parsing.""" """Read YAML from a file for round-trip parsing."""
data = ruamel.yaml.load(str_or_file, ruamel.yaml.RoundTripLoader) data = syaml.load_config(str_or_file)
filename = getattr(str_or_file, 'name', None) filename = getattr(str_or_file, 'name', None)
validate(data, filename) validate(data, filename)
return data return data
@ -411,8 +411,7 @@ def _write_yaml(data, str_or_file):
"""Write YAML to a file preserving comments and dict order.""" """Write YAML to a file preserving comments and dict order."""
filename = getattr(str_or_file, 'name', None) filename = getattr(str_or_file, 'name', None)
validate(data, filename) validate(data, filename)
ruamel.yaml.dump(data, str_or_file, Dumper=ruamel.yaml.RoundTripDumper, syaml.dump_config(data, str_or_file, default_flow_style=False)
default_flow_style=False)
def _eval_conditional(string): def _eval_conditional(string):

View file

@ -211,7 +211,7 @@ def write_projections(self):
if self.projections: if self.projections:
mkdirp(os.path.dirname(self.projections_path)) mkdirp(os.path.dirname(self.projections_path))
with open(self.projections_path, 'w') as f: with open(self.projections_path, 'w') as f:
f.write(s_yaml.dump({'projections': self.projections})) f.write(s_yaml.dump_config({'projections': self.projections}))
def read_projections(self): def read_projections(self):
if os.path.exists(self.projections_path): if os.path.exists(self.projections_path):

View file

@ -24,7 +24,7 @@ def check_compiler_yaml_version():
data = None data = None
if os.path.isfile(file_name): if os.path.isfile(file_name):
with open(file_name) as f: with open(file_name) as f:
data = syaml.load(f) data = syaml.load_config(f)
if data: if data:
compilers = data.get('compilers') compilers = data.get('compilers')

View file

@ -235,7 +235,7 @@ def generate_module_index(root, modules):
index_path = os.path.join(root, 'module-index.yaml') index_path = os.path.join(root, 'module-index.yaml')
llnl.util.filesystem.mkdirp(root) llnl.util.filesystem.mkdirp(root)
with open(index_path, 'w') as index_file: with open(index_path, 'w') as index_file:
syaml.dump(index, index_file, default_flow_style=False) syaml.dump(index, default_flow_style=False, stream=index_file)
def _generate_upstream_module_index(): def _generate_upstream_module_index():

View file

@ -79,7 +79,6 @@
import base64 import base64
import sys import sys
import collections import collections
import ctypes
import hashlib import hashlib
import itertools import itertools
import os import os
@ -88,6 +87,7 @@
import six import six
from operator import attrgetter from operator import attrgetter
import ruamel.yaml as yaml
from llnl.util.filesystem import find_headers, find_libraries, is_exe from llnl.util.filesystem import find_headers, find_libraries, is_exe
from llnl.util.lang import key_ordering, HashableMap, ObjectWrapper, dedupe from llnl.util.lang import key_ordering, HashableMap, ObjectWrapper, dedupe
@ -185,9 +185,6 @@
#: every time we call str() #: every time we call str()
_any_version = VersionList([':']) _any_version = VersionList([':'])
#: Max integer helps avoid passing too large a value to cyaml.
maxint = 2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1) - 1
default_format = '{name}{@version}' default_format = '{name}{@version}'
default_format += '{%compiler.name}{@compiler.version}{compiler_flags}' default_format += '{%compiler.name}{@compiler.version}{compiler_flags}'
default_format += '{variants}{arch=architecture}' default_format += '{variants}{arch=architecture}'
@ -1366,8 +1363,8 @@ def _spec_hash(self, hash):
""" """
# TODO: curently we strip build dependencies by default. Rethink # TODO: curently we strip build dependencies by default. Rethink
# this when we move to using package hashing on all specs. # this when we move to using package hashing on all specs.
yaml_text = syaml.dump(self.to_node_dict(hash=hash), yaml_text = syaml.dump(
default_flow_style=True, width=maxint) self.to_node_dict(hash=hash), default_flow_style=True)
sha = hashlib.sha1(yaml_text.encode('utf-8')) sha = hashlib.sha1(yaml_text.encode('utf-8'))
b32_hash = base64.b32encode(sha.digest()).lower() b32_hash = base64.b32encode(sha.digest()).lower()
@ -1937,7 +1934,7 @@ def from_yaml(stream):
stream -- string or file object to read from. stream -- string or file object to read from.
""" """
try: try:
data = syaml.load(stream) data = yaml.load(stream)
return Spec.from_dict(data) return Spec.from_dict(data)
except MarkedYAMLError as e: except MarkedYAMLError as e:
raise syaml.SpackYAMLError("error parsing YAML spec:", str(e)) raise syaml.SpackYAMLError("error parsing YAML spec:", str(e))

View file

@ -5,6 +5,7 @@
import os import os
import pytest import pytest
import re
import spack import spack
import spack.environment as ev import spack.environment as ev
@ -122,6 +123,6 @@ def test_release_jobs_with_env(tmpdir, mutable_mock_env_path, env_deactivate,
release_jobs('--output-file', outputfile) release_jobs('--output-file', outputfile)
with open(outputfile) as f: with open(outputfile) as f:
contents = f.read() contents = f.read().replace(os.linesep, '')
assert('archive-files' in contents) assert('archive-files' in contents)
assert('stages: [stage-0' in contents) assert(re.search(r'stages:\s*\[\s*stage-0', contents))

View file

@ -29,7 +29,7 @@ def concretize_scope(config, tmpdir):
@pytest.fixture() @pytest.fixture()
def configure_permissions(): def configure_permissions():
conf = syaml.load("""\ conf = syaml.load_config("""\
all: all:
permissions: permissions:
read: group read: group
@ -182,9 +182,9 @@ def test_no_virtuals_in_packages_yaml(self):
"""Verify that virtuals are not allowed in packages.yaml.""" """Verify that virtuals are not allowed in packages.yaml."""
# set up a packages.yaml file with a vdep as a key. We use # set up a packages.yaml file with a vdep as a key. We use
# syaml.load here to make sure source lines in the config are # syaml.load_config here to make sure source lines in the config are
# attached to parsed strings, as the error message uses them. # attached to parsed strings, as the error message uses them.
conf = syaml.load("""\ conf = syaml.load_config("""\
mpi: mpi:
paths: paths:
mpi-with-lapack@2.1: /path/to/lapack mpi-with-lapack@2.1: /path/to/lapack
@ -197,7 +197,7 @@ def test_no_virtuals_in_packages_yaml(self):
def test_all_is_not_a_virtual(self): def test_all_is_not_a_virtual(self):
"""Verify that `all` is allowed in packages.yaml.""" """Verify that `all` is allowed in packages.yaml."""
conf = syaml.load("""\ conf = syaml.load_config("""\
all: all:
variants: [+mpi] variants: [+mpi]
""") """)
@ -214,7 +214,7 @@ def test_external_mpi(self):
assert not spec['mpi'].external assert not spec['mpi'].external
# load config # load config
conf = syaml.load("""\ conf = syaml.load_config("""\
all: all:
providers: providers:
mpi: [mpich] mpi: [mpich]

View file

@ -12,7 +12,6 @@
from llnl.util.filesystem import touch, mkdirp from llnl.util.filesystem import touch, mkdirp
import pytest import pytest
import ruamel.yaml as yaml
import spack.paths import spack.paths
import spack.config import spack.config
@ -57,7 +56,7 @@ def _write(config, data, scope):
config_yaml = tmpdir.join(scope, config + '.yaml') config_yaml = tmpdir.join(scope, config + '.yaml')
config_yaml.ensure() config_yaml.ensure()
with config_yaml.open('w') as f: with config_yaml.open('w') as f:
yaml.dump(data, f) syaml.dump_config(data, f)
return _write return _write
@ -721,10 +720,44 @@ def test_single_file_scope(tmpdir, config):
'/x/y/z', '$spack/var/spack/repos/builtin'] '/x/y/z', '$spack/var/spack/repos/builtin']
def test_single_file_scope_section_override(tmpdir, config):
"""Check that individual config sections can be overridden in an
environment config. The config here primarily differs in that the
``packages`` section is intended to override all other scopes (using the
"::" syntax).
"""
env_yaml = str(tmpdir.join("env.yaml"))
with open(env_yaml, 'w') as f:
f.write("""\
env:
config:
verify_ssl: False
packages::
libelf:
compiler: [ 'gcc@4.5.3' ]
repos:
- /x/y/z
""")
scope = spack.config.SingleFileScope(
'env', env_yaml, spack.schema.env.schema, ['env'])
with spack.config.override(scope):
# from the single-file config
assert spack.config.get('config:verify_ssl') is False
assert spack.config.get('packages:libelf:compiler') == ['gcc@4.5.3']
# from the lower config scopes
assert spack.config.get('config:checksum') is True
assert not spack.config.get('packages:externalmodule')
assert spack.config.get('repos') == [
'/x/y/z', '$spack/var/spack/repos/builtin']
def check_schema(name, file_contents): def check_schema(name, file_contents):
"""Check a Spack YAML schema against some data""" """Check a Spack YAML schema against some data"""
f = StringIO(file_contents) f = StringIO(file_contents)
data = syaml.load(f) data = syaml.load_config(f)
spack.config.validate(data, name) spack.config.validate(data, name)

View file

@ -27,7 +27,7 @@ def data():
[ 1, 2, 3 ] [ 1, 2, 3 ]
some_key: some_string some_key: some_string
""" """
return syaml.load(test_file) return syaml.load_config(test_file)
def test_parse(data): def test_parse(data):

View file

@ -12,13 +12,14 @@
default unorderd dict. default unorderd dict.
""" """
import ctypes
from ordereddict_backport import OrderedDict from ordereddict_backport import OrderedDict
from six import string_types, StringIO from six import string_types, StringIO
import ruamel.yaml as yaml import ruamel.yaml as yaml
from ruamel.yaml import Loader, Dumper from ruamel.yaml import RoundTripLoader, RoundTripDumper
from ruamel.yaml.nodes import MappingNode, SequenceNode, ScalarNode
from ruamel.yaml.constructor import ConstructorError
from llnl.util.tty.color import colorize, clen, cextra from llnl.util.tty.color import colorize, clen, cextra
@ -58,6 +59,11 @@ class syaml_int(int):
} }
markable_types = set(syaml_types) | set([
yaml.comments.CommentedSeq,
yaml.comments.CommentedMap])
def syaml_type(obj): def syaml_type(obj):
"""Get the corresponding syaml wrapper type for a primitive type. """Get the corresponding syaml wrapper type for a primitive type.
@ -72,19 +78,15 @@ def syaml_type(obj):
def markable(obj): def markable(obj):
"""Whether an object can be marked.""" """Whether an object can be marked."""
return type(obj) in syaml_types return type(obj) in markable_types
def mark(obj, node): def mark(obj, node):
"""Add start and end markers to an object.""" """Add start and end markers to an object."""
if not markable(obj):
return
if hasattr(node, 'start_mark'): if hasattr(node, 'start_mark'):
obj._start_mark = node.start_mark obj._start_mark = node.start_mark
elif hasattr(node, '_start_mark'): elif hasattr(node, '_start_mark'):
obj._start_mark = node._start_mark obj._start_mark = node._start_mark
if hasattr(node, 'end_mark'): if hasattr(node, 'end_mark'):
obj._end_mark = node.end_mark obj._end_mark = node.end_mark
elif hasattr(node, '_end_mark'): elif hasattr(node, '_end_mark'):
@ -97,8 +99,11 @@ def marked(obj):
hasattr(obj, '_end_mark') and obj._end_mark) hasattr(obj, '_end_mark') and obj._end_mark)
class OrderedLineLoader(Loader): class OrderedLineLoader(RoundTripLoader):
"""YAML loader that preserves order and line numbers. """YAML loader specifically intended for reading Spack configuration
files. It preserves order and line numbers. It also has special-purpose
logic for handling dictionary keys that indicate a Spack config
override: namely any key that contains an "extra" ':' character.
Mappings read in by this loader behave like an ordered dict. Mappings read in by this loader behave like an ordered dict.
Sequences, mappings, and strings also have new attributes, Sequences, mappings, and strings also have new attributes,
@ -107,74 +112,46 @@ class OrderedLineLoader(Loader):
""" """
# #
# Override construct_yaml_* so that they build our derived types, # Override construct_yaml_* so that we can apply _start_mark/_end_mark to
# which allows us to add new attributes to them. # them. The superclass returns CommentedMap/CommentedSeq objects that we
# can add attributes to (and we depend on their behavior to preserve
# comments).
# #
# The standard YAML constructors return empty instances and fill # The inherited sequence/dictionary constructors return empty instances
# in with mappings later. We preserve this behavior. # and fill in with mappings later. We preserve this behavior.
# #
def construct_yaml_str(self, node): def construct_yaml_str(self, node):
value = self.construct_scalar(node) value = super(OrderedLineLoader, self).construct_yaml_str(node)
# There is no specific marker to indicate that we are parsing a key,
# so this assumes we are talking about a Spack config override key if
# it ends with a ':' and does not contain a '@' (which can appear
# in config values that refer to Spack specs)
if value and value.endswith(':') and '@' not in value:
value = syaml_str(value[:-1])
value.override = True
else:
value = syaml_str(value) value = syaml_str(value)
mark(value, node) mark(value, node)
return value return value
def construct_yaml_seq(self, node): def construct_yaml_seq(self, node):
data = syaml_list() gen = super(OrderedLineLoader, self).construct_yaml_seq(node)
data = next(gen)
if markable(data):
mark(data, node) mark(data, node)
yield data yield data
data.extend(self.construct_sequence(node)) for x in gen:
pass
def construct_yaml_map(self, node): def construct_yaml_map(self, node):
data = syaml_dict() gen = super(OrderedLineLoader, self).construct_yaml_map(node)
data = next(gen)
if markable(data):
mark(data, node) mark(data, node)
yield data yield data
value = self.construct_mapping(node) for x in gen:
data.update(value) pass
#
# Override the ``construct_*`` routines. These fill in empty
# objects after yielded by the above ``construct_yaml_*`` methods.
#
def construct_sequence(self, node, deep=False):
if not isinstance(node, SequenceNode):
raise ConstructorError(
None, None,
"expected a sequence node, but found %s" % node.id,
node.start_mark)
value = syaml_list(self.construct_object(child, deep=deep)
for child in node.value)
mark(value, node)
return value
def construct_mapping(self, node, deep=False):
"""Store mappings as OrderedDicts instead of as regular python
dictionaries to preserve file ordering."""
if not isinstance(node, MappingNode):
raise ConstructorError(
None, None,
"expected a mapping node, but found %s" % node.id,
node.start_mark)
mapping = syaml_dict()
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
try:
hash(key)
except TypeError as exc:
raise ConstructorError(
"while constructing a mapping", node.start_mark,
"found unacceptable key (%s)" % exc, key_node.start_mark)
value = self.construct_object(value_node, deep=deep)
if key in mapping:
raise ConstructorError(
"while constructing a mapping", node.start_mark,
"found already in-use key (%s)" % key, key_node.start_mark)
mapping[key] = value
mark(mapping, node)
return mapping
# register above new constructors # register above new constructors
@ -186,7 +163,7 @@ def construct_mapping(self, node, deep=False):
'tag:yaml.org,2002:str', OrderedLineLoader.construct_yaml_str) 'tag:yaml.org,2002:str', OrderedLineLoader.construct_yaml_str)
class OrderedLineDumper(Dumper): class OrderedLineDumper(RoundTripDumper):
"""Dumper that preserves ordering and formats ``syaml_*`` objects. """Dumper that preserves ordering and formats ``syaml_*`` objects.
This dumper preserves insertion ordering ``syaml_dict`` objects This dumper preserves insertion ordering ``syaml_dict`` objects
@ -196,41 +173,15 @@ class OrderedLineDumper(Dumper):
""" """
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
node = MappingNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
if hasattr(mapping, 'items'):
# if it's a syaml_dict, preserve OrderedDict order.
# Otherwise do the default thing.
sort = not isinstance(mapping, syaml_dict)
mapping = list(mapping.items())
if sort:
mapping.sort()
for item_key, item_value in mapping:
node_key = self.represent_data(item_key)
node_value = self.represent_data(item_value)
if not (isinstance(node_key, ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, ScalarNode) and
not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def ignore_aliases(self, _data): def ignore_aliases(self, _data):
"""Make the dumper NEVER print YAML aliases.""" """Make the dumper NEVER print YAML aliases."""
return True return True
def represent_str(self, data):
if hasattr(data, 'override') and data.override:
data = data + ':'
return super(OrderedLineDumper, self).represent_str(data)
# Make our special objects look like normal YAML ones. # Make our special objects look like normal YAML ones.
OrderedLineDumper.add_representer(syaml_dict, OrderedLineDumper.represent_dict) OrderedLineDumper.add_representer(syaml_dict, OrderedLineDumper.represent_dict)
@ -239,6 +190,28 @@ def ignore_aliases(self, _data):
OrderedLineDumper.add_representer(syaml_int, OrderedLineDumper.represent_int) OrderedLineDumper.add_representer(syaml_int, OrderedLineDumper.represent_int)
class SafeDumper(RoundTripDumper):
def ignore_aliases(self, _data):
"""Make the dumper NEVER print YAML aliases."""
return True
# Allow syaml_dict objects to be represented by ruamel.yaml.dump. With this,
# syaml_dict allows a user to provide an ordered dictionary to yaml.dump when
# the RoundTripDumper is used.
RoundTripDumper.add_representer(syaml_dict, RoundTripDumper.represent_dict)
#: Max integer helps avoid passing too large a value to cyaml.
maxint = 2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1) - 1
def dump(obj, default_flow_style=False, stream=None):
return yaml.dump(obj, default_flow_style=default_flow_style, width=maxint,
Dumper=SafeDumper, stream=stream)
def file_line(mark): def file_line(mark):
"""Format a mark as <file>:<line> information.""" """Format a mark as <file>:<line> information."""
result = mark.name result = mark.name
@ -288,6 +261,7 @@ def represent_data(self, data):
result = super(LineAnnotationDumper, self).represent_data(data) result = super(LineAnnotationDumper, self).represent_data(data)
if isinstance(result.value, string_types): if isinstance(result.value, string_types):
result.value = syaml_str(data) result.value = syaml_str(data)
if markable(result.value):
mark(result.value, data) mark(result.value, data)
return result return result
@ -319,14 +293,18 @@ def write_line_break(self):
_annotations.append('') _annotations.append('')
def load(*args, **kwargs): def load_config(*args, **kwargs):
"""Load but modify the loader instance so that it will add __line__ """Load but modify the loader instance so that it will add __line__
atrributes to the returned object.""" attributes to the returned object."""
kwargs['Loader'] = OrderedLineLoader kwargs['Loader'] = OrderedLineLoader
return yaml.load(*args, **kwargs) return yaml.load(*args, **kwargs)
def dump(*args, **kwargs): def load(*args, **kwargs):
return yaml.load(*args, **kwargs)
def dump_config(*args, **kwargs):
blame = kwargs.pop('blame', False) blame = kwargs.pop('blame', False)
if blame: if blame: