Skip to content

Commit 2a5cf21

Browse files
committed
Decomposed _make_reward_processing_classes method
1 parent 65635d6 commit 2a5cf21

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -353,28 +353,7 @@ def __init__(
353353
else:
354354
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
355355

356-
# Reward processing class
357-
if reward_processing_classes is None:
358-
reward_processing_classes = [None] * len(reward_funcs)
359-
elif not isinstance(reward_processing_classes, list):
360-
reward_processing_classes = [reward_processing_classes]
361-
else:
362-
if len(reward_processing_classes) != len(reward_funcs):
363-
raise ValueError("The number of reward processing classes must match the number of reward functions.")
364-
365-
for i, (reward_processing_class, reward_func) in enumerate(
366-
zip(reward_processing_classes, reward_funcs, strict=True)
367-
):
368-
if isinstance(reward_func, PreTrainedModel):
369-
if reward_processing_class is None:
370-
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
371-
if reward_processing_class.pad_token_id is None:
372-
reward_processing_class.pad_token = reward_processing_class.eos_token
373-
# The reward model computes the reward for the latest non-padded token in the input sequence.
374-
# So it's important to set the pad token ID to the padding token ID of the processing class.
375-
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
376-
reward_processing_classes[i] = reward_processing_class
377-
self.reward_processing_classes = reward_processing_classes
356+
self.reward_processing_classes = self._make_reward_processing_classes(reward_funcs, reward_processing_classes)
378357

379358
# Data collator
380359
def data_collator(features): # No data collation is needed in GRPO
@@ -576,6 +555,33 @@ def new_group_context():
576555
if isinstance(reward_func, PreTrainedModel):
577556
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
578557

558+
@staticmethod
559+
def _make_reward_processing_classes(
560+
reward_funcs,
561+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
562+
):
563+
if reward_processing_classes is None:
564+
reward_processing_classes = [None] * len(reward_funcs)
565+
elif not isinstance(reward_processing_classes, list):
566+
reward_processing_classes = [reward_processing_classes]
567+
else:
568+
if len(reward_processing_classes) != len(reward_funcs):
569+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
570+
571+
for i, (reward_processing_class, reward_func) in enumerate(
572+
zip(reward_processing_classes, reward_funcs, strict=True)
573+
):
574+
if isinstance(reward_func, PreTrainedModel):
575+
if reward_processing_class is None:
576+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
577+
if reward_processing_class.pad_token_id is None:
578+
reward_processing_class.pad_token = reward_processing_class.eos_token
579+
# The reward model computes the reward for the latest non-padded token in the input sequence.
580+
# So it's important to set the pad token ID to the padding token ID of the processing class.
581+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
582+
reward_processing_classes[i] = reward_processing_class
583+
return reward_processing_classes
584+
579585
def _set_signature_columns_if_needed(self):
580586
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
581587
# By default, this method sets `self._signature_columns` to the model's expected inputs.

0 commit comments

Comments
 (0)