Add a multiprocess Barrier class to use for testing parallel code.
This commit is contained in:
parent
5fda7daf57
commit
908a93a470
1 changed files with 49 additions and 1 deletions
|
@ -27,9 +27,11 @@
|
|||
than multiprocessing.Pool.apply() can. For example, apply() will fail
|
||||
to pickle functions if they're passed indirectly as parameters.
|
||||
"""
|
||||
from multiprocessing import Process, Pipe
|
||||
from multiprocessing import Process, Pipe, Semaphore, Value
|
||||
from itertools import izip
|
||||
|
||||
__all__ = ['spawn', 'parmap', 'Barrier']
|
||||
|
||||
def spawn(f):
|
||||
def fun(pipe,x):
|
||||
pipe.send(f(x))
|
||||
|
@ -43,3 +45,49 @@ def parmap(f,X):
|
|||
[p.join() for p in proc]
|
||||
return [p.recv() for (p,c) in pipe]
|
||||
|
||||
|
||||
class Barrier:
|
||||
"""Simple reusable semaphore barrier.
|
||||
|
||||
Python 2.6 doesn't have multiprocessing barriers so we implement this.
|
||||
|
||||
See http://greenteapress.com/semaphores/downey08semaphores.pdf, p. 41.
|
||||
"""
|
||||
def __init__(self, n, timeout=None):
|
||||
self.n = n
|
||||
self.to = timeout
|
||||
self.count = Value('i', 0)
|
||||
self.mutex = Semaphore(1)
|
||||
self.turnstile1 = Semaphore(0)
|
||||
self.turnstile2 = Semaphore(1)
|
||||
|
||||
|
||||
def wait(self):
|
||||
if not self.mutex.acquire(timeout=self.to):
|
||||
raise BarrierTimeoutError()
|
||||
self.count.value += 1
|
||||
if self.count.value == self.n:
|
||||
if not self.turnstile2.acquire(timeout=self.to):
|
||||
raise BarrierTimeoutError()
|
||||
self.turnstile1.release()
|
||||
self.mutex.release()
|
||||
|
||||
if not self.turnstile1.acquire(timeout=self.to):
|
||||
raise BarrierTimeoutError()
|
||||
self.turnstile1.release()
|
||||
|
||||
if not self.mutex.acquire(timeout=self.to):
|
||||
raise BarrierTimeoutError()
|
||||
self.count.value -= 1
|
||||
if self.count.value == 0:
|
||||
if not self.turnstile1.acquire(timeout=self.to):
|
||||
raise BarrierTimeoutError()
|
||||
self.turnstile2.release()
|
||||
self.mutex.release()
|
||||
|
||||
if not self.turnstile2.acquire(timeout=self.to):
|
||||
raise BarrierTimeoutError()
|
||||
self.turnstile2.release()
|
||||
|
||||
|
||||
class BarrierTimeoutError: pass
|
||||
|
|
Loading…
Reference in a new issue