Skip to content

Commit c70895c

Browse files
committed
add more examples
1 parent 0d2a836 commit c70895c

File tree

3 files changed

+131
-336
lines changed

3 files changed

+131
-336
lines changed

examples/bart/bart_heteroscedasticity.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1956,7 +1956,7 @@
19561956
"name": "python",
19571957
"nbconvert_exporter": "python",
19581958
"pygments_lexer": "ipython3",
1959-
"version": "3.13.5"
1959+
"version": "3.11.5"
19601960
}
19611961
},
19621962
"nbformat": 4,

examples/case_studies/bayesian_workflow.ipynb

Lines changed: 115 additions & 291 deletions
Large diffs are not rendered by default.

examples/case_studies/bayesian_workflow.myst.md

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: default
8+
display_name: pymc
99
language: python
1010
name: python3
1111
---
@@ -39,9 +39,8 @@ papermill:
3939
---
4040
import warnings
4141
42-
import arviz as az
42+
import arviz.preview as az
4343
import load_covid_data
44-
import matplotlib.pyplot as plt
4544
import numpy as np
4645
import plotly.express as px
4746
import plotly.graph_objects as go
@@ -54,18 +53,13 @@ from plotly.subplots import make_subplots
5453
5554
# Set renderer to generate static images
5655
pio.renderers.default = "png"
57-
58-
# Configure image size and quality
59-
pio.kaleido.scope.default_width = 800
60-
pio.kaleido.scope.default_height = 600
61-
pio.kaleido.scope.default_scale = 2
62-
6356
warnings.simplefilter("ignore")
6457
6558
RANDOM_SEED = 8451997
66-
sampler_kwargs = {"chains": 4, "cores": 4, "tune": 2000, "random_seed": RANDOM_SEED}
59+
sampler_kwargs = {"chains": 4, "cores": 4, "tune": 1000, "random_seed": RANDOM_SEED}
6760
68-
az.style.use("arviz-doc")
61+
az.rcParams["plot.backend"] = "plotly"
62+
az.style.use("arviz-variat")
6963
```
7064

7165
Bayesian methods offer several fundamental strengths that make it particularly valuable for building robust statistical models. Its great **flexibility** allows practitioners to build remarkably complex models from simple building blocks. The framework provides a principled way of dealing with **uncertainty**, capturing not just the most likely outcome but the complete distribution of all possible outcomes. Critically, Bayesian methods allow **expert information** to guide model development through the use of informative priors, incorporating domain knowledge directly into the analysis.
@@ -254,7 +248,6 @@ fig.update_layout(
254248
yaxis_title="Positive cases",
255249
yaxis=dict(range=[-1000, 1000]),
256250
xaxis=dict(range=[0, 10]),
257-
template="plotly_white",
258251
)
259252
```
260253

@@ -332,7 +325,6 @@ fig.update_layout(
332325
yaxis_title="Positive cases",
333326
yaxis=dict(range=[-100, 1000]),
334327
xaxis=dict(range=[0, 10]),
335-
template="plotly_white",
336328
)
337329
```
338330

@@ -390,12 +382,12 @@ make_subplots(
390382
).update_yaxes(
391383
title_text="Count", row=1, col=1
392384
).update_layout(
393-
template="plotly_white", showlegend=False, height=350
385+
showlegend=False, height=350
394386
)
395387
```
396388

397389
```{code-cell} ipython3
398-
obs_samples = az.extract(prior_pred3.prior_predictive)["obs"].values
390+
obs_samples = az.extract(prior_pred3.prior_predictive).values
399391
400392
fig = go.Figure()
401393
for i in range(min(100, obs_samples.shape[1])): # Show max 100 traces
@@ -416,7 +408,6 @@ fig.update_layout(
416408
yaxis_title="Positive cases",
417409
yaxis=dict(range=[0, 1000]),
418410
xaxis=dict(range=[0, 10]),
419-
template="plotly_white",
420411
)
421412
```
422413

@@ -446,11 +437,11 @@ Before trusting our results, we must verify that the sampler has converged prope
446437
:::
447438

448439
```{code-cell} ipython3
449-
az.plot_trace(trace_exp3, var_names=["a", "b", "alpha"]);
440+
az.plot_rank_dist(trace_exp3, var_names=["a", "b", "alpha"]);
450441
```
451442

