Skip to content

BUG: LKJCorr default transform raises error #7002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
juanitorduz opened this issue Nov 8, 2023 · 7 comments · Fixed by #7065
Closed

BUG: LKJCorr default transform raises error #7002

juanitorduz opened this issue Nov 8, 2023 · 7 comments · Fixed by #7065
Labels

Comments

@juanitorduz
Copy link
Contributor

juanitorduz commented Nov 8, 2023

Describe the issue:

I am trying to run the model 4 from https://tomicapretto.github.io/posts/2022-06-12_lkj-prior/#model-4-correlated-priors-with-lkjcorr.-replicate-rstanarm-prior and I am getting the following error

NotImplementedError: Univariate transform Interval cannot be applied to multivariate lkjcorr_rv{1, (0, 0), floatX, False}

Reproduceable code example:

import arviz as az
import pytensor.tensor as pt
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

url = "https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/lme4/sleepstudy.csv"
data = pd.read_csv(url, index_col = 0)

# Subjects and subjects index
subjects, subjects_idx = np.unique(data["Subject"], return_inverse=True)

# Coordinates to handle dimensions of PyMC distributions and use better labels
coords = {"subject": subjects}

# Response mean -- Used in the prior for the intercept
y_mean = data["Reaction"].mean()

# Days variable
days = data["Days"].values

coords.update({"effect": ["intercept", "slope"]})

J = 2 # Order of covariance matrix

with pm.Model(coords=coords) as model_lkj_corr_2:
    # Common part
    β0 = pm.Normal("β0", mu=y_mean, sigma=50)
    β1 = pm.Normal("β1", mu=0, sigma=10)
    
    # Residual SD 
    σ = pm.HalfStudentT("σ", nu=4, sigma=50)
    
    # Group-specific part
    # Begin of rstanarm approach ----------------------------------
    τ = pm.Gamma("τ", alpha=1, beta=1)
    Σ_trace = J * τ ** 2
    π = pm.Dirichlet("π", a=np.ones(J), dims="effect")
    σ_u = pm.Deterministic("b_σ", σ * π * (Σ_trace) ** 0.5)
    # End of rstanarm approach ------------------------------------
    
    # Triangular upper part of the correlation matrix
    Ω_triu = pm.LKJCorr("Ω_triu", eta=1, n=2)
     
    # Correlation matrix
    Ω = pm.Deterministic(
        "Ω", pt.fill_diagonal(Ω_triu[np.zeros((2, 2), dtype=np.int64)], 1.)
    )

    # Construct diagonal matrix of standard deviation
    σ_u_diagonal = pm.Deterministic("σ_u_diagonal", pt.eye(2) * σ_u)
    
    # Covariance matrix
    Σ = pt.nlinalg.matrix_dot(σ_u_diagonal, Ω, σ_u_diagonal)
    
    # Cholesky decomposition, lower triangular matrix.
    L = pm.Deterministic("L", pt.slinalg.cholesky(Σ))
    u_raw = pm.Normal("u_raw", mu=0, sigma=1, dims=("effect", "subject")) 
    
    u = pm.Deterministic("u", pt.dot(L, u_raw).T, dims=("subject", "effect"))
                         
    u0 = pm.Deterministic("u0", u[:, 0], dims="subject")
    σ_u0 = pm.Deterministic("σ_u0", σ_u[0])
    
    u1 = pm.Deterministic("u1", u[:, 1], dims="subject")
    σ_u1 = pm.Deterministic("σ_u1", σ_u[1])
    
    # Correlation
    ρ_u = pm.Deterministic("ρ_u", Ω[0, 1])
         
    # Construct intercept and slope
    intercept = pm.Deterministic("intercept", β0 + u0[subjects_idx]) 
    slope = pm.Deterministic("slope", (β1 + u1[subjects_idx]) * days) 
    
    # Conditional mean
    μ = pm.Deterministic("μ", intercept + slope)
       
    y = pm.Normal("y", mu=μ, sigma=σ, observed=data["Reaction"])

Error message:

NotImplementedError                       Traceback (most recent call last)
/Users/juanitorduz/Documents/website_projects/Python/multilevel_elasticities_single_sku_2.ipynb Cell 9 line 2
      1 with model_cov:
----> 2     idata_cov = pm.sample(
      3         target_accept=0.9,
      4         draws=4_000,
      5         chains=4,
      6         random_seed=rng,
      7     )
      8     posterior_predictive_cov = pm.sample_posterior_predictive(
      9         trace=idata_cov, random_seed=rng
     10     )

