Skip to content

⚠️ Fix Attention Masking in GRPO #2708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 2, 2025

Conversation

andyl98
Copy link
Contributor

@andyl98 andyl98 commented Jan 31, 2025

What does this PR do?

Fix a small yet dangerous issue -- not adding attention_mask when computing logits.

Will add more notes.

Details

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-Coder-0.5B", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-0.5B")

prompts = [
    "hello, 1 + 1 = ?",
    "hello, 12781297598217521580 + 120481209481290482 = ?",
]

prompt_inputs = tokenizer(
    prompts, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
).to(model.device)

generation_config = GenerationConfig(
    max_new_tokens=100,
    do_sample=True,
    temperature=0.2,
    num_return_sequences=4,
    pad_token_id=tokenizer.pad_token_id,
)

prompt_completion_ids = model.generate(
    **prompt_inputs, generation_config=generation_config
)

prompt_length = prompt_inputs["input_ids"].size(1) #45
completion_ids = prompt_completion_ids[:, prompt_length:]
num_logits_to_keep = completion_ids.size(1) #100
> input_ids

{'input_ids': tensor([[151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
         151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643,
          14990,     11,    220,     16,    488,    220,     16,    284,    937],
        [ 14990,     11,    220,     16,     17,     22,     23,     16,     17,
             24,     22,     20,     24,     23,     17,     16,     22,     20,
             17,     16,     20,     23,     15,    488,    220,     16,     17,
             15,     19,     23,     16,     17,     15,     24,     19,     23,
             16,     17,     24,     15,     19,     23,     17,    284,    937]],
       device='mps:0'), 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='mps:0')}

Note how there are a lot of padding tokens in the first element.

Now if we run

input_ids = prompt_completion_ids

# Old Implementation's last prompt logit (for predicting first completion)
old_logps = model(prompt_completion_ids, num_logits_to_keep=num_logits_to_keep+1).logits[:, 0, :].log_softmax(dim=-1)

# This PR
attention_mask = (input_ids != tokenizer.pad_token_id).long()
new_logps = model(input_ids=input_ids, attention_mask=attention_mask, num_logits_to_keep=num_logits_to_keep+1).logits[:, 0, :].log_softmax(dim=-1)

# Ground Truth
correct_logps = model(**prompt_inputs).logits[:, -1, :].repeat_interleave(4, dim=0).log_softmax(dim=-1)

And we compare with

# Check that the logits are the same
print(torch.allclose(old_logps, new_logps, atol=1e-5))
print(torch.allclose(old_logps, correct_logps, atol=1e-5))
print(torch.allclose(new_logps, correct_logps, atol=1e-5))

We'll get console output of

> False
> False
> True

Compute row-wise, we get

for i in range(old_logps.size(0)):
    diff = torch.mean(torch.abs(old_logps[i] - correct_logps[i]))
    print(diff)

# tensor(1.5913, device='mps:0', grad_fn=<MeanBackward0>)
# tensor(1.5913, device='mps:0', grad_fn=<MeanBackward0>)
# tensor(1.5913, device='mps:0', grad_fn=<MeanBackward0>)
# tensor(1.5913, device='mps:0', grad_fn=<MeanBackward0>)
# tensor(0., device='mps:0', grad_fn=<MeanBackward0>)
# tensor(0., device='mps:0', grad_fn=<MeanBackward0>)
# tensor(0., device='mps:0', grad_fn=<MeanBackward0>)
# tensor(0., device='mps:0', grad_fn=<MeanBackward0>)

@andyl98 andyl98 changed the title 🔧 Add Attention Mask to GRPO Logits Computation ⚠️ Fix Attention Mask in GRPO Logits Computation Jan 31, 2025
@andyl98 andyl98 changed the title ⚠️ Fix Attention Mask in GRPO Logits Computation ⚠️ Fix Attention Mask in GRPO Jan 31, 2025
@aboros98
Copy link

The current implementation computes per_token_logps without masking the values corresponding to the padding tokens. This means padding tokens are included in the KL divergence calculation, which could slightly skew the results since padding isn't meaningful content.

The current code:

per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
  log_probs = logits_row.log_softmax(dim=-1)
  token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
  per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

Suggested change to mask out padding tokens:

per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
  log_probs = logits_row.log_softmax(dim=-1)
  token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
  mask = ~(input_ids_row != self.processing_class.pad_token_id)
  token_log_prob *= mask
  per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

This comment is being added here since it relates to KL divergence calculation, which is within the scope of this PR. If this understanding is incorrect, I would be very happy for someone to explain why 😄

@qgallouedec
Copy link
Member

thanks @andyl98, that's a very good point indeed. Can you build the attention mask out of the function instead? and not deduce it from the content of input_ids? in some edge cases, input ids may contain pad_token in the middle of the text.

@qgallouedec
Copy link
Member

move this part

# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

before get_per_token_logps, then concat prompt_inputs["attention_mask"] and completion_mask

@qgallouedec
Copy link
Member

This means padding tokens are included in the KL divergence calculation

No, because we mask here:

per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 31, 2025

Thanks @qgallouedec . Fixed as you suggested.

Also I changed the completion_mask's dtype from int to long as this is the default dtype for ["attention_mask"] (plus some variable names for clarity). Feel free to edit and merge.

@andyl98 andyl98 changed the title ⚠️ Fix Attention Mask in GRPO ⚠️ Fix Attention Masking in GRPO Jan 31, 2025
@qgallouedec
Copy link
Member

qgallouedec commented Jan 31, 2025

Btw, can you run small experiments so that we know the impact of not having this attention mask? If you don't have the compute/time I can handle it. 🙂

@andyl98
Copy link
Contributor Author

andyl98 commented Jan 31, 2025

It would be great if you can test it as I need to work on some other stuff rn, appreciate the reviews :)

@qgallouedec
Copy link
Member

I will, thanks a lot for spotting this issue!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll soon share the comparative results

@qgallouedec
Copy link
Member

very nice finding (green is this PR, red is main)

Screenshot 2025-02-02 at 20 16 30

@qgallouedec qgallouedec merged commit bbdd6db into huggingface:main Feb 2, 2025
13 checks passed
tgaddair added a commit to tgaddair/trl that referenced this pull request Feb 3, 2025
Follow-up to huggingface#2708. Seems the attention_mask in the prompt_inputs can be CPU even when the rest of training is on GPU.
@andyl98 andyl98 mentioned this pull request Mar 11, 2025
5 tasks
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* Update grpo_trainer.py

* Update grpo_trainer.py

* Update grpo_trainer.py

* Slight name change

* Fix typo

* Improve readability + move attn mask to args

* revert adding "completion_"

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants