Skip to content

[CLIP ENCODER] Vision Transform for Clip encoder #1127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jul 8, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Jun 27, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Added Vision Transform architecture, which can be used for the CLIP model.

Main points to focus on the review:

  • Docstrings, specially of VisionTransformer. Are the examples clear? Are the names intuitive (tile, patch, token)?
  • Are we comfortable with leaving the CLS projection in the ViT if we are going to create a projection module somewhere else?

Changelog

  • LayerNorm (public)
  • Vision Transform (public)
  • clip_vision_encoder builder
  • Positional embeddings to supports tiled images

Test plan

  • Shared notebook with parity check with clip in torchmutimodal
  • unit tests asserting shape and regression of positional encodings and ViT.
  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented Jun 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1127

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 348e681 with merge base 06a125e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 27, 2024
@felipemello1 felipemello1 marked this pull request as draft June 27, 2024 01:12
@felipemello1 felipemello1 changed the title [WIP][CLIP ENCODER] Vision Transform for Clip encoder [CLIP ENCODER] Vision Transform for Clip encoder Jul 1, 2024
@codecov-commenter
Copy link

codecov-commenter commented Jul 1, 2024

Codecov Report

Attention: Patch coverage is 29.62963% with 209 lines in your changes missing coverage. Please review.

Project coverage is 26.82%. Comparing base (f158577) to head (beb80c4).
Report is 2 commits behind head on main.

Files Patch % Lines
torchtune/modules/vision_transformer.py 18.51% 66 Missing ⚠️
tests/torchtune/modules/test_vision_transformer.py 22.85% 54 Missing ⚠️
torchtune/models/clip/_position_embeddings.py 23.91% 35 Missing ⚠️
...orchtune/models/clip/test_positional_embeddings.py 26.31% 28 Missing ⚠️
torchtune/models/clip/_component_builders.py 35.00% 13 Missing ⚠️
tests/torchtune/modules/test_layernorm.py 62.96% 10 Missing ⚠️
torchtune/modules/layer_norm.py 62.50% 3 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (f158577) and HEAD (beb80c4). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (f158577) HEAD (beb80c4)
4 3
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1127       +/-   ##
===========================================
- Coverage   65.98%   26.82%   -39.17%     
===========================================
  Files         194      212       +18     
  Lines        9023     9595      +572     
===========================================
- Hits         5954     2574     -3380     
- Misses       3069     7021     +3952     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipemello1 felipemello1 marked this pull request as ready for review July 1, 2024 04:54
Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments, but overall this looks good to me - thank you for patiently addressing all of the comments. I'll let @pbontrager and/or @ebsmothers take a pass and stamp

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking really good, and I really love the doc strings with this. I left a number of comments for a few things to either address or clarify, but I think it's close to ready to land.


logger = logging.getLogger(__name__)

def clip_vision_encoder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm tempted to say that maybe we should have clip_vision_encoder and tiled_clip_vision_encoder builders

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can see why, but we may have to pay the debt elsewhere. In the transforms, adapter, masking and inference we may have to check the shape and see if it contains tiles. Assuming everything is tiled saves some downstream complexity.

from torch import nn, Tensor


class Fp32LayerNorm(nn.LayerNorm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently support mixed precision training in the library, so I don't believe we should include this. On top of that, I believe torch autocast already automatically converts layernorm to fp32 ref

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3) The patches will be flattened and transformed. We call them tokens, because that's how the transformer sees them.


Image: shape (8x8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think any of these are actually appearing in the docs. I didn't see anything below "In summary"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird. This is how it shows for me

image

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for patiently addressing all of the comments!

@felipemello1 felipemello1 merged commit 069b12b into pytorch:main Jul 8, 2024
29 checks passed
@felipemello1 felipemello1 deleted the clip_encoder branch July 8, 2024 22:39
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Kartikay Khandelwal <[email protected]>
yinfan98 pushed a commit to yinfan98/sgl-tune-eagle that referenced this pull request May 26, 2025
…opk and routing (pytorch#1127)

This PR adds a single sort_tokens function (simplified from earlier
prepare_expert_routing).
This removes the code duplication that was present earlier in:
moe_forward and 
moe_on_device, 

as both were doing the same exact expert routing prep. 

The goal here is to avoid tech debt by streamlining this functionality
into a single function to ensure any future updates are thus
auto-propagated.

Testing:
verified inference works as before, with and without cuda graphs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants