14
14
15
15
import contextlib
16
16
import functools
17
+ import logging
17
18
import os
18
19
import textwrap
19
20
import warnings
72
73
if is_wandb_available ():
73
74
import wandb
74
75
76
+ logger = logging .getLogger (__name__ )
77
+
75
78
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
76
79
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
77
80
RewardFunc = Union [str , PreTrainedModel , Callable [[list , list ], list [float ]]]
@@ -277,25 +280,25 @@ def __init__(
277
280
278
281
# Models
279
282
# Trained model
280
- model_init_kwargs = args .model_init_kwargs or {}
283
+ self . _model_init_kwargs = args .model_init_kwargs or {}
281
284
if isinstance (model , str ):
282
285
model_id = model
283
- torch_dtype = model_init_kwargs .get ("torch_dtype" )
286
+ torch_dtype = self . _model_init_kwargs .get ("torch_dtype" )
284
287
if isinstance (torch_dtype , torch .dtype ) or torch_dtype == "auto" or torch_dtype is None :
285
288
pass # torch_dtype is already a torch.dtype or "auto" or None
286
289
elif isinstance (torch_dtype , str ): # it's a str, but not "auto"
287
290
torch_dtype = getattr (torch , torch_dtype )
288
- model_init_kwargs ["torch_dtype" ] = torch_dtype
291
+ self . _model_init_kwargs ["torch_dtype" ] = torch_dtype
289
292
else :
290
293
raise ValueError (
291
294
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
292
295
f"a `torch.dtype` (e.g., 'float32'), but got { torch_dtype } ."
293
296
)
294
297
# Disable caching if gradient checkpointing is enabled (not supported)
295
- model_init_kwargs ["use_cache" ] = (
296
- False if args .gradient_checkpointing else model_init_kwargs .get ("use_cache" )
298
+ self . _model_init_kwargs ["use_cache" ] = (
299
+ False if args .gradient_checkpointing else self . _model_init_kwargs .get ("use_cache" )
297
300
)
298
- model = AutoModelForCausalLM .from_pretrained (model , ** model_init_kwargs )
301
+ model = AutoModelForCausalLM .from_pretrained (model , ** self . _model_init_kwargs )
299
302
else :
300
303
model_id = model .config ._name_or_path
301
304
if args .model_init_kwargs is not None :
@@ -319,7 +322,7 @@ def __init__(
319
322
# If beta is 0.0, the reference model is not needed
320
323
self .ref_model = None
321
324
elif is_deepspeed_zero3_enabled ():
322
- self .ref_model = AutoModelForCausalLM .from_pretrained (model_id , ** model_init_kwargs )
325
+ self .ref_model = AutoModelForCausalLM .from_pretrained (model_id , ** self . _model_init_kwargs )
323
326
elif is_peft_model (model ):
324
327
# If PEFT is used, the reference model is not needed since the adapter can be disabled
325
328
# to revert to the initial model.
@@ -338,7 +341,7 @@ def __init__(
338
341
for i , reward_func in enumerate (reward_funcs ):
339
342
if isinstance (reward_func , str ):
340
343
reward_funcs [i ] = AutoModelForSequenceClassification .from_pretrained (
341
- reward_func , num_labels = 1 , ** model_init_kwargs
344
+ reward_func , num_labels = 1 , ** self . _model_init_kwargs
342
345
)
343
346
self .reward_funcs = reward_funcs
344
347
@@ -1181,3 +1184,115 @@ def create_model_card(
1181
1184
)
1182
1185
1183
1186
model_card .save (os .path .join (self .args .output_dir , "README.md" ))
1187
+
1188
+
1189
+ class GRPOTrainerWithEval (GRPOTrainer ):
1190
+ def __init__ (
1191
+ self ,
1192
+ model : str | PreTrainedModel ,
1193
+ train_reward_funcs : RewardFunc | list [RewardFunc ],
1194
+ eval_reward_funcs : RewardFunc | list [RewardFunc ] | None = None ,
1195
+ args : GRPOConfig | None = None ,
1196
+ train_dataset : Dataset | IterableDataset | None = None ,
1197
+ eval_dataset : Dataset | IterableDataset | dict [str , Dataset | IterableDataset ] | None = None ,
1198
+ processing_class : PreTrainedTokenizerBase | None = None ,
1199
+ train_reward_processing_classes : PreTrainedTokenizerBase | list [PreTrainedTokenizerBase ] | None = None ,
1200
+ eval_reward_processing_classes : PreTrainedTokenizerBase | list [PreTrainedTokenizerBase ] | None = None ,
1201
+ ** kwargs ,
1202
+ ):
1203
+ super ().__init__ (
1204
+ model = model ,
1205
+ reward_funcs = train_reward_funcs ,
1206
+ args = args ,
1207
+ train_dataset = train_dataset ,
1208
+ eval_dataset = eval_dataset ,
1209
+ processing_class = processing_class ,
1210
+ reward_processing_classes = train_reward_processing_classes ,
1211
+ ** kwargs ,
1212
+ )
1213
+
1214
+ # Store training reward functions reference
1215
+ self .train_reward_funcs = self .reward_funcs
1216
+ self .train_reward_processing_classes = self .reward_processing_classes
1217
+
1218
+ if eval_reward_funcs is not None :
1219
+ # Okay we have some custom evaluation reward functions, set them up
1220
+
1221
+ if "compute_metrics" in kwargs :
1222
+ logger .warning (
1223
+ "Please make sure your custom compute_metrics function is using the"
1224
+ " right evaluation reward functions."
1225
+ )
1226
+
1227
+ # Matching reward_funcs processing
1228
+ if not isinstance (eval_reward_funcs , list ):
1229
+ eval_reward_funcs = [eval_reward_funcs ]
1230
+ for i , reward_func in enumerate (eval_reward_funcs ):
1231
+ if isinstance (reward_func , str ):
1232
+ eval_reward_funcs [i ] = AutoModelForSequenceClassification .from_pretrained (
1233
+ reward_func , num_labels = 1 , ** self ._model_init_kwargs
1234
+ )
1235
+ self .eval_reward_funcs = eval_reward_funcs
1236
+ self .eval_reward_processing_classes = self ._make_reward_processing_classes (
1237
+ eval_reward_funcs , eval_reward_processing_classes
1238
+ )
1239
+ else :
1240
+ # We don't have any, so we just reuse the training ones
1241
+ self .eval_reward_funcs = self .train_reward_funcs
1242
+ self .eval_reward_processing_classes = self .train_reward_processing_classes
1243
+
1244
+ def _compute_rewards_per_func (self , inputs , prompts : list [str ], completions : list [str ], device ) -> torch .Tensor :
1245
+ if self .control .should_evaluate :
1246
+ reward_funcs = self .eval_reward_funcs
1247
+ reward_processing_classes = self .eval_reward_processing_classes
1248
+ else :
1249
+ reward_funcs = self .train_reward_funcs
1250
+ reward_processing_classes = self .train_reward_processing_classes
1251
+
1252
+ rewards_per_func = torch .zeros (len (prompts ), len (reward_funcs ), device = device )
1253
+ for i , (reward_func , reward_processing_class ) in enumerate (
1254
+ zip (reward_funcs , reward_processing_classes , strict = True )
1255
+ ):
1256
+ if isinstance (reward_func , nn .Module ): # Module instead of PretrainedModel for compat with compiled models
1257
+ reward_func_name = f"reward { reward_func .config ._name_or_path .split ('/' )[- 1 ]} "
1258
+ else :
1259
+ reward_func_name = reward_func .__name__
1260
+ with profiling_context (self , reward_func_name ):
1261
+ if isinstance (
1262
+ reward_func , nn .Module
1263
+ ): # Module instead of PretrainedModel for compat with compiled models
1264
+ if is_conversational (inputs [0 ]):
1265
+ messages = [{"messages" : p + c } for p , c in zip (prompts , completions , strict = True )]
1266
+ texts = [apply_chat_template (x , reward_processing_class )["text" ] for x in messages ]
1267
+ else :
1268
+ texts = [p + c for p , c in zip (prompts , completions , strict = True )]
1269
+ reward_inputs = reward_processing_class (
1270
+ texts , return_tensors = "pt" , padding = True , padding_side = "right" , add_special_tokens = False
1271
+ )
1272
+ reward_inputs = super ()._prepare_inputs (reward_inputs )
1273
+ with torch .inference_mode ():
1274
+ rewards_per_func [:, i ] = reward_func (** reward_inputs ).logits [:, 0 ] # Shape (B*G,)
1275
+ else :
1276
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
1277
+ keys = [key for key in inputs [0 ] if key not in {"prompt" , "completion" }]
1278
+ reward_kwargs = {key : [example [key ] for example in inputs ] for key in keys }
1279
+ output_reward_func = reward_func (prompts = prompts , completions = completions , ** reward_kwargs )
1280
+ rewards_per_func [:, i ] = torch .tensor (output_reward_func , dtype = torch .float32 , device = device )
1281
+
1282
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
1283
+ # completions may be distributed across processes
1284
+ return gather (rewards_per_func )
1285
+
1286
+ def compute_reward_metrics (self , eval_prediction : EvalPrediction ) -> dict [str , float ]:
1287
+ if not self .control .should_evaluate :
1288
+ raise RuntimeError ("We are supposed to be in evaluation mode." )
1289
+
1290
+ avg_reward_per_func = eval_prediction .predictions .mean (axis = 0 )
1291
+ metrics : dict [str , float ] = {}
1292
+ for i , reward_func in enumerate (self .eval_reward_funcs ):
1293
+ if isinstance (reward_func , PreTrainedModel ):
1294
+ reward_func_name = reward_func .config ._name_or_path .split ("/" )[- 1 ]
1295
+ else :
1296
+ reward_func_name = reward_func .__name__
1297
+ metrics [f"rewards/{ reward_func_name } " ] = avg_reward_per_func [i ].item ()
1298
+ return metrics
0 commit comments