Decoupling generation and loss batch sizes #1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This introduces a
per_device_loss_batch_size
to define microbatches to be used when computing the loss. Ideally, I would have liked to compute the loss in chunks ofper_device_loss_batch_size
and accumulate gradients. However, to compute the advantage, we need allper_device_train_batch_size * num_generations
samples.So instead, we compute the three tensors needed for the loss (reward, logp, KL) in chunks of
per_device_loss_batch_size
, concatenate the chunks, and compute the full loss all at once. I think this should result in a similar memory reduction, but it remains to be tested.I also think this code is pretty compilation-unfriendly, since I'm slicing tensors dynamically. Oh well.