Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Stochastic gradients in pytensor #1419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jessegrabowski opened this issue May 24, 2025 · 1 comment
Closed

Stochastic gradients in pytensor #1419

jessegrabowski opened this issue May 24, 2025 · 1 comment

Comments

@jessegrabowski
Copy link
Member

Description

There are many problems in machine learning that require differentiating through random variables. This specifically came up in the context of pymc-devs/pymc#7799, but it was also implicated in #555.

Right now, pt.grad refuses to go through a random variable. That's probably the correct behavior. But we could have another function, pt.stochastic_grad, that would do it and return a stochastic gradient. It would also take an n_samples argument, since what we're actually computing is a sample-based MCMC estimate of the gradients with respect to parameters.

Most often, the RV would have to know a "reparameterization trick" to split its parameters from the source of randomness. The canonical example is the non-centered normal parameterization. Given a loss function $\mathcal L$ that depends on $x \sim N(\mu, \sigma)$ , the proposed pt.stochastic_grad would compute the gradient of the expected loss given the RVs: $\nabla_\theta \mathbb{E}x [\mathcal{L(g(x, \theta)})] = \mathbb{E} \nabla\theta \mathcal{L}(g(x, \theta))$. The so-called reparameterization trick just does a non-centered parameterization, $x = \mu + \sigma z, \quad z \sim N(0,1)$, so that now the gradient contribution of $g(x, \theta)$ can be estimated:

$$ \approx \frac{1}{N} \sum_{i=1}^N \nabla_\theta L(\mu + \sigma z_i) $$

And the (expected) sensitivity equations for the parameters of the normal are:

$$ \begin{align} \bar{\mu} &= \frac{1}{N} \sum_{i=1}^N \frac{\partial L}{\partial x^{(i)}}, \\ \bar{\sigma} &= \frac{1}{N} \sum_{i=1}^N \frac{\partial L}{\partial x^{(i)}} \cdot z^{(i)} \end{align} $$

It would be easy enough for a normal_rv to know this, and to supply these formulas when requested to by the hypothetical pt.stochastic_grad.

I guess other RVs also have reparameterizations (beta, dirichlet, gamma, ...?), but in some cases, there are multiple options but it's not clear which one is best to use in what cases. Some thought would have to be given to how to handle that.

When a reparameterization doesn't exist, there are other, higher-variance options to compute the expected gradients (the REINFORCE gradients, for example). We could offer these as a fallback.

Basically, this issue is proposing this API, and inviting some discussion on whether we want this type of feature, and how to do it if so. The pt.stochastic_grad function would be novel as far as I know. Other packages require that you explicitly generate samples in your computation graph. For example, torch offers normal(mu, sigma).rsample(n_draws), which generates samples using reparameterization trick, so the standard loss.backward() works. Here the user can't "accidentally" trigger stochastic gradients (because you have to call rsample instead of sample).

I'm less familiar with how numpyro works, but I believe that something like numpyro.sample("z", dist.Normal(mu, sigma)) automatically implies reparameterization trick if it's available. They don't have a special idiom like rsample for when it will or won't be used.

@ricardoV94
Copy link
Member

CC @aseyboldt

@pymc-devs pymc-devs locked and limited conversation to collaborators May 27, 2025
@ricardoV94 ricardoV94 converted this issue into discussion #1424 May 27, 2025

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Projects
None yet
Development

No branches or pull requests

2 participants