Skip to content

New parallel cannot sample more chain than n_core #3028

Closed
@junpenglao

Description

@junpenglao

The new parallel sampling using multiprocessing (#3011) seems to break below:

with pm.Model() as m:
    x = pm.Normal('x', 0, 5)
    pm.Normal('y', x, 1, observed=np.asarray([.1, 5.]))
    trace = pm.sample(cores=2, nchains=4)

Traceback:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [x]
Sampling 4 chains:  50%|█████     | 2000/4000 [00:00<00:00, 4039.89draws/s]

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-13-aea560e3f116> in <module>()
      2     x = pm.Normal('x', 0, 5)
      3     pm.Normal('y', x, 1, observed=np.asarray([.1, 5.]))
----> 4     trace = pm.sample(cores=2, nchains=4)

~/Documents/Github/pymc3/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    473                 "convergence reliably.")
    474         else:
--> 475             trace.report._run_convergence_checks(trace, model)
    476 
    477     trace.report._log_summary()

~/Documents/Github/pymc3/pymc3/backends/report.py in _run_convergence_checks(self, trace, model)
     81                 varnames.append(rv_name)
     82 
---> 83         self._effective_n = effective_n = diagnostics.effective_n(trace, varnames)
     84         self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace, varnames)
     85 

~/Documents/Github/pymc3/pymc3/diagnostics.py in effective_n(mtrace, varnames, include_transformed)
    298 
    299     for var in varnames:
--> 300         n_eff[var] = generate_neff(mtrace.get_values(var, combine=False))
    301 
    302     return n_eff

~/Documents/Github/pymc3/pymc3/diagnostics.py in generate_neff(trace_values)
    276         # Iterate over tuples of indices of the shape of var
    277         for tup in np.ndindex(*list(x.shape[:-2])):
--> 278             _n_eff[tup] = get_neff(x[tup])
    279 
    280         if len(shape) == 2:

~/Documents/Github/pymc3/pymc3/diagnostics.py in get_neff(x)
    217         """
    218         trace_value = x.T
--> 219         nchain, n_samples = trace_value.shape
    220 
    221         acov = np.asarray([autocov(trace_value[chain]) for chain in range(nchain)])

ValueError: not enough values to unpack (expected 2, got 1)

Setting compute_convergence_checks=False, the sampler only returns the number of chains same as n cores:

with pm.Model() as m:
    x = pm.Normal('x', 0, 5)
    pm.Normal('y', x, 1, observed=np.asarray([.1, 5.]))
    trace = pm.sample(cores=2, nchains=4, compute_convergence_checks=False)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 2 jobs)
NUTS: [x]
Sampling 4 chains:  50%|█████     | 2000/4000 [00:00<00:00, 3903.79draws/s]

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions