Skip to content

Rewrite parallel sampling using multiprocessing #3011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 14, 2018
5 changes: 5 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -13,6 +13,11 @@
- Improve error message `NaN occurred in optimization.` during ADVI
- Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace`
- Add `Kumaraswamy` distribution
- Rewrite parallel sampling of multiple chains on py3. This resolves
long standing issues when tranferring large traces to the main process,
avoids pickleing issues on UNIX, and allows us to show a progress bar
for all chains. If parallel sampling is interrupted, we now return
partial results.

### Fixes

5 changes: 3 additions & 2 deletions pymc3/backends/text.py
Original file line number Diff line number Diff line change
@@ -99,8 +99,9 @@ def record(self, point):
self._fh.write(','.join(columns) + '\n')

def close(self):
self._fh.close()
self._fh = None # Avoid serialization issue.
if self._fh is not None:
self._fh.close()
self._fh = None # Avoid serialization issue.

# Selection methods

332 changes: 332 additions & 0 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
import multiprocessing
import multiprocessing.sharedctypes
import ctypes
import time
import logging
from collections import namedtuple
import traceback

import six
import numpy as np

from . import theanof

logger = logging.getLogger('pymc3')


# Taken from https://hg.python.org/cpython/rev/c4f92b597074
class RemoteTraceback(Exception):
def __init__(self, tb):
self.tb = tb

def __str__(self):
return self.tb


class ExceptionWithTraceback:
def __init__(self, exc, tb):
tb = traceback.format_exception(type(exc), exc, tb)
tb = ''.join(tb)
self.exc = exc
self.tb = '\n"""\n%s"""' % tb

def __reduce__(self):
return rebuild_exc, (self.exc, self.tb)


def rebuild_exc(exc, tb):
exc.__cause__ = RemoteTraceback(tb)
return exc


# Messages
# ('writing_done', is_last, sample_idx, tuning, stats)
# ('error', *exception_info)

# ('abort', reason)
# ('write_next',)
# ('start',)


class _Process(multiprocessing.Process):
"""Seperate process for each chain.
We communicate with the main process using a pipe,
and send finished samples using shared memory.
"""
def __init__(self, name, msg_pipe, step_method, shared_point,
draws, tune, seed):
super(_Process, self).__init__(daemon=True, name=name)
self._msg_pipe = msg_pipe
self._step_method = step_method
self._shared_point = shared_point
self._seed = seed
self._tt_seed = seed + 1
self._draws = draws
self._tune = tune

def run(self):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
pass
except BaseException as e:
e = ExceptionWithTraceback(e, e.__traceback__)
self._msg_pipe.send(('error', e))
finally:
self._msg_pipe.close()

def _make_numpy_refs(self):
shape_dtypes = self._step_method.vars_shape_dtype
point = {}
for name, (shape, dtype) in shape_dtypes.items():
array = self._shared_point[name]
self._shared_point[name] = array
point[name] = np.frombuffer(array, dtype).reshape(shape)
return point

def _write_point(self, point):
for name, vals in point.items():
self._point[name][...] = vals

def _recv_msg(self):
return self._msg_pipe.recv()

def _start_loop(self):
np.random.seed(self._seed)
theanof.set_tt_rng(self._tt_seed)

draw = 0
tuning = True

msg = self._recv_msg()
if msg[0] == 'abort':
raise KeyboardInterrupt()
if msg[0] != 'start':
raise ValueError('Unexpected msg ' + msg[0])

while True:
if draw < self._draws + self._tune:
point, stats = self._compute_point()
else:
return

if draw == self._tune:
self._step_method.stop_tuning()
tuning = False

msg = self._recv_msg()
if msg[0] == 'abort':
raise KeyboardInterrupt()
elif msg[0] == 'write_next':
self._write_point(point)
is_last = draw + 1 == self._draws + self._tune
if is_last:
warns = self._collect_warnings()
else:
warns = None
self._msg_pipe.send(
('writing_done', is_last, draw, tuning, stats, warns))
draw += 1
else:
raise ValueError('Unknown message ' + msg[0])

def _compute_point(self):
if self._step_method.generates_stats:
point, stats = self._step_method.step(self._point)
else:
point = self._step_method.step(self._point)
stats = None
return point, stats

def _collect_warnings(self):
if hasattr(self._step_method, 'warnings'):
return self._step_method.warnings()
else:
return []


class ProcessAdapter(object):
"""Control a Chain process from the main thread."""
def __init__(self, draws, tune, step_method, chain, seed, start):
self.chain = chain
process_name = "worker_chain_%s" % chain
self._msg_pipe, remote_conn = multiprocessing.Pipe()

self._shared_point = {}
self._point = {}
for name, (shape, dtype) in step_method.vars_shape_dtype.items():
size = 1
for dim in shape:
size *= int(dim)
size *= dtype.itemsize
if size != ctypes.c_size_t(size).value:
raise ValueError('Variable %s is too large' % name)

array = multiprocessing.sharedctypes.RawArray('c', size)
self._shared_point[name] = array
array_np = np.frombuffer(array, dtype).reshape(shape)
array_np[...] = start[name]
self._point[name] = array_np

self._readable = True
self._num_samples = 0

self._process = _Process(
process_name, remote_conn, step_method, self._shared_point,
draws, tune, seed)
# We fork right away, so that the main process can start tqdm threads
self._process.start()

@property
def shared_point_view(self):
"""May only be written to or read between a `recv_draw`
call from the process and a `write_next` or `abort` call.
"""
if not self._readable:
raise RuntimeError()
return self._point

def start(self):
self._msg_pipe.send(('start',))

def write_next(self):
self._readable = False
self._msg_pipe.send(('write_next',))

def abort(self):
self._msg_pipe.send(('abort',))

def join(self, timeout=None):
self._process.join(timeout)

def terminate(self):
self._process.terminate()

@staticmethod
def recv_draw(processes, timeout=3600):
if not processes:
raise ValueError('No processes.')
pipes = [proc._msg_pipe for proc in processes]
ready = multiprocessing.connection.wait(pipes)
if not ready:
raise multiprocessing.TimeoutError('No message from samplers.')
idxs = {id(proc._msg_pipe): proc for proc in processes}
proc = idxs[id(ready[0])]
msg = ready[0].recv()

if msg[0] == 'error':
old = msg[1]
six.raise_from(RuntimeError('Chain %s failed.' % proc.chain), old)
elif msg[0] == 'writing_done':
proc._readable = True
proc._num_samples += 1
return (proc,) + msg[1:]
else:
raise ValueError('Sampler sent bad message.')

@staticmethod
def terminate_all(processes, patience=2):
for process in processes:
try:
process.abort()
except EOFError:
pass

start_time = time.time()
try:
for process in processes:
timeout = time.time() + patience - start_time
if timeout < 0:
raise multiprocessing.TimeoutError()
process.join(timeout)
except multiprocessing.TimeoutError:
logger.warn('Chain processes did not terminate as expected. '
'Terminating forcefully...')
for process in processes:
process.terminate()
for process in processes:
process.join()


Draw = namedtuple(
'Draw',
['chain', 'is_last', 'draw_idx', 'tuning', 'stats', 'point', 'warnings']
)


class ParallelSampler(object):
def __init__(self, draws, tune, chains, cores, seeds, start_points,
step_method, start_chain_num=0, progressbar=True):
if progressbar:
import tqdm
tqdm_ = tqdm.tqdm

self._samplers = [
ProcessAdapter(draws, tune, step_method,
chain + start_chain_num, seed, start)
for chain, seed, start in zip(range(chains), seeds, start_points)
]

self._inactive = self._samplers.copy()
self._finished = []
self._active = []
self._max_active = cores

self._in_context = False
self._start_chain_num = start_chain_num

self._progress = None
if progressbar:
self._progress = tqdm_(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have to add the position argument here in order to not have tqdms interfering with each other

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only one progress bar now, but that progress bar counts samples from all chains.

total=chains * (draws + tune), unit='draws',
desc='Sampling %s chains' % chains)

def _make_active(self):
while self._inactive and len(self._active) < self._max_active:
proc = self._inactive.pop(0)
proc.start()
proc.write_next()
self._active.append(proc)

def __iter__(self):
if not self._in_context:
raise ValueError('Use ParallelSampler as context manager.')
self._make_active()

while self._active:
draw = ProcessAdapter.recv_draw(self._active)
proc, is_last, draw, tuning, stats, warns = draw
if self._progress is not None:
self._progress.update()

if is_last:
proc.join()
self._active.remove(proc)
self._finished.append(proc)
self._make_active()

# We could also yield proc.shared_point_view directly,
# and only call proc.write_next() after the yield returns.
# This seems to be faster overally though, as the worker
# loses less time waiting.
point = {name: val.copy()
for name, val in proc.shared_point_view.items()}

# Already called for new proc in _make_active
if not is_last:
proc.write_next()

yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns)

def __enter__(self):
self._in_context = True
return self

def __exit__(self, *args):
ProcessAdapter.terminate_all(self._samplers)
if self._progress is not None:
self._progress.close()
114 changes: 91 additions & 23 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
@@ -663,7 +663,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
except KeyboardInterrupt:
strace.close()
if hasattr(step, 'warnings'):
warns = step.warnings(strace)
warns = step.warnings()
strace._add_warnings(warns)
raise
except BaseException:
@@ -672,7 +672,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
else:
strace.close()
if hasattr(step, 'warnings'):
warns = step.warnings(strace)
warns = step.warnings()
strace._add_warnings(warns)


@@ -965,36 +965,104 @@ def _choose_backend(trace, chain, shortcuts=None, **kwds):
raise ValueError('Argument `trace` is invalid.')


def _mp_sample(**kwargs):
cores = kwargs.pop('cores')
chain = kwargs.pop('chain')
rseed = kwargs.pop('random_seed')
start = kwargs.pop('start')
chains = kwargs.pop('chains')
use_mmap = kwargs.pop('use_mmap')
def _mp_sample(draws, tune, step, chains, cores, chain, random_seed,
start, progressbar, trace=None, model=None, use_mmap=False,
**kwargs):

chain_nums = list(range(chain, chain + chains))
pbars = [kwargs.pop('progressbar')] + [False] * (chains - 1)
jobs = (delayed(_sample)(*args, **kwargs)
for args in zip(chain_nums, pbars, rseed, start))
if sys.version_info.major >= 3:
import pymc3.parallel_sampling as ps

# We did draws += tune in pm.sample
draws -= tune

traces = []
for idx in range(chain, chain + chains):
if trace is not None:
strace = _choose_backend(copy(trace), idx, model=model)
else:
strace = _choose_backend(None, idx, model=model)
# TODO what is this for?
update_start_vals(start[idx - chain], model.test_point, model)
if step.generates_stats and strace.supports_sampler_stats:
strace.setup(draws + tune, idx + chain, step.stats_dtypes)
else:
strace.setup(draws + tune, idx + chain)
traces.append(strace)

sampler = ps.ParallelSampler(
draws, tune, chains, cores, random_seed, start, step,
chain, progressbar)
try:
with sampler:
for draw in sampler:
trace = traces[draw.chain - chain]
if trace.supports_sampler_stats and draw.stats is not None:
trace.record(draw.point, draw.stats)
else:
trace.record(draw.point)
if draw.is_last:
trace.close()
if draw.warnings is not None:
trace._add_warnings(draw.warnings)
return MultiTrace(traces)
except KeyboardInterrupt:
traces, length = _choose_chains(traces, tune)
return MultiTrace(traces)[:length]
finally:
for trace in traces:
trace.close()

if use_mmap:
traces = Parallel(n_jobs=cores)(jobs)
else:
traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs)
chain_nums = list(range(chain, chain + chains))
pbars = [progressbar] + [False] * (chains - 1)
jobs = (
delayed(_sample)(
chain=args[0], progressbar=args[1], random_seed=args[2],
start=args[3], draws=draws, step=step, trace=trace,
tune=tune, model=model, **kwargs
)
for args in zip(chain_nums, pbars, random_seed, start)
)
if use_mmap:
traces = Parallel(n_jobs=cores)(jobs)
else:
traces = Parallel(n_jobs=cores, mmap_mode=None)(jobs)
return MultiTrace(traces)

return MultiTrace(traces)

def _choose_chains(traces, tune):
if tune is None:
tune = 0

def stop_tuning(step):
""" stop tuning the current step method """
if not traces:
return []

lengths = [max(0, len(trace) - tune) for trace in traces]
if not sum(lengths):
raise ValueError('Not enough samples to build a trace.')

idxs = np.argsort(lengths)[::-1]
l_sort = np.array(lengths)[idxs]

final_length = l_sort[0]
last_total = 0
for i, length in enumerate(l_sort):
total = (i + 1) * length
if total < last_total:
use_until = i
break
last_total = total
final_length = length
else:
use_until = len(lengths)

return [traces[idx] for idx in idxs[:use_until]], final_length + tune

if hasattr(step, 'tune'):
step.tune = False

if hasattr(step, 'methods'):
step.methods = [stop_tuning(s) for s in step.methods]
def stop_tuning(step):
""" stop tuning the current step method """

step.stop_tuning()
return step


15 changes: 14 additions & 1 deletion pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
@@ -87,13 +87,26 @@ def _competence(cls, vars, have_grad):
vars = np.atleast_1d(vars)
have_grad = np.atleast_1d(have_grad)
competences = []
for var,has_grad in zip(vars, have_grad):
for var, has_grad in zip(vars, have_grad):
try:
competences.append(cls.competence(var, has_grad))
except TypeError:
competences.append(cls.competence(var))
return competences

@property
def vars_shape_dtype(self):
shape_dtypes = {}
for var in self.vars:
dtype = np.dtype(var.dtype)
shape = var.dshape
shape_dtypes[var.name] = (shape, dtype)
return shape_dtypes

def stop_tuning(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is much better.

if hasattr(self, 'tune'):
self.tune = False


class ArrayStep(BlockedStep):
"""
21 changes: 17 additions & 4 deletions pymc3/step_methods/compound.py
Original file line number Diff line number Diff line change
@@ -6,11 +6,13 @@


class CompoundStep(object):
"""Step method composed of a list of several other step methods applied in sequence."""
"""Step method composed of a list of several other step
methods applied in sequence."""

def __init__(self, methods):
self.methods = list(methods)
self.generates_stats = any(method.generates_stats for method in self.methods)
self.generates_stats = any(
method.generates_stats for method in self.methods)
self.stats_dtypes = []
for method in self.methods:
if method.generates_stats:
@@ -31,9 +33,20 @@ def step(self, point):
point = method.step(point)
return point

def warnings(self, strace):
def warnings(self):
warns = []
for method in self.methods:
if hasattr(method, 'warnings'):
warns.extend(method.warnings(strace))
warns.extend(method.warnings())
return warns

def stop_tuning(self):
for method in self.methods:
method.stop_tuning()

@property
def vars_shape_dtype(self):
dtype_shapes = {}
for method in self.methods:
dtype_shapes.update(method.vars_shape_dtype)
return dtype_shapes
2 changes: 1 addition & 1 deletion pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
@@ -164,7 +164,7 @@ def reset(self, start=None):
self.tune = True
self.potential.reset()

def warnings(self, strace):
def warnings(self):
# list.copy() is not available in python2
warnings = self._warnings[:]

4 changes: 2 additions & 2 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
@@ -184,8 +184,8 @@ def competence(var, has_grad):
return Competence.IDEAL
return Competence.INCOMPATIBLE

def warnings(self, strace):
warnings = super(NUTS, self).warnings(strace)
def warnings(self):
warnings = super(NUTS, self).warnings()
n_samples = self._samples_after_tune
n_treedepth = self._reached_max_treedepth

4 changes: 2 additions & 2 deletions pymc3/tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
@@ -82,13 +82,13 @@ def make_model(cls):

class StudentTFixture(KnownMean, KnownCDF):
means = {'a': 0}
cdfs = {'a': stats.t(df=3).cdf}
cdfs = {'a': stats.t(df=4).cdf}
ks_thin = 10

@classmethod
def make_model(cls):
with pm.Model() as model:
a = pm.StudentT("a", nu=3, mu=0, sd=1)
a = pm.StudentT("a", nu=4, mu=0, sd=1)
return model


72 changes: 72 additions & 0 deletions pymc3/tests/test_parallel_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import time
import sys
import pytest

import pymc3.parallel_sampling as ps
import pymc3 as pm


@pytest.mark.skipif(sys.version_info < (3,3),
reason="requires python3.3")
def test_abort():
with pm.Model() as model:
a = pm.Normal('a', shape=1)
pm.HalfNormal('b')
step1 = pm.NUTS([a])
step2 = pm.Metropolis([model.b_log__])

step = pm.CompoundStep([step1, step2])

proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1,
start={'a': 1., 'b_log__': 2.})
proc.start()
proc.write_next()
proc.abort()
proc.join()


@pytest.mark.skipif(sys.version_info < (3,3),
reason="requires python3.3")
def test_explicit_sample():
with pm.Model() as model:
a = pm.Normal('a', shape=1)
pm.HalfNormal('b')
step1 = pm.NUTS([a])
step2 = pm.Metropolis([model.b_log__])

step = pm.CompoundStep([step1, step2])

start = time.time()
proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1,
start={'a': 1., 'b_log__': 2.})
proc.start()
while True:
proc.write_next()
out = ps.ProcessAdapter.recv_draw([proc])
view = proc.shared_point_view
for name in view:
view[name].copy()
if out[1]:
break
proc.join()
print(time.time() - start)


@pytest.mark.skipif(sys.version_info < (3,3),
reason="requires python3.3")
def test_iterator():
with pm.Model() as model:
a = pm.Normal('a', shape=1)
pm.HalfNormal('b')
step1 = pm.NUTS([a])
step2 = pm.Metropolis([model.b_log__])

step = pm.CompoundStep([step1, step2])

start = time.time()
start = {'a': 1., 'b_log__': 2.}
sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3,
step, 0, False)
with sampler:
for draw in sampler:
pass