diff --git a/lib/spack/llnl/util/lock.py b/lib/spack/llnl/util/lock.py index 86a45e2d7c..3a58093491 100644 --- a/lib/spack/llnl/util/lock.py +++ b/lib/spack/llnl/util/lock.py @@ -275,7 +275,12 @@ def acquire_write(self, timeout=None): wait_time, nattempts = self._lock(fcntl.LOCK_EX, timeout=timeout) self._acquired_debug('WRITE LOCK', wait_time, nattempts) self._writes += 1 - return True + + # return True only if we weren't nested in a read lock. + # TODO: we may need to return two values: whether we got + # the write lock, and whether this is acquiring a read OR + # write lock for the first time. Now it returns the latter. + return self._reads == 0 else: self._writes += 1 return False diff --git a/lib/spack/spack/test/llnl/util/lock.py b/lib/spack/spack/test/llnl/util/lock.py index 2b0892a25e..ca879cdc0b 100644 --- a/lib/spack/spack/test/llnl/util/lock.py +++ b/lib/spack/spack/test/llnl/util/lock.py @@ -1087,6 +1087,62 @@ def write(t, v, tb): assert vals['wrote'] +def test_nested_reads(lock_path): + """Ensure that write transactions won't re-read data.""" + + def read(): + vals['read'] += 1 + + vals = collections.defaultdict(lambda: 0) + lock = AssertLock(lock_path, vals) + + # read/read + vals.clear() + assert vals['read'] == 0 + with lk.ReadTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.ReadTransaction(lock, acquire=read): + assert vals['read'] == 1 + + # write/write + vals.clear() + assert vals['read'] == 0 + with lk.WriteTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.WriteTransaction(lock, acquire=read): + assert vals['read'] == 1 + + # read/write + vals.clear() + assert vals['read'] == 0 + with lk.ReadTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.WriteTransaction(lock, acquire=read): + assert vals['read'] == 1 + + # write/read/write + vals.clear() + assert vals['read'] == 0 + with lk.WriteTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.ReadTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.WriteTransaction(lock, acquire=read): + assert vals['read'] == 1 + + # read/write/read/write + vals.clear() + assert vals['read'] == 0 + with lk.ReadTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.WriteTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.ReadTransaction(lock, acquire=read): + assert vals['read'] == 1 + with lk.WriteTransaction(lock, acquire=read): + assert vals['read'] == 1 + + def test_lock_debug_output(lock_path): host = socket.getfqdn()