Add a multiprocess Barrier class to use for testing parallel code.

This commit is contained in:
Todd Gamblin 2015-10-24 19:54:52 -07:00
parent 5fda7daf57
commit 908a93a470

View file

@ -27,9 +27,11 @@
than multiprocessing.Pool.apply() can. For example, apply() will fail
to pickle functions if they're passed indirectly as parameters.
"""
from multiprocessing import Process, Pipe
from multiprocessing import Process, Pipe, Semaphore, Value
from itertools import izip
__all__ = ['spawn', 'parmap', 'Barrier']
def spawn(f):
def fun(pipe,x):
pipe.send(f(x))
@ -43,3 +45,49 @@ def parmap(f,X):
[p.join() for p in proc]
return [p.recv() for (p,c) in pipe]
class Barrier:
"""Simple reusable semaphore barrier.
Python 2.6 doesn't have multiprocessing barriers so we implement this.
See http://greenteapress.com/semaphores/downey08semaphores.pdf, p. 41.
"""
def __init__(self, n, timeout=None):
self.n = n
self.to = timeout
self.count = Value('i', 0)
self.mutex = Semaphore(1)
self.turnstile1 = Semaphore(0)
self.turnstile2 = Semaphore(1)
def wait(self):
if not self.mutex.acquire(timeout=self.to):
raise BarrierTimeoutError()
self.count.value += 1
if self.count.value == self.n:
if not self.turnstile2.acquire(timeout=self.to):
raise BarrierTimeoutError()
self.turnstile1.release()
self.mutex.release()
if not self.turnstile1.acquire(timeout=self.to):
raise BarrierTimeoutError()
self.turnstile1.release()
if not self.mutex.acquire(timeout=self.to):
raise BarrierTimeoutError()
self.count.value -= 1
if self.count.value == 0:
if not self.turnstile1.acquire(timeout=self.to):
raise BarrierTimeoutError()
self.turnstile2.release()
self.mutex.release()
if not self.turnstile2.acquire(timeout=self.to):
raise BarrierTimeoutError()
self.turnstile2.release()
class BarrierTimeoutError: pass