Skip to content

Commit 87e2c61

Browse files
committed
Decomposed _make_reward_processing_classes method
1 parent 47c98fc commit 87e2c61

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -353,28 +353,7 @@ def __init__(
353353
else:
354354
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
355355

356-
# Reward processing class
357-
if reward_processing_classes is None:
358-
reward_processing_classes = [None] * len(reward_funcs)
359-
elif not isinstance(reward_processing_classes, list):
360-
reward_processing_classes = [reward_processing_classes]
361-
else:
362-
if len(reward_processing_classes) != len(reward_funcs):
363-
raise ValueError("The number of reward processing classes must match the number of reward functions.")
364-
365-
for i, (reward_processing_class, reward_func) in enumerate(
366-
zip(reward_processing_classes, reward_funcs, strict=True)
367-
):
368-
if isinstance(reward_func, PreTrainedModel):
369-
if reward_processing_class is None:
370-
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
371-
if reward_processing_class.pad_token_id is None:
372-
reward_processing_class.pad_token = reward_processing_class.eos_token
373-
# The reward model computes the reward for the latest non-padded token in the input sequence.
374-
# So it's important to set the pad token ID to the padding token ID of the processing class.
375-
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
376-
reward_processing_classes[i] = reward_processing_class
377-
self.reward_processing_classes = reward_processing_classes
356+
self.reward_processing_classes = self._make_reward_processing_classes(reward_funcs, reward_processing_classes)
378357

379358
# Data collator
380359
def data_collator(features): # No data collation is needed in GRPO
@@ -577,6 +556,33 @@ def new_group_context():
577556
if isinstance(reward_func, PreTrainedModel):
578557
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
579558

559+
@staticmethod
560+
def _make_reward_processing_classes(
561+
reward_funcs,
562+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
563+
):
564+
if reward_processing_classes is None:
565+
reward_processing_classes = [None] * len(reward_funcs)
566+
elif not isinstance(reward_processing_classes, list):
567+
reward_processing_classes = [reward_processing_classes]
568+
else:
569+
if len(reward_processing_classes) != len(reward_funcs):
570+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
571+
572+
for i, (reward_processing_class, reward_func) in enumerate(
573+
zip(reward_processing_classes, reward_funcs, strict=True)
574+
):
575+
if isinstance(reward_func, PreTrainedModel):
576+
if reward_processing_class is None:
577+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
578+
if reward_processing_class.pad_token_id is None:
579+
reward_processing_class.pad_token = reward_processing_class.eos_token
580+
# The reward model computes the reward for the latest non-padded token in the input sequence.
581+
# So it's important to set the pad token ID to the padding token ID of the processing class.
582+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
583+
reward_processing_classes[i] = reward_processing_class
584+
return reward_processing_classes
585+
580586
def _set_signature_columns_if_needed(self):
581587
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
582588
# By default, this method sets `self._signature_columns` to the model's expected inputs.

0 commit comments

Comments
 (0)