Skip to content

Commit 3b1a796

Browse files
committed
Created GRPOTrainerWithEval subclasses for adding eval functions
1 parent 87e2c61 commit 3b1a796

File tree

1 file changed

+123
-8
lines changed

1 file changed

+123
-8
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import contextlib
1616
import functools
17+
import logging
1718
import os
1819
import textwrap
1920
import warnings
@@ -72,6 +73,8 @@
7273
if is_wandb_available():
7374
import wandb
7475

76+
logger = logging.getLogger(__name__)
77+
7578
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
7679
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
7780
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
@@ -277,25 +280,25 @@ def __init__(
277280

278281
# Models
279282
# Trained model
280-
model_init_kwargs = args.model_init_kwargs or {}
283+
self._model_init_kwargs = args.model_init_kwargs or {}
281284
if isinstance(model, str):
282285
model_id = model
283-
torch_dtype = model_init_kwargs.get("torch_dtype")
286+
torch_dtype = self._model_init_kwargs.get("torch_dtype")
284287
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
285288
pass # torch_dtype is already a torch.dtype or "auto" or None
286289
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
287290
torch_dtype = getattr(torch, torch_dtype)
288-
model_init_kwargs["torch_dtype"] = torch_dtype
291+
self._model_init_kwargs["torch_dtype"] = torch_dtype
289292
else:
290293
raise ValueError(
291294
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
292295
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
293296
)
294297
# Disable caching if gradient checkpointing is enabled (not supported)
295-
model_init_kwargs["use_cache"] = (
296-
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
298+
self._model_init_kwargs["use_cache"] = (
299+
False if args.gradient_checkpointing else self._model_init_kwargs.get("use_cache")
297300
)
298-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
301+
model = AutoModelForCausalLM.from_pretrained(model, **self._model_init_kwargs)
299302
else:
300303
model_id = model.config._name_or_path
301304
if args.model_init_kwargs is not None:
@@ -319,7 +322,7 @@ def __init__(
319322
# If beta is 0.0, the reference model is not needed
320323
self.ref_model = None
321324
elif is_deepspeed_zero3_enabled():
322-
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
325+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **self._model_init_kwargs)
323326
elif is_peft_model(model):
324327
# If PEFT is used, the reference model is not needed since the adapter can be disabled
325328
# to revert to the initial model.
@@ -338,7 +341,7 @@ def __init__(
338341
for i, reward_func in enumerate(reward_funcs):
339342
if isinstance(reward_func, str):
340343
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
341-
reward_func, num_labels=1, **model_init_kwargs
344+
reward_func, num_labels=1, **self._model_init_kwargs
342345
)
343346
self.reward_funcs = reward_funcs
344347

@@ -1185,3 +1188,115 @@ def create_model_card(
11851188
)
11861189

11871190
model_card.save(os.path.join(self.args.output_dir, "README.md"))
1191+
1192+
1193+
class GRPOTrainerWithEval(GRPOTrainer):
1194+
def __init__(
1195+
self,
1196+
model: str | PreTrainedModel,
1197+
train_reward_funcs: RewardFunc | list[RewardFunc],
1198+
eval_reward_funcs: RewardFunc | list[RewardFunc] | None = None,
1199+
args: GRPOConfig | None = None,
1200+
train_dataset: Dataset | IterableDataset | None = None,
1201+
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
1202+
processing_class: PreTrainedTokenizerBase | None = None,
1203+
train_reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
1204+
eval_reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
1205+
**kwargs,
1206+
):
1207+
super().__init__(
1208+
model=model,
1209+
reward_funcs=train_reward_funcs,
1210+
args=args,
1211+
train_dataset=train_dataset,
1212+
eval_dataset=eval_dataset,
1213+
processing_class=processing_class,
1214+
reward_processing_classes=train_reward_processing_classes,
1215+
**kwargs,
1216+
)
1217+
1218+
# Store training reward functions reference
1219+
self.train_reward_funcs = self.reward_funcs
1220+
self.train_reward_processing_classes = self.reward_processing_classes
1221+
1222+
if eval_reward_funcs is not None:
1223+
# Okay we have some custom evaluation reward functions, set them up
1224+
1225+
if "compute_metrics" in kwargs:
1226+
logger.warning(
1227+
"Please make sure your custom compute_metrics function is using the"
1228+
" right evaluation reward functions."
1229+
)
1230+
1231+
# Matching reward_funcs processing
1232+
if not isinstance(eval_reward_funcs, list):
1233+
eval_reward_funcs = [eval_reward_funcs]
1234+
for i, reward_func in enumerate(eval_reward_funcs):
1235+
if isinstance(reward_func, str):
1236+
eval_reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
1237+
reward_func, num_labels=1, **self._model_init_kwargs
1238+
)
1239+
self.eval_reward_funcs = eval_reward_funcs
1240+
self.eval_reward_processing_classes = self._make_reward_processing_classes(
1241+
eval_reward_funcs, eval_reward_processing_classes
1242+
)
1243+
else:
1244+
# We don't have any, so we just reuse the training ones
1245+
self.eval_reward_funcs = self.train_reward_funcs
1246+
self.eval_reward_processing_classes = self.train_reward_processing_classes
1247+
1248+
def _compute_rewards_per_func(self, inputs, prompts: list[str], completions: list[str], device) -> torch.Tensor:
1249+
if self.control.should_evaluate:
1250+
reward_funcs = self.eval_reward_funcs
1251+
reward_processing_classes = self.eval_reward_processing_classes
1252+
else:
1253+
reward_funcs = self.train_reward_funcs
1254+
reward_processing_classes = self.train_reward_processing_classes
1255+
1256+
rewards_per_func = torch.zeros(len(prompts), len(reward_funcs), device=device)
1257+
for i, (reward_func, reward_processing_class) in enumerate(
1258+
zip(reward_funcs, reward_processing_classes, strict=True)
1259+
):
1260+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1261+
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
1262+
else:
1263+
reward_func_name = reward_func.__name__
1264+
with profiling_context(self, reward_func_name):
1265+
if isinstance(
1266+
reward_func, nn.Module
1267+
): # Module instead of PretrainedModel for compat with compiled models
1268+
if is_conversational(inputs[0]):
1269+
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
1270+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1271+
else:
1272+
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
1273+
reward_inputs = reward_processing_class(
1274+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1275+
)
1276+
reward_inputs = super()._prepare_inputs(reward_inputs)
1277+
with torch.inference_mode():
1278+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1279+
else:
1280+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
1281+
keys = [key for key in inputs[0] if key not in {"prompt", "completion"}]
1282+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1283+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
1284+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1285+
1286+
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1287+
# completions may be distributed across processes
1288+
return gather(rewards_per_func)
1289+
1290+
def compute_reward_metrics(self, eval_prediction: EvalPrediction) -> dict[str, float]:
1291+
if not self.control.should_evaluate:
1292+
raise RuntimeError("We are supposed to be in evaluation mode.")
1293+
1294+
avg_reward_per_func = eval_prediction.predictions.mean(axis=0)
1295+
metrics: dict[str, float] = {}
1296+
for i, reward_func in enumerate(self.eval_reward_funcs):
1297+
if isinstance(reward_func, PreTrainedModel):
1298+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1299+
else:
1300+
reward_func_name = reward_func.__name__
1301+
metrics[f"rewards/{reward_func_name}"] = avg_reward_per_func[i].item()
1302+
return metrics

0 commit comments

Comments
 (0)