-
Notifications
You must be signed in to change notification settings - Fork 647
modules doc updates #1588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
modules doc updates #1588
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1588
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 76782d4 with merge base f6d3a7a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
the cross entropy normally, but upcasting only one chunk at a time saves considerable memory. | ||
|
||
The CE and upcasting have to be compiled together for better performance. | ||
When using this class, we recommend using torch.compile only on the method `compute_cross_entropy`. | ||
When using this class, we recommend using :func:`torch.compile` only on the method ``compute_cross_entropy``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you self reference method here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was debating it but honestly it's like 5 lines below and not well-documented anyways so I thought it was overkill
torchtune/modules/lr_schedulers.py
Outdated
0.0 to lr over num_warmup_steps, then decreases to 0.0 on a cosine schedule over | ||
the remaining num_training_steps-num_warmup_steps (assuming num_cycles = 0.5). | ||
0.0 to lr over ``num_warmup_steps``, then decreases to 0.0 on a cosine schedule over | ||
the remaining ``num_training_steps-num_warmup_steps`` (assuming num_cycles = 0.5). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backticks on num cycles
@@ -149,7 +149,7 @@ class FusionEmbedding(nn.Module): | |||
second embedding for the additional tokens. During forward this module routes | |||
the tokens to the appropriate embedding table. | |||
|
|||
Use this as a drop-in replacement for `nn.Embedding` in your model. | |||
Use this as a drop-in replacement for ``nn.Embedding`` in your model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Point to :class:torch.nn.Embedding?
@@ -85,16 +85,13 @@ def forward( | |||
If none, assume the index of the token is its position id. Default is None. | |||
|
|||
Returns: | |||
torch.Tensor: output tensor with RoPE applied | |||
torch.Tensor: output tensor with shape [b, s, n_h, h_d] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backticks on shape
torchtune/modules/tied_linear.py
Outdated
@@ -31,4 +32,12 @@ def __init__(self, tied_module: nn.Module): | |||
) | |||
|
|||
def __call__(self, x: torch.tensor) -> torch.tensor: | |||
""" | |||
Args: | |||
x (torch.tensor): Input tensor. Should have shape ``(..., in_dim)``, where ``in_dim`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these typed with lowercase tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh you caught me. idk and idc, I just wanted these changes to be docstring-only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving it as is for now, we can check with @felipemello1 on whether this was by design or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it just to see if you guys truly review the PRs, obviously. Congratulations, you passed the test! Wanna update it to upper case?
torchtune/modules/attention.py
Outdated
@@ -28,7 +28,7 @@ class MultiHeadAttention(nn.Module): | |||
Following is an example of MHA, GQA and MQA with num_heads = 4 | |||
|
|||
(credit for the documentation: | |||
https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/config.py). | |||
https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py). | |
`litgpt.Config <https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py>`_). |
torchtune/modules/layer_norm.py
Outdated
@@ -25,7 +25,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
x (torch.Tensor): Input tensor. | |||
|
|||
Returns: | |||
torch.Tensor: The normalized output tensor. | |||
torch.Tensor: The normalized output tensor having the same shape as x. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Tensor: The normalized output tensor having the same shape as x. | |
torch.Tensor: The normalized output tensor having the same shape as ``x``. |
@@ -356,8 +356,8 @@ def forward( | |||
KV values for each position. | |||
|
|||
Returns: | |||
Tensor: output tensor with shape [b x s x v] or a list of layer | |||
output tensors defined by ``output_hidden_states`` with the | |||
Tensor: output tensor with shape [b x s x v] or a list of layer \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor: output tensor with shape [b x s x v] or a list of layer \ | |
Tensor: Output tensor with shape ``[b x s x v]`` or a list of layer \ |
torchtune/modules/rlhf/rewards.py
Outdated
@@ -27,7 +27,7 @@ def get_reward_penalty_mask( | |||
|
|||
Args: | |||
padding_masks (torch.Tensor): torch.Tensor where True indicates a padding token in the generated | |||
sequence, and False otherwise. Shape: (b, reponse_len) | |||
sequence, and False otherwise. Shape: (b, response_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sequence, and False otherwise. Shape: (b, response_len) | |
sequence, and False otherwise. Shape: ``(b, response_len)`` |
torchtune/modules/rlhf/rewards.py
Outdated
@@ -27,7 +27,7 @@ def get_reward_penalty_mask( | |||
|
|||
Args: | |||
padding_masks (torch.Tensor): torch.Tensor where True indicates a padding token in the generated | |||
sequence, and False otherwise. Shape: (b, reponse_len) | |||
sequence, and False otherwise. Shape: (b, response_len) | |||
seq_lens (torch.Tensor): The length of each generated sequence. Shape: (b,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_lens (torch.Tensor): The length of each generated sequence. Shape: (b,) | |
seq_lens (torch.Tensor): The length of each generated sequence. Shape: ``(b,)`` |
torchtune/modules/rlhf/rewards.py
Outdated
@@ -58,8 +58,8 @@ def get_rewards_ppo( | |||
|
|||
Args: | |||
scores (torch.Tensor): Reward model scores, shape (b,). | |||
logprobs (torch.Tensor): Policy logprobs, shape (b, reponse_len). | |||
ref_logprobs (torch.Tensor): Reference base model, shape (b, reponse_len). | |||
logprobs (torch.Tensor): Policy logprobs, shape (b, response_len). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logprobs (torch.Tensor): Policy logprobs, shape (b, response_len). | |
logprobs (torch.Tensor): Policy logprobs, shape ``(b, response_len)``. |
torchtune/modules/rlhf/rewards.py
Outdated
logprobs (torch.Tensor): Policy logprobs, shape (b, reponse_len). | ||
ref_logprobs (torch.Tensor): Reference base model, shape (b, reponse_len). | ||
logprobs (torch.Tensor): Policy logprobs, shape (b, response_len). | ||
ref_logprobs (torch.Tensor): Reference base model, shape (b, response_len). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ref_logprobs (torch.Tensor): Reference base model, shape (b, response_len). | |
ref_logprobs (torch.Tensor): Reference base model logprobs, shape ``(b, response_len)``. |
torchtune/modules/rlhf/rewards.py
Outdated
values (torch.Tensor): The predicted values for each state. Shape: (b, response_len) | ||
rewards (torch.Tensor): The rewards received at each time step. Shape: (b, response_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
values (torch.Tensor): The predicted values for each state. Shape: (b, response_len) | |
rewards (torch.Tensor): The rewards received at each time step. Shape: (b, response_len) | |
values (torch.Tensor): The predicted values for each state. Shape: ``(b, response_len)`` | |
rewards (torch.Tensor): The rewards received at each time step. Shape: ``(b, response_len)`` |
torchtune/modules/rlhf/rewards.py
Outdated
- advantages (torch.Tensor): The estimated advantages. Shape: (b, response_len) | ||
- returns (torch.Tensor): The estimated returns. Shape: (b, response_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- advantages (torch.Tensor): The estimated advantages. Shape: (b, response_len) | |
- returns (torch.Tensor): The estimated returns. Shape: (b, response_len) | |
- advantages (torch.Tensor): The estimated advantages. Shape: ``(b, response_len)`` | |
- returns (torch.Tensor): The estimated returns. Shape: ``(b, response_len)`` |
torchtune/modules/rms_norm.py
Outdated
@@ -34,7 +34,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
x (torch.Tensor): input tensor to normalize | |||
|
|||
Returns: | |||
torch.Tensor: The output tensor after applying RMSNorm. | |||
torch.Tensor: The normalized and scaled tensor having the same shape as x. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Tensor: The normalized and scaled tensor having the same shape as x. | |
torch.Tensor: The normalized and scaled tensor having the same shape as ``x``. |
torchtune/modules/tanh_gate.py
Outdated
@@ -22,6 +22,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
x (torch.Tensor): input tensor to gate | |||
|
|||
Returns: | |||
torch.Tensor: The output tensor after gating. | |||
torch.Tensor: The output tensor after gating. Has the same shape as x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Tensor: The output tensor after gating. Has the same shape as x | |
torch.Tensor: The output tensor after gating. Has the same shape as ``x``. |
maybe I'm getting carried away
torchtune/modules/tied_linear.py
Outdated
@@ -12,7 +12,8 @@ | |||
class TiedLinear: | |||
""" | |||
A tied linear layer, without bias, that shares the same weight as another linear layer. | |||
This is useful for models that use tied weights, such as qwen and gemma. | |||
This is useful for models that use tied weights, such as :func:`~torchtune.models.qwen2_0_5b`, | |||
:func:`~torchtune.models.qwen2_1_5b` and :func:`~torchtune.models.gemma`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:func:`~torchtune.models.qwen2_1_5b` and :func:`~torchtune.models.gemma`. | |
:func:`~torchtune.models.qwen2_1_5b` and all of the :func:`~torchtune.models.gemma` models. |
That was kind of cathartic |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1588 +/- ##
===========================================
- Coverage 73.12% 27.00% -46.13%
===========================================
Files 289 290 +1
Lines 14175 14252 +77
===========================================
- Hits 10366 3849 -6517
- Misses 3809 10403 +6594 ☔ View full report in Codecov by Sentry. |
A bunch of miscellaneous changes to torchtune/modules docstrings so that our API docs look a bit nicer. Tbh I probably could have done a lot more here and many of the changes are more surface-level, but every little bit helps.