File ~/.local/share/virtualenvs/website_projects-1IZj_WTw/lib/python3.11/site-packages/pymc/sampling/mcmc.py:689, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    686         auto_nuts_init = False
    688 initial_points = None
--> 689 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    691 if nuts_sampler != "pymc":
    692     if not isinstance(step, NUTS):

File ~/.local/share/virtualenvs/website_projects-1IZj_WTw/lib/python3.11/site-packages/pymc/sampling/mcmc.py:217, in assign_step_methods(model, step, methods, step_kwargs)
    215 methods_list: List[Type[BlockedStep]] = list(methods or pm.STEP_METHODS)
    216 selected_steps: Dict[Type[BlockedStep], List] = {}
--> 217 model_logp = model.logp()
    219 for var in model.value_vars:
    220     if var not in assigned_vars:
    221         # determine if a gradient can be computed

File ~/.local/share/virtualenvs/website_projects-1IZj_WTw/lib/python3.11/site-packages/pymc/model/core.py:721, in Model.logp(self, vars, jacobian, sum)
    719 rv_logps: List[TensorVariable] = []
    720 if rvs:
--> 721     rv_logps = transformed_conditional_logp(
    722         rvs=rvs,
    723         rvs_to_values=self.rvs_to_values,
    724         rvs_to_transforms=self.rvs_to_transforms,
    725         jacobian=jacobian,
    726     )
    727     assert isinstance(rv_logps, list)
    729 # Replace random variables by their value variables in potential terms

File ~/.local/share/virtualenvs/website_projects-1IZj_WTw/lib/python3.11/site-packages/pymc/logprob/basic.py:612, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    609     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore
    611 kwargs.setdefault("warn_rvs", False)
--> 612 temp_logp_terms = conditional_logp(
...
   1269 else:
   1270     # Check there is no broadcasting between logp and jacobian
   1271     if logp.type.broadcastable != log_jac_det.type.broadcastable:

NotImplementedError: Univariate transform Interval cannot be applied to multivariate lkjcorr_rv{1, (0, 0), floatX, False}

PyMC version information:

Last updated: Wed Nov 08 2023

Python implementation: CPython
Python version       : 3.11.5
IPython version      : 8.17.2

pytensor: 2.17.3

pandas    : 2.1.2
pytensor  : 2.17.3
numpy     : 1.23.5
matplotlib: 3.8.1
arviz     : 0.16.1
pymc      : 5.9.1

Watermark: 2.4.3

Context for the issue:

No response

@juanitorduz juanitorduz added the bug label Nov 8, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Nov 8, 2023

We disabled univariate transforms for multivariate RVs (see discussion: #6903 (comment))

It looks fine for this variable however (so much that it is a default... although not tested anywhere xD), so maybe we can remove that restriction. It was more to prevent users from shooting themselves in the foot, but I think it's better to allow that than not.

CC @lucianopaz

@juanitorduz
Copy link
Contributor Author

juanitorduz commented Nov 8, 2023

Thanks for the info! I agree we should allow it exactly because what motivates this model (I love this blogpost @tomicapretto !).

Is this something a newbie like me can help with without going into your black magic 😅?

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 8, 2023

Just have to remove the check added in that PR and sum the jacobian. Then tweak the tests and perhaps add one for the LKJCorr RV since it didn't complaint when we put the restriction in place

@juanitorduz
Copy link
Contributor Author

ok! I might try it out! 😅

@ricardoV94
Copy link
Member

@juanitorduz
Copy link
Contributor Author

Just have to remove the check added in that PR

@ricardoV94 Do you mean remove the condition https://github.com/pymc-devs/pymc/blob/main/pymc/logprob/transform_value.py#L124 ?

elif log_jac_det.ndim > logp.ndim:  # <- This one?
    ...

... and sum the jacobian.

Where precisely? Sorry, I do not get this one 🙈

@ricardoV94
Copy link
Member

When we apply a univariate transform to a multivariate rv the jacobian comes out larger than the logp (vector vs scalar). We need to reduce the jacobian before adding it to the logp, otherwise it would broadcast the logp and count the same thing multiple times.

That's why we thought about just forbidding it. The reduction has to to be done exactly there where the error is now raised.

@ricardoV94 ricardoV94 changed the title BUG: LKJCorr model NotImplementedError BUG: LKJCorr default transform raises error Nov 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants