diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 165c7ec3f7..01bc457ee3 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -665,7 +665,13 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() utils.batch_to_device(batch, self._device) - num_tokens += batch["tokens"].numel() + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") @@ -683,17 +689,17 @@ def train(self) -> None: logits = logits.reshape(-1, logits.size(-1)) # Compute loss - loss = self._loss_fn(logits, labels) + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + running_loss += self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits - loss = loss / self._gradient_accumulation_steps - running_loss += loss - loss.backward() - # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + loss = running_loss / num_tokens + loss.backward() if self._clip_grad_norm is not None: if self._optimizer_in_bwd: raise NotImplementedError( @@ -710,7 +716,7 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() + loss_to_log = loss.item() pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index b180e66e8f..c9bcf23a30 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -625,15 +625,22 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() utils.batch_to_device(batch, self._device) - num_tokens += batch["tokens"].numel() - loss = self._loss_step(batch) - loss = loss / self._gradient_accumulation_steps - running_loss += loss - loss.backward() + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + running_loss += self._loss_step(batch) * current_num_tokens # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + loss = running_loss / num_tokens + loss.backward() if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -648,7 +655,7 @@ def train(self) -> None: self._lr_scheduler.step() self.global_step += 1 - loss_to_log = running_loss.item() + loss_to_log = loss.item() pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" @@ -662,9 +669,11 @@ def train(self) -> None: # NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently # true since we don't expose the ability to configure this yet. "lr": get_lr( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper, + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), ), "tokens_per_second_per_gpu": num_tokens / time_per_step, } diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index c2ee8c7cc4..a56382f0ae 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -687,17 +687,26 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() batch = {k: v.to(self._device) for k, v in batch.items()} - num_tokens += batch["tokens"].numel() + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens class_loss, kd_loss = self._loss_step(batch) - loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss - loss = loss / self._gradient_accumulation_steps - running_class_loss += class_loss / self._gradient_accumulation_steps - running_kd_loss += kd_loss / self._gradient_accumulation_steps - loss.backward() + running_class_loss += class_loss * current_num_tokens + running_kd_loss += kd_loss * current_num_tokens # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + class_loss = running_class_loss / num_tokens + kd_loss = running_kd_loss / num_tokens + loss = ( + 1 - self._kd_ratio + ) * class_loss + self._kd_ratio * kd_loss + loss.backward() if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -709,8 +718,8 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - class_loss_to_log = running_class_loss.item() - kd_loss_to_log = running_kd_loss.item() + class_loss_to_log = class_loss.item() + kd_loss_to_log = kd_loss.item() loss_to_log = ( 1 - self._kd_ratio ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 28f2b58f5e..1624d6fcbb 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -764,7 +764,13 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() utils.batch_to_device(batch, self._device) - num_tokens += batch["tokens"].numel() + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") @@ -783,17 +789,17 @@ def train(self) -> None: logits = logits.reshape(-1, logits.size(-1)) # Compute loss - loss = self._loss_fn(logits, labels) + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + running_loss += self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits - loss = loss / self._gradient_accumulation_steps - running_loss += loss - loss.backward() - # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + loss = running_loss / num_tokens + loss.backward() if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -806,7 +812,7 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() + loss_to_log = loss.item() pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 5d39b72086..00c4659f12 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -631,7 +631,6 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: labels = labels.reshape(-1) logits = logits.reshape(-1, logits.size(-1)) - # Compute loss loss = self._loss_fn(logits, labels) # free logits otherwise it peaks backward memory @@ -679,15 +678,22 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() utils.batch_to_device(batch, self._device) - num_tokens += batch["tokens"].numel() - loss = self._loss_step(batch) - loss = loss / self._gradient_accumulation_steps - running_loss += loss - loss.backward() + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + running_loss += self._loss_step(batch) * current_num_tokens # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + loss = running_loss / num_tokens + loss.backward() if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), @@ -699,7 +705,7 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() + loss_to_log = loss.item() pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index df6eb5c2d6..6e676d0ce2 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -659,7 +659,14 @@ def train(self) -> None: self._model.apply(enable_fq) tokens = tokens.to(self._device) - num_tokens += tokens.numel() + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + labels = labels.to(self._device) mask = mask.to(self._device) if mask is not None else None input_pos = ( @@ -679,23 +686,22 @@ def train(self) -> None: logits = logits.reshape(-1, logits.size(-1)) # Compute loss - loss = self._loss_fn(logits, labels) + running_loss += self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits - loss = loss / self._gradient_accumulation_steps - running_loss += loss - loss.backward() - # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + loss = running_loss / num_tokens + loss.backward() + self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() + loss_to_log = loss.item() pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index 28e046ebd0..8e5a5fca2b 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -31,7 +31,6 @@ class TestFullFinetuneDistributedRecipe: def _get_test_config_overrides(self): return [ - "batch_size=4", "dtype=fp32", "enable_activation_checkpointing=False", "dataset.train_on_input=False", @@ -52,21 +51,22 @@ def _fetch_expected_loss_values(self, model_type): @pytest.mark.integration_test @pytest.mark.parametrize( - "config, model_type, ckpt_type, fsdp_sharding_strategy", + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps", [ - ("llama2/7B_full", "llama2", "hf", None), - ("llama3/8B_full", "llama3", "tune", None), - ("llama3/8B_full", "llama3", "tune", "NO_SHARD"), + ("llama2/7B_full", "llama2", "hf", 1, 4), + ("llama3/8B_full", "llama3", "tune", 1, 4), + ("llama3/8B_full", "llama3", "tune", 4, 1), ], ) @pytest.mark.parametrize("optim_in_bwd", [True, False]) @gpu_test(gpu_count=2) def test_loss( self, + micro_batch_size, + gradient_accumulation_steps, config, model_type, ckpt_type, - fsdp_sharding_strategy, optim_in_bwd, tmpdir, monkeypatch, @@ -84,6 +84,8 @@ def test_loss( cmd = f""" tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -94,8 +96,6 @@ def test_loss( tokenizer.prompt_template=null \ metric_logger.filename={log_file} \ """.split() - if fsdp_sharding_strategy: - cmd.append(f"fsdp_sharding_strategy={fsdp_sharding_strategy}") model_config = MODEL_TEST_CONFIGS[model_type] cmd = cmd + self._get_test_config_overrides() + model_config # "optimizer_in_bwd=True" would free gradient info before clip_grad, causing diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 170e4008d3..bd90fbbfad 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -11,8 +11,6 @@ import sys from pathlib import Path -import numpy as np - import pytest import torch @@ -35,7 +33,6 @@ class TestFullFinetuneSingleDeviceRecipe: def _get_test_config_overrides(self): return [ - "batch_size=8", "device=cpu", "dtype=fp32", "enable_activation_checkpointing=False", @@ -61,6 +58,10 @@ def _fetch_expected_loss_values(self, model_type): @pytest.mark.integration_test @pytest.mark.parametrize("compile", [True, False]) + @pytest.mark.parametrize( + "micro_batch_size, gradient_accumulation_steps, optimizer_in_bwd", + [(8, 1, True), (2, 4, False)], + ) @pytest.mark.parametrize( "config, model_type, ckpt_type", [ @@ -68,7 +69,18 @@ def _fetch_expected_loss_values(self, model_type): ("llama3/8B_full_single_device", "llama3", "tune"), ], ) - def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): + def test_loss( + self, + compile, + micro_batch_size, + gradient_accumulation_steps, + optimizer_in_bwd, + config, + model_type, + ckpt_type, + tmpdir, + monkeypatch, + ): ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] ckpt = model_type + "_" + ckpt_type ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) @@ -79,6 +91,9 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) cmd = f""" tune run full_finetune_single_device \ --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + optimizer_in_bwd={optimizer_in_bwd} \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -133,6 +148,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): cmd_1 = f""" tune run full_finetune_single_device \ --config llama2/7B_full_low_memory \ + batch_size=8 \ output_dir={tmpdir} \ checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -154,6 +170,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): cmd_2 = f""" tune run full_finetune_single_device \ --config llama2/7B_full_low_memory \ + batch_size=8 \ output_dir={tmpdir} \ checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ @@ -179,89 +196,3 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) - - -class TestFullFinetuneSingleDeviceGradientAccumulation: - def _get_test_config_overrides(self): - return [ - "device=cpu", - "dtype=fp32", - "enable_activation_checkpointing=False", - "tokenizer.path=/tmp/test-artifacts/tokenizer.model", - "tokenizer.prompt_template=null", - "dataset=tests.recipes.utils.DummyDataset", - "dataset.train_on_input=False", - "seed=9", - "epochs=1", - "max_steps_per_epoch=1", - "optimizer=torch.optim.AdamW", - "optimizer.lr=2e-5", - "log_every_n_steps=1", - "optimizer_in_bwd=False", - ] - - @pytest.mark.integration_test - def test_gradient_accumulation(self, tmpdir, monkeypatch): - """Test whether gradient accumulation runs properly in the recipe. In general - the sum of loss across minibatches should equal the loss over the full batch, - but since our loss is normalized by the number of unmasked tokens, this does not - hold in for our case. We use a dummy dataset where all tokens are unmasked, and - in this test check that a single batch size of N yields the same loss as N accumulated - microbatches of size 1. - """ - full_batch_size = 4 - micro_batch_size = 1 - gradient_accumulation_steps = full_batch_size // micro_batch_size - ckpt = "llama2_tune" - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - ckpt_dir = ckpt_path.parent - no_grad_accum_log_file = gen_log_file_name(tmpdir, suffix="no_grad_accum") - grad_accum_log_file = gen_log_file_name(tmpdir, suffix="grad_accum") - - cmd_1 = f""" - tune run full_finetune_single_device \ - --config llama2/7B_full_low_memory \ - checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ - checkpointer.checkpoint_dir={ckpt_dir} \ - checkpointer.checkpoint_files=[{ckpt_path}]\ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA2 \ - batch_size={full_batch_size} \ - output_dir={tmpdir} \ - log_every_n_steps=1 \ - metric_logger.filename={no_grad_accum_log_file} \ - """.split() - - model_config = MODEL_TEST_CONFIGS["llama2"] - cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config - - monkeypatch.setattr(sys, "argv", cmd_1) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - no_accum_loss = get_loss_values_from_metric_logger(no_grad_accum_log_file)[ - 0 - ] # List of a single element - - # Update the cmd with new values for gradient accumulation - cmd_2 = f""" - tune run full_finetune_single_device \ - --config llama2/7B_full_low_memory \ - checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ - checkpointer.checkpoint_dir={ckpt_dir} \ - checkpointer.checkpoint_files=[{ckpt_path}]\ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=llama2 \ - batch_size={micro_batch_size} \ - gradient_accumulation_steps={gradient_accumulation_steps} \ - output_dir={tmpdir} \ - metric_logger.filename={grad_accum_log_file} \ - """.split() - cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config - - monkeypatch.setattr(sys, "argv", cmd_2) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - accum_loss = np.mean(get_loss_values_from_metric_logger(grad_accum_log_file)) - torch.testing.assert_close(no_accum_loss, accum_loss, atol=1e-5, rtol=1e-5) diff --git a/tests/recipes/test_knowledge_distillation_single_device.py b/tests/recipes/test_knowledge_distillation_single_device.py index f53fec4e9e..81b1c8aba2 100644 --- a/tests/recipes/test_knowledge_distillation_single_device.py +++ b/tests/recipes/test_knowledge_distillation_single_device.py @@ -52,14 +52,21 @@ def _fetch_expected_loss_values(self, model_type): return loss_values_map[model_type] @pytest.mark.integration_test - @pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize( - "config, model_type, ckpt_type", - [ - ("qwen2/knowledge_distillation_single_device", "llama3", "tune"), - ], + "micro_batch_size, gradient_accumulation_steps, compile", + [(8, 1, False), (2, 4, True), (2, 4, False)], ) - def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): + def test_loss( + self, + micro_batch_size, + gradient_accumulation_steps, + compile, + tmpdir, + monkeypatch, + ): + config = "qwen2/knowledge_distillation_single_device" + model_type = "llama3" + ckpt_type = "tune" ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] ckpt = model_type + "_" + ckpt_type ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) @@ -71,6 +78,8 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) tune run knowledge_distillation_single_device \ --config {config} \ output_dir={tmpdir} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}] \ diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index 7777b02862..7be6a13f03 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -32,7 +32,6 @@ class TestLoRAFinetuneDistributedRecipe: def _get_test_config_overrides(self): return [ - "batch_size=4", "dataset.train_on_input=False", "seed=9", "epochs=2", @@ -40,7 +39,6 @@ def _get_test_config_overrides(self): "max_steps_per_epoch=2", "optimizer.lr=2e-5", "log_every_n_steps=1", - "gradient_accumulation_steps=1", "compile=False", ] + dummy_alpaca_dataset_config() @@ -56,13 +54,17 @@ def _fetch_expected_loss_values(self, model_type): @pytest.mark.integration_test @gpu_test(gpu_count=2) @pytest.mark.parametrize( - "reshard_after_forward", - [ - True, - False, - ], + "micro_batch_size, gradient_accumulation_steps, reshard_after_forward", + [(4, 1, True), (1, 4, False)], ) - def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): + def test_loss( + self, + micro_batch_size, + gradient_accumulation_steps, + reshard_after_forward, + tmpdir, + monkeypatch, + ): ckpt = "llama2_tune" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent @@ -70,6 +72,8 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): cmd = f""" tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -138,6 +142,8 @@ def test_training_state_on_resume( cmd_1 = f""" tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \ --config {config} \ + batch_size=4 \ + gradient_accumulation_steps=1 \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -160,6 +166,8 @@ def test_training_state_on_resume( cmd_2 = f""" tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \ --config {config} \ + batch_size=4 \ + gradient_accumulation_steps=1 \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir={tmpdir} \ @@ -206,6 +214,8 @@ def test_save_and_load_merged_weights( cmd = f""" tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed \ --config {recipe_config} \ + batch_size=4 \ + gradient_accumulation_steps=1 \ output_dir={tmpdir} \ model=torchtune.models.lora_small_test_model \ checkpointer._component_={ckpt_component} \ diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index f2d7409042..80bc5dc072 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -26,13 +26,11 @@ TOKENIZER_PATHS, ) from torchtune import config -from torchtune.utils import torch_version_ge class TestLoRAFinetuneSingleDeviceRecipe: def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): return [ - "batch_size=8", "device=cpu", f"dtype={dtype_str}", "dataset.train_on_input=False", @@ -41,7 +39,6 @@ def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): "max_steps_per_epoch=2", "optimizer.lr=2e-5", "log_every_n_steps=1", - "gradient_accumulation_steps=1", "clip_grad_norm=100", ] + dummy_alpaca_dataset_config() @@ -58,15 +55,26 @@ def _fetch_qlora_expected_loss_values(self, dtype): return [10.5198, 10.5271, 10.5131, 10.5244] @pytest.mark.integration_test - @pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize( - "config, model_type, ckpt_type", + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, compile", [ - ("llama2/7B_lora_single_device", "llama2", "meta"), - ("llama3/8B_lora_single_device", "llama3", "tune"), + ("llama2/7B_lora_single_device", "llama2", "meta", 8, 1, False), + ("llama3/8B_lora_single_device", "llama3", "tune", 2, 4, True), + ("llama2/7B_lora_single_device", "llama2", "meta", 8, 1, True), + ("llama3/8B_lora_single_device", "llama3", "tune", 2, 4, False), ], ) - def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): + def test_loss( + self, + compile, + micro_batch_size, + gradient_accumulation_steps, + config, + model_type, + ckpt_type, + tmpdir, + monkeypatch, + ): ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] ckpt = model_type + "_" + ckpt_type ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) @@ -77,6 +85,8 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) cmd = f""" tune run lora_finetune_single_device \ --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -107,13 +117,24 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) ) @pytest.mark.integration_test - @pytest.mark.parametrize("dtype", ["fp32", "bf16"]) - @pytest.mark.parametrize("compile", [True, False]) - @pytest.mark.skipif( - not torch_version_ge("2.4.0"), - reason="Please install a nightly build of torch to run this test.", + @pytest.mark.parametrize( + "dtype, compile, micro_batch_size, gradient_accumulation_steps", + [ + ("fp32", True, 8, 1), + ("bf16", True, 2, 4), + ("fp32", False, 4, 2), + ("bf16", False, 8, 1), + ], ) - def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): + def test_loss_qlora( + self, + dtype, + compile, + micro_batch_size, + gradient_accumulation_steps, + tmpdir, + monkeypatch, + ): ckpt = "llama2_meta" ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent @@ -122,6 +143,8 @@ def test_loss_qlora(self, compile, dtype, tmpdir, monkeypatch): cmd = f""" tune run lora_finetune_single_device --config llama2/7B_qlora_single_device \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelMetaCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -179,6 +202,8 @@ def test_training_state_on_resume( cmd_1 = f""" tune run lora_finetune_single_device \ --config llama2/7B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -202,6 +227,8 @@ def test_training_state_on_resume( cmd_2 = f""" tune run lora_finetune_single_device \ --config llama2/7B_lora_single_device \ + batch_size=8 \ + gradient_accumulation_steps=1 \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ diff --git a/tests/recipes/test_qat_distributed.py b/tests/recipes/test_qat_distributed.py index 5d4d7069f1..18e87a71d1 100644 --- a/tests/recipes/test_qat_distributed.py +++ b/tests/recipes/test_qat_distributed.py @@ -26,13 +26,11 @@ gpu_test, TOKENIZER_PATHS, ) -from torchao.utils import TORCH_VERSION_AFTER_2_4 class TestQATDistributedRecipe: def _get_test_config_overrides(self): return [ - "batch_size=4", "dtype=fp32", "enable_activation_checkpointing=False", "dataset.train_on_input=False", @@ -53,17 +51,24 @@ def _fetch_expected_loss_values(self, model_type): @pytest.mark.integration_test @pytest.mark.parametrize( - "config, model_type, ckpt_type", + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps", [ - ("llama2/7B_qat_full", "llama2", "hf"), - ("llama3/8B_qat_full", "llama3", "tune"), + ("llama2/7B_qat_full", "llama2", "hf", 4, 1), + ("llama3/8B_qat_full", "llama3", "tune", 4, 1), + ("llama3/8B_qat_full", "llama3", "tune", 4, 1), ], ) @gpu_test(gpu_count=2) - @pytest.mark.skipif( - not TORCH_VERSION_AFTER_2_4, reason="QAT only supported for PyTorch 2.4+" - ) - def test_loss(self, config, model_type, ckpt_type, tmpdir, monkeypatch): + def test_loss( + self, + config, + model_type, + ckpt_type, + micro_batch_size, + gradient_accumulation_steps, + tmpdir, + monkeypatch, + ): ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] ckpt = model_type + "_" + ckpt_type ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) @@ -78,6 +83,8 @@ def test_loss(self, config, model_type, ckpt_type, tmpdir, monkeypatch): tune run --nnodes 1 --nproc_per_node 2 qat_distributed \ --config {config} \ output_dir={tmpdir} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ diff --git a/torchtune/modules/loss/ce_chunked_output_loss.py b/torchtune/modules/loss/ce_chunked_output_loss.py index d05ed6018f..17a5eced36 100644 --- a/torchtune/modules/loss/ce_chunked_output_loss.py +++ b/torchtune/modules/loss/ce_chunked_output_loss.py @@ -33,7 +33,7 @@ def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): self.ignore_index = ignore_index def compute_cross_entropy( - self, logits: torch.Tensor, labels: torch.Tensor + self, logits: torch.Tensor, labels: torch.Tensor, normalize: bool = True ) -> torch.Tensor: """ Upcast logits to fp32 and compute cross entropy loss.