Skip to content

Commit 646cd14

Browse files
committed
Created GRPOTrainerWithEval subclasses for adding eval functions
1 parent 2a5cf21 commit 646cd14

File tree

1 file changed

+123
-8
lines changed

1 file changed

+123
-8
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import contextlib
1616
import functools
17+
import logging
1718
import os
1819
import textwrap
1920
import warnings
@@ -72,6 +73,8 @@
7273
if is_wandb_available():
7374
import wandb
7475

76+
logger = logging.getLogger(__name__)
77+
7578
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
7679
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
7780
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
@@ -277,25 +280,25 @@ def __init__(
277280

278281
# Models
279282
# Trained model
280-
model_init_kwargs = args.model_init_kwargs or {}
283+
self._model_init_kwargs = args.model_init_kwargs or {}
281284
if isinstance(model, str):
282285
model_id = model
283-
torch_dtype = model_init_kwargs.get("torch_dtype")
286+
torch_dtype = self._model_init_kwargs.get("torch_dtype")
284287
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
285288
pass # torch_dtype is already a torch.dtype or "auto" or None
286289
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
287290
torch_dtype = getattr(torch, torch_dtype)
288-
model_init_kwargs["torch_dtype"] = torch_dtype
291+
self._model_init_kwargs["torch_dtype"] = torch_dtype
289292
else:
290293
raise ValueError(
291294
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
292295
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
293296
)
294297
# Disable caching if gradient checkpointing is enabled (not supported)
295-
model_init_kwargs["use_cache"] = (
296-
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
298+
self._model_init_kwargs["use_cache"] = (
299+
False if args.gradient_checkpointing else self._model_init_kwargs.get("use_cache")
297300
)
298-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
301+
model = AutoModelForCausalLM.from_pretrained(model, **self._model_init_kwargs)
299302
else:
300303
model_id = model.config._name_or_path
301304
if args.model_init_kwargs is not None:
@@ -319,7 +322,7 @@ def __init__(
319322
# If beta is 0.0, the reference model is not needed
320323
self.ref_model = None
321324
elif is_deepspeed_zero3_enabled():
322-
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
325+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **self._model_init_kwargs)
323326
elif is_peft_model(model):
324327
# If PEFT is used, the reference model is not needed since the adapter can be disabled
325328
# to revert to the initial model.
@@ -338,7 +341,7 @@ def __init__(
338341
for i, reward_func in enumerate(reward_funcs):
339342
if isinstance(reward_func, str):
340343
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
341-
reward_func, num_labels=1, **model_init_kwargs
344+
reward_func, num_labels=1, **self._model_init_kwargs
342345
)
343346
self.reward_funcs = reward_funcs
344347

@@ -1181,3 +1184,115 @@ def create_model_card(
11811184
)
11821185

11831186
model_card.save(os.path.join(self.args.output_dir, "README.md"))
1187+
1188+
1189+
class GRPOTrainerWithEval(GRPOTrainer):
1190+
def __init__(
1191+
self,
1192+
model: str | PreTrainedModel,
1193+
train_reward_funcs: RewardFunc | list[RewardFunc],
1194+
eval_reward_funcs: RewardFunc | list[RewardFunc] | None = None,
1195+
args: GRPOConfig | None = None,
1196+
train_dataset: Dataset | IterableDataset | None = None,
1197+
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
1198+
processing_class: PreTrainedTokenizerBase | None = None,
1199+
train_reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
1200+
eval_reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
1201+
**kwargs,
1202+
):
1203+
super().__init__(
1204+
model=model,
1205+
reward_funcs=train_reward_funcs,
1206+
args=args,
1207+
train_dataset=train_dataset,
1208+
eval_dataset=eval_dataset,
1209+
processing_class=processing_class,
1210+
reward_processing_classes=train_reward_processing_classes,
1211+
**kwargs,
1212+
)
1213+
1214+
# Store training reward functions reference
1215+
self.train_reward_funcs = self.reward_funcs
1216+
self.train_reward_processing_classes = self.reward_processing_classes
1217+
1218+
if eval_reward_funcs is not None:
1219+
# Okay we have some custom evaluation reward functions, set them up
1220+
1221+
if "compute_metrics" in kwargs:
1222+
logger.warning(
1223+
"Please make sure your custom compute_metrics function is using the"
1224+
" right evaluation reward functions."
1225+
)
1226+
1227+
# Matching reward_funcs processing
1228+
if not isinstance(eval_reward_funcs, list):
1229+
eval_reward_funcs = [eval_reward_funcs]
1230+
for i, reward_func in enumerate(eval_reward_funcs):
1231+
if isinstance(reward_func, str):
1232+
eval_reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
1233+
reward_func, num_labels=1, **self._model_init_kwargs
1234+
)
1235+
self.eval_reward_funcs = eval_reward_funcs
1236+
self.eval_reward_processing_classes = self._make_reward_processing_classes(
1237+
eval_reward_funcs, eval_reward_processing_classes
1238+
)
1239+
else:
1240+
# We don't have any, so we just reuse the training ones
1241+
self.eval_reward_funcs = self.train_reward_funcs
1242+
self.eval_reward_processing_classes = self.train_reward_processing_classes
1243+
1244+
def _compute_rewards_per_func(self, inputs, prompts: list[str], completions: list[str], device) -> torch.Tensor:
1245+
if self.control.should_evaluate:
1246+
reward_funcs = self.eval_reward_funcs
1247+
reward_processing_classes = self.eval_reward_processing_classes
1248+
else:
1249+
reward_funcs = self.train_reward_funcs
1250+
reward_processing_classes = self.train_reward_processing_classes
1251+
1252+
rewards_per_func = torch.zeros(len(prompts), len(reward_funcs), device=device)
1253+
for i, (reward_func, reward_processing_class) in enumerate(
1254+
zip(reward_funcs, reward_processing_classes, strict=True)
1255+
):
1256+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
1257+
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
1258+
else:
1259+
reward_func_name = reward_func.__name__
1260+
with profiling_context(self, reward_func_name):
1261+
if isinstance(
1262+
reward_func, nn.Module
1263+
): # Module instead of PretrainedModel for compat with compiled models
1264+
if is_conversational(inputs[0]):
1265+
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
1266+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
1267+
else:
1268+
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
1269+
reward_inputs = reward_processing_class(
1270+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
1271+
)
1272+
reward_inputs = super()._prepare_inputs(reward_inputs)
1273+
with torch.inference_mode():
1274+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
1275+
else:
1276+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
1277+
keys = [key for key in inputs[0] if key not in {"prompt", "completion"}]
1278+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
1279+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
1280+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
1281+
1282+
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1283+
# completions may be distributed across processes
1284+
return gather(rewards_per_func)
1285+
1286+
def compute_reward_metrics(self, eval_prediction: EvalPrediction) -> dict[str, float]:
1287+
if not self.control.should_evaluate:
1288+
raise RuntimeError("We are supposed to be in evaluation mode.")
1289+
1290+
avg_reward_per_func = eval_prediction.predictions.mean(axis=0)
1291+
metrics: dict[str, float] = {}
1292+
for i, reward_func in enumerate(self.eval_reward_funcs):
1293+
if isinstance(reward_func, PreTrainedModel):
1294+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1295+
else:
1296+
reward_func_name = reward_func.__name__
1297+
metrics[f"rewards/{reward_func_name}"] = avg_reward_per_func[i].item()
1298+
return metrics

0 commit comments

Comments
 (0)