452443
```{code-cell} ipython3
453-
az.summary(trace_exp3, var_names=["a", "b", "alpha"])
444+
az.summary(trace_exp3, var_names=["a", "b", "alpha"], kind="diagnostics")
454445
```
455446

456447
```{code-cell} ipython3
@@ -462,7 +453,7 @@ az.plot_energy(trace_exp3);
462453

463454
**R-hat values**: All close to 1.0 (< 1.01)
464455
**Effective sample size**: Reasonable for all parameters
465-
**Trace plots**: Show good mixing with no trends
456+
**Rank plots**: Show relatively good mixing with no trends
466457
**Energy plot**: Marginal and energy distributions overlap well
467458

468459
Our model has converged successfully!
@@ -538,22 +529,8 @@ for config in prior_configs:
538529
```
539530

540531
```{code-cell} ipython3
541-
fig, ax = plt.subplots(figsize=(8, 4))
542-
543-
colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
544-
545-
for i, (name, trace) in enumerate(results.items()):
546-
az.plot_kde(
547-
trace.posterior["b"].values.flatten(),
548-
label=name,
549-
ax=ax,
550-
plot_kwargs={"color": colors[i], "linewidth": 2},
551-
)
552-
553-
ax.set_xlabel("Growth rate (b)")
554-
ax.set_ylabel("Density")
555-
ax.set_title("Sensitivity to Prior Choice")
556-
ax.legend();
532+
pc = az.plot_dist(results, var_names=["b"])
533+
pc.add_legend("model");
557534
```
558535

559536
:::{admonition} Sensitivity Analysis Results
@@ -620,7 +597,6 @@ fig.add_trace(
620597
xaxis_title="Days since 100 cases",
621598
yaxis_title="Confirmed cases (log scale)",
622599
yaxis_type="log",
623-
template="plotly_white",
624600
)
625601
```
626602

@@ -651,7 +627,6 @@ fig.update_layout(
651627
xaxis_title="Days since 100 cases",
652628
yaxis_title="Residual",
653629
yaxis=dict(range=[-50000, 200000]),
654-
template="plotly_white",
655630
)
656631
```
657632

@@ -768,7 +743,6 @@ fig.update_layout(
768743
xaxis_title="Days since 100 cases",
769744
yaxis_title="Confirmed cases",
770745
yaxis_type="log",
771-
template="plotly_white",
772746
height=400,
773747
)
774748
```
@@ -851,7 +825,6 @@ fig.update_layout(
851825
xaxis_title="Days since 100 cases",
852826
yaxis_title="Positive cases",
853827
yaxis_type="log",
854-
template="plotly_white",
855828
)
856829
```
857830

@@ -865,7 +838,7 @@ with logistic_model:
865838
```
866839

867840
```{code-cell} ipython3
868-
az.plot_trace(trace_logistic);
841+
az.plot_rank_dist(trace_logistic);
869842
```
870843

871844
```{code-cell} ipython3
@@ -905,7 +878,6 @@ fig.add_trace(
905878
title="Logistic Model Fit - Germany",
906879
xaxis_title="Days since 100 cases",
907880
yaxis_title="Confirmed cases",
908-
template="plotly_white",
909881
height=400,
910882
)
911883
```
@@ -987,7 +959,7 @@ fig = px.line(
987959
labels={"days_since_100": "Days since 100 cases", "confirmed": "Confirmed cases"},
988960
)
989961
fig.update_traces(line=dict(color="#FF4136", width=3))
990-
fig.update_layout(template="plotly_white", height=400)
962+
fig.update_layout(height=400)
991963
```
992964

993965
The US data looks quite different - there appear to be multiple waves. Let's see how our logistic model handles this:
@@ -1026,7 +998,7 @@ with logistic_model_us:
1026998
```
1027999

10281000
```{code-cell} ipython3
1029-
az.plot_trace(trace_logistic_us);
1001+
az.plot_rank_dist(trace_logistic_us);
10301002
```
10311003

10321004
```{code-cell} ipython3
@@ -1066,7 +1038,6 @@ fig.add_trace(
10661038
title="Logistic Model Fit - US",
10671039
xaxis_title="Days since 100 cases",
10681040
yaxis_title="Confirmed cases",
1069-
template="plotly_white",
10701041
height=400,
10711042
)
10721043
```

0 commit comments

Comments
 (0)