@@ -5,7 +5,7 @@ jupytext:
55 format_name : myst
66 format_version : 0.13
77kernelspec :
8- display_name : default
8+ display_name : pymc
99 language : python
1010 name : python3
1111---
@@ -39,9 +39,8 @@ papermill:
3939---
4040import warnings
4141
42- import arviz as az
42+ import arviz.preview as az
4343import load_covid_data
44- import matplotlib.pyplot as plt
4544import numpy as np
4645import plotly.express as px
4746import plotly.graph_objects as go
@@ -54,18 +53,13 @@ from plotly.subplots import make_subplots
5453
5554# Set renderer to generate static images
5655pio.renderers.default = "png"
57-
58- # Configure image size and quality
59- pio.kaleido.scope.default_width = 800
60- pio.kaleido.scope.default_height = 600
61- pio.kaleido.scope.default_scale = 2
62-
6356warnings.simplefilter("ignore")
6457
6558RANDOM_SEED = 8451997
66- sampler_kwargs = {"chains": 4, "cores": 4, "tune": 2000 , "random_seed": RANDOM_SEED}
59+ sampler_kwargs = {"chains": 4, "cores": 4, "tune": 1000 , "random_seed": RANDOM_SEED}
6760
68- az.style.use("arviz-doc")
61+ az.rcParams["plot.backend"] = "plotly"
62+ az.style.use("arviz-variat")
6963```
7064
7165Bayesian methods offer several fundamental strengths that make it particularly valuable for building robust statistical models. Its great ** flexibility** allows practitioners to build remarkably complex models from simple building blocks. The framework provides a principled way of dealing with ** uncertainty** , capturing not just the most likely outcome but the complete distribution of all possible outcomes. Critically, Bayesian methods allow ** expert information** to guide model development through the use of informative priors, incorporating domain knowledge directly into the analysis.
@@ -254,7 +248,6 @@ fig.update_layout(
254248 yaxis_title="Positive cases",
255249 yaxis=dict(range=[-1000, 1000]),
256250 xaxis=dict(range=[0, 10]),
257- template="plotly_white",
258251)
259252```
260253
@@ -332,7 +325,6 @@ fig.update_layout(
332325 yaxis_title="Positive cases",
333326 yaxis=dict(range=[-100, 1000]),
334327 xaxis=dict(range=[0, 10]),
335- template="plotly_white",
336328)
337329```
338330
@@ -390,12 +382,12 @@ make_subplots(
390382).update_yaxes(
391383 title_text="Count", row=1, col=1
392384).update_layout(
393- template="plotly_white", showlegend=False, height=350
385+ showlegend=False, height=350
394386)
395387```
396388
397389``` {code-cell} ipython3
398- obs_samples = az.extract(prior_pred3.prior_predictive)["obs"] .values
390+ obs_samples = az.extract(prior_pred3.prior_predictive).values
399391
400392fig = go.Figure()
401393for i in range(min(100, obs_samples.shape[1])): # Show max 100 traces
@@ -416,7 +408,6 @@ fig.update_layout(
416408 yaxis_title="Positive cases",
417409 yaxis=dict(range=[0, 1000]),
418410 xaxis=dict(range=[0, 10]),
419- template="plotly_white",
420411)
421412```
422413
@@ -446,11 +437,11 @@ Before trusting our results, we must verify that the sampler has converged prope
446437:::
447438
448439``` {code-cell} ipython3
449- az.plot_trace (trace_exp3, var_names=["a", "b", "alpha"]);
440+ az.plot_rank_dist (trace_exp3, var_names=["a", "b", "alpha"]);
450441```
451442
452443``` {code-cell} ipython3
453- az.summary(trace_exp3, var_names=["a", "b", "alpha"])
444+ az.summary(trace_exp3, var_names=["a", "b", "alpha"], kind="diagnostics" )
454445```
455446
456447``` {code-cell} ipython3
@@ -462,7 +453,7 @@ az.plot_energy(trace_exp3);
462453
463454✓ ** R-hat values** : All close to 1.0 (< 1.01)
464455✓ ** Effective sample size** : Reasonable for all parameters
465- ✓ ** Trace plots** : Show good mixing with no trends
456+ ✓ ** Rank plots** : Show relatively good mixing with no trends
466457✓ ** Energy plot** : Marginal and energy distributions overlap well
467458
468459Our model has converged successfully!
@@ -538,22 +529,8 @@ for config in prior_configs:
538529```
539530
540531``` {code-cell} ipython3
541- fig, ax = plt.subplots(figsize=(8, 4))
542-
543- colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
544-
545- for i, (name, trace) in enumerate(results.items()):
546- az.plot_kde(
547- trace.posterior["b"].values.flatten(),
548- label=name,
549- ax=ax,
550- plot_kwargs={"color": colors[i], "linewidth": 2},
551- )
552-
553- ax.set_xlabel("Growth rate (b)")
554- ax.set_ylabel("Density")
555- ax.set_title("Sensitivity to Prior Choice")
556- ax.legend();
532+ pc = az.plot_dist(results, var_names=["b"])
533+ pc.add_legend("model");
557534```
558535
559536:::{admonition} Sensitivity Analysis Results
@@ -620,7 +597,6 @@ fig.add_trace(
620597 xaxis_title="Days since 100 cases",
621598 yaxis_title="Confirmed cases (log scale)",
622599 yaxis_type="log",
623- template="plotly_white",
624600)
625601```
626602
@@ -651,7 +627,6 @@ fig.update_layout(
651627 xaxis_title="Days since 100 cases",
652628 yaxis_title="Residual",
653629 yaxis=dict(range=[-50000, 200000]),
654- template="plotly_white",
655630)
656631```
657632
@@ -768,7 +743,6 @@ fig.update_layout(
768743 xaxis_title="Days since 100 cases",
769744 yaxis_title="Confirmed cases",
770745 yaxis_type="log",
771- template="plotly_white",
772746 height=400,
773747)
774748```
@@ -851,7 +825,6 @@ fig.update_layout(
851825 xaxis_title="Days since 100 cases",
852826 yaxis_title="Positive cases",
853827 yaxis_type="log",
854- template="plotly_white",
855828)
856829```
857830
@@ -865,7 +838,7 @@ with logistic_model:
865838```
866839
867840``` {code-cell} ipython3
868- az.plot_trace (trace_logistic);
841+ az.plot_rank_dist (trace_logistic);
869842```
870843
871844``` {code-cell} ipython3
@@ -905,7 +878,6 @@ fig.add_trace(
905878 title="Logistic Model Fit - Germany",
906879 xaxis_title="Days since 100 cases",
907880 yaxis_title="Confirmed cases",
908- template="plotly_white",
909881 height=400,
910882)
911883```
@@ -987,7 +959,7 @@ fig = px.line(
987959 labels={"days_since_100": "Days since 100 cases", "confirmed": "Confirmed cases"},
988960)
989961fig.update_traces(line=dict(color="#FF4136", width=3))
990- fig.update_layout(template="plotly_white", height=400)
962+ fig.update_layout(height=400)
991963```
992964
993965The US data looks quite different - there appear to be multiple waves. Let's see how our logistic model handles this:
@@ -1026,7 +998,7 @@ with logistic_model_us:
1026998```
1027999
10281000``` {code-cell} ipython3
1029- az.plot_trace (trace_logistic_us);
1001+ az.plot_rank_dist (trace_logistic_us);
10301002```
10311003
10321004``` {code-cell} ipython3
@@ -1066,7 +1038,6 @@ fig.add_trace(
10661038 title="Logistic Model Fit - US",
10671039 xaxis_title="Days since 100 cases",
10681040 yaxis_title="Confirmed cases",
1069- template="plotly_white",
10701041 height=400,
10711042)
10721043```
0 commit comments