diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 9757f5bf4e..1ce0db98d8 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -151,28 +151,20 @@ def __init__(self, cfg: DictConfig) -> None: self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) - # activation checkpointing/offloading - self._enable_activation_checkpointing = cfg.get( - "enable_activation_checkpointing", False - ) - self._enable_activation_offloading = cfg.get( - "enable_activation_offloading", False - ) - if self._enable_activation_offloading: - if self._device.type != "cuda": + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: raise RuntimeError( - "enable_activation_offloading should only be True when training on CUDA" + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." ) - if not self._enable_activation_checkpointing: + if self._gradient_accumulation_steps > 1: raise RuntimeError( - "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." ) - elif self._enable_activation_checkpointing: - log.info( - "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." - ) # activation checkpointing/offloading self._enable_activation_checkpointing = cfg.get( @@ -203,7 +195,6 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._clip_grad_norm = cfg.get("clip_grad_norm", None) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -720,7 +711,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - _, rank = training.get_world_size_and_rank() + world_size, rank = training.get_world_size_and_rank() # zero out the gradients before starting training if not self._optimizer_in_bwd: @@ -787,32 +778,43 @@ def train(self) -> None: # Compute loss # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - running_loss += self._loss_fn(logits, labels) * current_num_tokens + current_loss = self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits + running_loss += current_loss + + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: + torch.distributed.all_reduce(num_tokens) + torch.distributed.all_reduce(running_loss) + current_loss = current_loss / num_tokens + + current_loss.backward() + # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - loss = running_loss / num_tokens - loss.backward() - if self._clip_grad_norm is not None: - if self._optimizer_in_bwd: - raise NotImplementedError( - "Gradient clipping is not supported after optimizer-in-the-backward." - ) - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = loss.item() + loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" @@ -833,7 +835,8 @@ def train(self) -> None: else self._optim_ckpt_wrapper ), ), - "tokens_per_second_per_gpu": num_tokens / time_per_step, + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), } if self._log_peak_memory_stats: log_dict.update( diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 6819b6c210..4a44b233e8 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -141,6 +141,20 @@ def __init__(self, cfg: DictConfig) -> None: self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._optimizer_in_bwd = cfg.optimizer_in_bwd + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) # activation checkpointing/offloading self._enable_activation_checkpointing = cfg.get( @@ -164,14 +178,6 @@ def __init__(self, cfg: DictConfig) -> None: "Enabling activation offloading should reduce memory further." ) - # TODO: find a better place / way to perform validation of args that don't yet - # compose with each other. - if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: - raise RuntimeError( - "Gradient accumulation is not supported with optimizer in bwd." - "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." - ) - # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -179,7 +185,6 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._clip_grad_norm = cfg.get("clip_grad_norm", None) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -686,18 +691,19 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - running_loss += self._loss_step(batch) * current_num_tokens + current_loss = self._loss_step(batch) * current_num_tokens + running_loss += current_loss + current_loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - loss = running_loss / num_tokens - loss.backward() - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) if not self._optimizer_in_bwd: + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -706,7 +712,7 @@ def train(self) -> None: self._lr_scheduler.step() self.global_step += 1 - loss_to_log = loss.item() + loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index d17e480ba6..b40be9e89e 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -821,7 +821,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - _, rank = training.get_world_size_and_rank() + world_size, rank = training.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -857,7 +857,7 @@ def train(self) -> None: ): torch.cuda.memory._record_memory_history() - batch = {k: v.to(self._device) for k, v in batch.items()} + utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step @@ -869,13 +869,22 @@ def train(self) -> None: class_loss, kd_loss = self._loss_step(batch) running_class_loss += class_loss * current_num_tokens running_kd_loss += kd_loss * current_num_tokens + current_loss = ( + 1 - self._kd_ratio + ) * class_loss + self._kd_ratio * kd_loss + current_loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - class_loss = running_class_loss / num_tokens - kd_loss = running_kd_loss / num_tokens - loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss - loss.backward() + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_class_loss) + torch.distributed.all_reduce(running_kd_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + class_loss_to_log = running_class_loss.item() / num_tokens + kd_loss_to_log = running_kd_loss.item() / num_tokens self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -903,7 +912,8 @@ def train(self) -> None: "class_loss": class_loss_to_log, "kd_loss": kd_loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": num_tokens / time_per_step, + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), } if self._log_peak_memory_stats: log_dict.update( diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index 4c97d6829d..1a2c3f0e4b 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -704,15 +704,14 @@ def train(self) -> None: class_loss, kd_loss = self._loss_step(batch) running_class_loss += class_loss * current_num_tokens running_kd_loss += kd_loss * current_num_tokens + current_loss = ( + 1 - self._kd_ratio + ) * class_loss + self._kd_ratio * kd_loss + current_loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - class_loss = running_class_loss / num_tokens - kd_loss = running_kd_loss / num_tokens - loss = ( - 1 - self._kd_ratio - ) * class_loss + self._kd_ratio * kd_loss - loss.backward() + training.scale_grads(self._model, 1 / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -724,8 +723,8 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - class_loss_to_log = class_loss.item() - kd_loss_to_log = kd_loss.item() + class_loss_to_log = running_class_loss.item() / num_tokens + kd_loss_to_log = running_kd_loss.item() / num_tokens loss_to_log = ( 1 - self._kd_ratio ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 7f724c2e66..418c823344 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -748,7 +748,7 @@ def train(self) -> None: # clean up before training begins training.cleanup_before_training() - _, rank = training.get_world_size_and_rank() + world_size, rank = training.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -812,15 +812,22 @@ def train(self) -> None: # Compute loss # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - running_loss += self._loss_fn(logits, labels) * current_num_tokens + current_loss = self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits + running_loss += current_loss + current_loss.backward() + # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - loss = running_loss / num_tokens - loss.backward() + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -833,7 +840,7 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = loss.item() + loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" @@ -848,7 +855,8 @@ def train(self) -> None: log_dict = { "loss": loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": num_tokens / time_per_step, + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), } if self._log_peak_memory_stats: log_dict.update( diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index bc4018b810..50a61c1c0b 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -692,12 +692,13 @@ def train(self) -> None: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - running_loss += self._loss_step(batch) * current_num_tokens + current_loss = self._loss_step(batch) * current_num_tokens + running_loss += current_loss + current_loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - loss = running_loss / num_tokens - loss.backward() + training.scale_grads(self._model, 1 / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -709,7 +710,7 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = loss.item() + loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index afb8e8d0e8..4126f95bd5 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -599,8 +599,7 @@ def train(self) -> None: """ # clean up before training begins training.cleanup_before_training() - - _, rank = training.get_world_size_and_rank() + world_size, rank = training.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -668,18 +667,16 @@ def train(self) -> None: # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step + + utils.batch_to_device(batch, self._device) + current_num_tokens = ( batch["labels"] != self._loss_fn.ignore_index ).sum() num_tokens += current_num_tokens + labels = batch.pop("labels") - labels = labels.to(self._device) - mask = mask.to(self._device) if mask is not None else None - input_pos = ( - input_pos.to(self._device) if input_pos is not None else None - ) - - logits = self._model(tokens, mask=mask, input_pos=input_pos) + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -692,14 +689,22 @@ def train(self) -> None: logits = logits.reshape(-1, logits.size(-1)) # Compute loss - running_loss += self._loss_fn(logits, labels) * current_num_tokens + current_loss = self._loss_fn(logits, labels) * current_num_tokens + # free logits otherwise it peaks backward memory del logits + running_loss += current_loss + current_loss.backward() + # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - loss = running_loss / num_tokens - loss.backward() + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -707,7 +712,7 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = loss.item() + loss_to_log = running_loss.item() / num_tokens pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" @@ -722,7 +727,9 @@ def train(self) -> None: log_dict = { "loss": loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": num_tokens / time_per_step, + "tokens_per_second_per_gpu": ( + num_tokens / time_per_step * world_size + ), } if self._log_peak_memory_stats: log_dict.update( diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index a381b6ce58..f1f4256411 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -45,21 +45,20 @@ def _get_test_config_overrides(self): def _fetch_expected_loss_values(self, model_type): loss_values_map = { - "llama2": [10.5136, 10.4813, 10.5088, 10.5250], - "llama3": [12.0673, 11.9072, 11.9302, 11.9355], + "llama2": [10.5209, 10.5217, 10.4945, 10.5136], + "llama3": [11.9839, 11.9684, 11.9596, 11.93656], } return loss_values_map[model_type] @pytest.mark.integration_test @pytest.mark.parametrize( - "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps", + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", [ - ("llama2/7B_full", "llama2", "hf", 1, 4), - ("llama3/8B_full", "llama3", "tune", 1, 4), - ("llama3/8B_full", "llama3", "tune", 4, 1), + ("llama2/7B_full", "llama2", "hf", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ("llama3/8B_full", "llama3", "tune", 4, 1, True), ], ) - @pytest.mark.parametrize("optim_in_bwd", [True, False]) @gpu_test(gpu_count=2) def test_loss( self, diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 6d3bea10c6..819c70fdf0 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -46,7 +46,6 @@ def _get_test_config_overrides(self): "lr_scheduler.num_warmup_steps=0", "lr_scheduler.num_cycles=0", "log_every_n_steps=1", - "clip_grad_norm=100", ] + dummy_alpaca_dataset_config() def _fetch_expected_loss_values(self, model_type): @@ -94,7 +93,6 @@ def test_loss( --config {config} \ batch_size={micro_batch_size} \ gradient_accumulation_steps={gradient_accumulation_steps} \ - optimizer_in_bwd={optimizer_in_bwd} \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -109,7 +107,14 @@ def test_loss( model_config = MODEL_TEST_CONFIGS[model_type] cmd = cmd + self._get_test_config_overrides() + model_config - + # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing + # wrong grad_norm, so we only test one of them each time. But loss values + # should be the same. + if not optimizer_in_bwd: + cmd.append("clip_grad_norm=100") + cmd.append("optimizer_in_bwd=False") + else: + cmd.append("optimizer_in_bwd=True") monkeypatch.setattr(sys, "argv", cmd) with pytest.raises(SystemExit, match=""): runpy.run_path(TUNE_PATH, run_name="__main__") diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index c8515b43c4..4943e1559b 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -46,8 +46,8 @@ def _fetch_expected_loss_values(self, model_type): # These values have been validated against single device recipe test via # https://gist.github.com/ebsmothers/f1c3db7c66655a23a91e0290360960c4 loss_values_map = { - "llama2": [10.5136, 10.4856, 10.5292, 10.5345], - "llama3": [11.9325, 11.9325, 11.9325, 11.9369], + "llama2": [10.5209, 10.5269, 10.5130, 10.5242], + "llama3": [11.9839, 11.9691, 11.9617, 11.9383], } return loss_values_map[model_type] diff --git a/tests/recipes/test_qat_distributed.py b/tests/recipes/test_qat_distributed.py index f5174fb46a..34dd190125 100644 --- a/tests/recipes/test_qat_distributed.py +++ b/tests/recipes/test_qat_distributed.py @@ -45,8 +45,8 @@ def _get_test_config_overrides(self): def _fetch_expected_loss_values(self, model_type): loss_values_map = { - "llama2": [10.5164, 10.4830, 10.5138, 10.5199], - "llama3": [12.0672, 11.9067, 11.9304, 11.9351], + "llama2": [10.5211, 10.5217, 10.4944, 10.5134], + "llama3": [11.9836, 11.9683, 11.9594, 11.9366], } return loss_values_map[model_type] @@ -56,7 +56,7 @@ def _fetch_expected_loss_values(self, model_type): [ ("llama2/7B_qat_full", "llama2", "hf", 4, 1), ("llama3/8B_qat_full", "llama3", "tune", 4, 1), - ("llama3/8B_qat_full", "llama3", "tune", 4, 1), + ("llama3/8B_qat_full", "llama3", "tune", 1, 4), ], ) @gpu_test(gpu_count=2) diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index db52e44cbd..a1e1cdbd73 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -27,6 +27,7 @@ shard_model, validate_no_params_on_meta_device, ) +from torchtune.training._grad_scaler import scale_grads from torchtune.training._profiler import ( DEFAULT_PROFILE_DIR, DEFAULT_PROFILER_ACTIVITIES, @@ -137,4 +138,5 @@ "NoOpManager", "OffloadActivations", "FormattedCheckpointFiles", + "scale_grads", ] diff --git a/torchtune/training/_grad_scaler.py b/torchtune/training/_grad_scaler.py new file mode 100644 index 0000000000..aab938bc90 --- /dev/null +++ b/torchtune/training/_grad_scaler.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + + +def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: + """ + Utility to scale the gradients of a model. + This is useful for gradient accumulation where we want to normalize + the gradients by the total number of tokens seen. + + Inputs: + model (nn.Module): model whose gradients should be scaled + scaler (torch.Tensor): scaling factor to apply to the gradients + + Outputs: + None (grad fields are modified in place) + """ + for p in model.parameters(): + if p.grad is not None: + p.grad *= scaler