Skip to content

Use stanio for creating Stan's data JSON #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2025

Conversation

WardBrian
Copy link
Contributor

Closes #204.

Note: I believe the output-processing code in stan.rs is probably still mishandling tuples. I see it has some logic to treat complex variables as 2-arrays, which seems fine if a bit unfortunate.

Verified

This commit was signed with the committer’s verified signature.
lunika Manuel Raynaud
@WardBrian
Copy link
Contributor Author

@aseyboldt the failures here seem to be unrelated to the Stan code -- thoughts?

@aseyboldt
Copy link
Member

Thanks!
The test failures are unrelated. In the latest PR for the normalizing flows I introduce a problem on ARM somehow, that I haven't figured out yet.

Complex values are represented as an array with two elements right now. I'll open an issue to change that to a proper complex type.

What do you mean about the tuples? This for instance seems to work as expected (apart from the complex variable)?

import nutpie

code = """
parameters {
    tuple(complex, tuple(tuple(real, real), tuple(real, real)), array[2, 3] real) xi;
}
transformed parameters {
    real re = get_real(xi.1);
    real im = get_imag(xi.1);
}
model {
    re ~ normal(0, 0.1);
    im ~ normal(-1, 0.1);

    xi.2.1.1 ~ normal(-2, 0.1);
    xi.2.1.2 ~ normal(-3, 0.1);
    xi.2.2.1 ~ normal(-4, 0.1);
    xi.2.2.2 ~ normal(-5, 0.1);

    xi.3[1, 1] ~ normal(1, 0.1);
    xi.3[1, 2] ~ normal(2, 0.1);
    xi.3[1, 3] ~ normal(3, 0.1);
    xi.3[2, 1] ~ normal(4, 0.1);
    xi.3[2, 2] ~ normal(5, 0.1);
    xi.3[2, 3] ~ normal(6, 0.1);
}
"""

compiled = nutpie.compile_stan_model(code=code)

tr = nutpie.sample(compiled)

means = tr.posterior.mean(["draw", "chain"])

image

There is no way to represent unnamed hierarchical structure in xarray, so I don't think we can do better than flattening the tuple?

@aseyboldt aseyboldt merged commit 35c2508 into pymc-devs:main May 6, 2025
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve JSON dumping in CompiledStanModel.with_data
2 participants