-
Notifications
You must be signed in to change notification settings - Fork 2k
⚠️ Fix Attention Masking in GRPO #2708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The current implementation computes The current code: per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps) Suggested change to mask out padding tokens: per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
mask = ~(input_ids_row != self.processing_class.pad_token_id)
token_log_prob *= mask
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps) This comment is being added here since it relates to KL divergence calculation, which is within the scope of this PR. If this understanding is incorrect, I would be very happy for someone to explain why 😄 |
thanks @andyl98, that's a very good point indeed. Can you build the attention mask out of the function instead? and not deduce it from the content of input_ids? in some edge cases, input ids may contain pad_token in the middle of the text. |
move this part # Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() before |
No, because we mask here: trl/trl/trainer/grpo_trainer.py Lines 511 to 512 in 265663a
|
Thanks @qgallouedec . Fixed as you suggested. Also I changed the |
…into fix-grpo-logits-calc
Btw, can you run small experiments so that we know the impact of not having this attention mask? If you don't have the compute/time I can handle it. 🙂 |
It would be great if you can test it as I need to work on some other stuff rn, appreciate the reviews :) |
I will, thanks a lot for spotting this issue! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll soon share the comparative results
Follow-up to huggingface#2708. Seems the attention_mask in the prompt_inputs can be CPU even when the rest of training is on GPU.
* Update grpo_trainer.py * Update grpo_trainer.py * Update grpo_trainer.py * Slight name change * Fix typo * Improve readability + move attn mask to args * revert adding "completion_" --------- Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
What does this PR do?
Fix a small yet dangerous issue -- not adding
attention_mask
when computing logits.Will add more notes.
Details
Note how there are a lot of padding tokens in the first element.
Now if we run
And we compare with
We'll get console output of
Compute row-wise, we get