Skip to content

Add function that goes from transformed space to untransformed space #6721

Open
@ricardoV94

Description

@ricardoV94
Member

Description

Because we don't save transformed variables in the returned InferenceData (why not?) it's not easy to evaluate the model logp once we have a trace.

One could rewrite the model without transforms (and we can make this automatically for the user) This is possible with https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html

But someone might still want to evaluate it in the original model (with jacobians and all that).

One dirty implementation is given here: https://discourse.pymc.io/t/logp-questions-synthetic-dataset-to-evaluate-modeling/12129/6?u=ricardov94

Activity

ricardoV94

ricardoV94 commented on Dec 16, 2023

@ricardoV94
MemberAuthor

Results should be saved in https://python.arviz.org/en/latest/schema/schema.html#unconstrained-posterior

We should make sure there's an option from pm.sample to store those, besides allowing users to populate them afterwards with a helper as initially suggested in this issue

pipme

pipme commented on Feb 26, 2025

@pipme
Contributor

Hi, any updates on this? Or do you have any suggestions for vectorizing the transformation of parameters between the constrained and unconstrained space?

I am currently doing an inefficient for-loop, which also feels a bit hacky:

  • constrained to unconstrained space:
    model: pm.Model
    transformed_rvs = []
    for free_rv in model.free_RVs:
        transform = model.rvs_to_transforms.get(free_rv)
        if transform is None:
            transformed_rvs.append(free_rv)
        else:
            transformed_rv = transform.forward(free_rv, *free_rv.owner.inputs)
            transformed_rvs.append(transformed_rv)

    fn = model.compile_fn(inputs=model.free_RVs, outs=transformed_rvs)
	# N parameter values to transform
    for i in range(N_samples):
        # the value_dict is e.g., {"sigma": 0.1, "a": [0.1, 0.2]}
        value_unconstrained_list = fn(value_dict)
  • unconstrained to constrained space:
    outputs = model.unobserved_value_vars
    fn_inv = model.compile_fn(outs=outputs)
    for i in range(N_samples):
        # value_dict = {"sigma_log__": np.log(0.1), "a": [0.1, 0.2]}
        value_constrained_list = fn_inv(value_dict)

Thanks!

pipme

pipme commented on Feb 26, 2025

@pipme
Contributor
from pymc.sampling.jax import _postprocess_samples, get_jaxified_graph
from pymc.util import (
    get_default_varnames,
)

filtered_var_names = model.unobserved_value_vars
vars_to_sample = list(
    get_default_varnames(filtered_var_names, include_transformed=False)
)

jax_fn_inv = get_jaxified_graph(
    inputs=model.value_vars, outputs=vars_to_sample
)
_postprocess_samples(
    jax_fn_inv, params_unconstrained
)

The above seems to work for transforming from the unconstrained to constrained space. Inside _postprocess_samples, jax.vmap is leveraged for vectorization. But the below doesn't work for going from the constrained to unconstrained space:

jax_fn = get_jaxified_graph(
    inputs=vars_to_sample, outputs=model.value_vars
)

Is it possible to have such a jaxified function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @ricardoV94@pipme

        Issue actions

          Add function that goes from transformed space to untransformed space · Issue #6721 · pymc-devs/pymc