Skip to content

Adding a new transposed convolution function to lax #5772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yang-song
Copy link

This PR implements lax.gradient_based_conv_transpose. Compared to conv_transpose, it provides support for output_shape and output_padding. It matches the APIs for transposed convolutions derived from the gradient of a forward convolution, which is common in other deep learning frameworks such as TensorFlow, PyTorch, and Keras. This additional function on transposed convolution can make it much easier to reproduce code written in other (and currently more popular) frameworks.

@google-cla google-cla bot added the cla: yes label Feb 18, 2021
Copy link
Member

@froystig froystig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

The changes to lax.py include many file-wide formatting adjustments. Could you undo those? We probably don't want to take those at the moment, and it obscures the main change in the diff. You'll also want to squash commits so that there isn't one commit that makes formatting changes followed by another that undoes them.

@froystig froystig requested a review from hawkinsp February 18, 2021 23:49
@froystig
Copy link
Member

I'm curious what @hawkinsp and @mattjj think about this change overall, including whether it is best to add it in lax, and for any review comments as well.

@yang-song yang-song force-pushed the patch-1 branch 5 times, most recently from 89dc631 to 6584a54 Compare February 19, 2021 01:15
@yang-song
Copy link
Author

Thank you!

The changes to lax.py include many file-wide formatting adjustments. Could you undo those? We probably don't want to take those at the moment, and it obscures the main change in the diff. You'll also want to squash commits so that there isn't one commit that makes formatting changes followed by another that undoes them.

Just removed formatting changes (done automatically by my IDE).

@schrute99
Copy link

Are there any updates on this? It would be awesome to have this function.

@yang-song
Copy link
Author

Pending on @froystig and @hawkinsp. I think the requested changes have been made.

@hawkinsp
Copy link
Collaborator

I'm not an expert on this, but I'm wondering what the pros and cons are of introducing a new API endpoint vs adding features like output_shape to the existing conv_transpose. What do you think? Is there a reason we need a new function? Is it conceptually different in some important way?

@yang-song
Copy link
Author

Because the meaning of padding in conv_transpose is different from that of padding in gradient_based_conv_transpose. If we merge the APIs, we will either break all existing code using conv_tranpose (if adopting the padding in gradient_based_conv_transpose), or otherwise fail to match the APIs of other frameworks.

@hawkinsp hawkinsp requested a review from levskaya March 25, 2021 21:21
@levskaya
Copy link
Collaborator

levskaya commented Apr 2, 2021

Sorry, I wrote the original at a time when no frameworks really existed in JAX. (Nowadays, I'd probably not even add this function to "lax", since it's strictly a specialization of general convolutions, and delegate these matters to NN frameworks.)

Aside from a pending review of correctness, this mainly comes down to a question of organization:

  1. Keep our old conv_transpose and delegate these specialized "conv templates" to frameworks.
  2. Axe the old conv_transpose and use this new one to match other frameworks... but we probably have users of the old one that prevents that.
  3. Fold the alternative behavior definition into the existing one under an optional flag.
  4. Retrospectively I wish I had called the existing one "fractionally strided convolutions" since it's more accurate
    maybe that's what we should do? Rename the old one and add the new if it matches what most people expect of "conv transpose"? This comes at the expense of polluting the lax namespace a bit.

@schrute99
Copy link

The issue with 1. might be that at least for Flax and Objax, all the convolution modules are just wrappers of the functions in jax.lax. Also if every Jax framework implements its own transposed convolution, they might not be consistent.

@codeboy5
Copy link

Hey is this issue being actively worked on ?

younesbelkada added a commit to younesbelkada/jax that referenced this pull request Jun 25, 2022
@younesbelkada
Copy link

Hi all! Is there a plan to merge this PR? It seems that it is the root cause issue of converting some PyTorch models to JAX/FLAX , would be nice if we can merge it ;)

@ericd-1qbit
Copy link

The flax docs sent me here and I'd kindly like to add a +1 on hoping this PR will be merged soon :)

@leiteg
Copy link

leiteg commented Jun 21, 2023

Also came here from the Flax docs. Is there any other way to use transposed convolutions that are compatible with PyTorch's nn.ConvTranspose2d?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants