-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Adding a new transposed convolution function to lax #5772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
The changes to lax.py include many file-wide formatting adjustments. Could you undo those? We probably don't want to take those at the moment, and it obscures the main change in the diff. You'll also want to squash commits so that there isn't one commit that makes formatting changes followed by another that undoes them.
89dc631
to
6584a54
Compare
Just removed formatting changes (done automatically by my IDE). |
Are there any updates on this? It would be awesome to have this function. |
I'm not an expert on this, but I'm wondering what the pros and cons are of introducing a new API endpoint vs adding features like |
Because the meaning of |
Sorry, I wrote the original at a time when no frameworks really existed in JAX. (Nowadays, I'd probably not even add this function to "lax", since it's strictly a specialization of general convolutions, and delegate these matters to NN frameworks.) Aside from a pending review of correctness, this mainly comes down to a question of organization:
|
The issue with 1. might be that at least for Flax and Objax, all the convolution modules are just wrappers of the functions in |
Hey is this issue being actively worked on ? |
Hi all! Is there a plan to merge this PR? It seems that it is the root cause issue of converting some PyTorch models to JAX/FLAX , would be nice if we can merge it ;) |
The |
Also came here from the Flax docs. Is there any other way to use transposed convolutions that are compatible with PyTorch's |
This PR implements
lax.gradient_based_conv_transpose
. Compared toconv_transpose
, it provides support foroutput_shape
andoutput_padding
. It matches the APIs for transposed convolutions derived from the gradient of a forward convolution, which is common in other deep learning frameworks such as TensorFlow, PyTorch, and Keras. This additional function on transposed convolution can make it much easier to reproduce code written in other (and currently more popular) frameworks.