Skip to content

pulling in diff from trl-2747 #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
is_wandb_available,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from transformers.utils import is_peft_available, logging

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..import_utils import is_vllm_available
Expand All @@ -55,6 +55,9 @@
if is_wandb_available():
import wandb


logger = logging.get_logger("GRPOTrainer")

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
Expand Down Expand Up @@ -383,6 +386,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]

logger.info("Starting generation")
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
Expand Down Expand Up @@ -419,9 +423,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
unwrapped_model.eval() # Needed to make sure use_cache works with gradient_checkpointing
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
model.train()
logger.info("Finishing generation")

# Compute prompt length and extract completion ids
prompt_length = prompt_inputs["input_ids"].size(1)
Expand All @@ -435,7 +442,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()

# Concatenate prompt_mask with completion_mask for logit computation
prompt_mask_repeated = prompt_inputs["attention_mask"].repeat_interleave(self.num_generations, dim=0)
prompt_mask_repeated = prompt_inputs["attention_mask"].to(device).repeat_interleave(self.num_generations, dim=0)
attention_mask = torch.cat([prompt_mask_repeated, completion_mask], dim=1) # (B*G, P+C)

# Get the per-token log probabilities for the completions for the model and the reference model
Expand Down Expand Up @@ -541,6 +548,8 @@ def get_per_token_logps(model, input_ids, attention_mask, logits_to_keep):
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

logger.info("Finishing loss")

return loss

def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
Expand Down