Skip to content

Normalize CE loss by total number of (non-padding) tokens #1875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,13 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
Expand All @@ -683,17 +689,17 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
if self._optimizer_in_bwd:
raise NotImplementedError(
Expand All @@ -710,7 +716,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
27 changes: 18 additions & 9 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,15 +625,22 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

loss = self._loss_step(batch)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -648,7 +655,7 @@ def train(self) -> None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand All @@ -662,9 +669,11 @@ def train(self) -> None:
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": get_lr(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper,
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
Expand Down
25 changes: 17 additions & 8 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,17 +687,26 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

batch = {k: v.to(self._device) for k, v in batch.items()}
num_tokens += batch["tokens"].numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

class_loss, kd_loss = self._loss_step(batch)
loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss
loss = loss / self._gradient_accumulation_steps
running_class_loss += class_loss / self._gradient_accumulation_steps
running_kd_loss += kd_loss / self._gradient_accumulation_steps
loss.backward()
running_class_loss += class_loss * current_num_tokens
running_kd_loss += kd_loss * current_num_tokens

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
class_loss = running_class_loss / num_tokens
kd_loss = running_kd_loss / num_tokens
loss = (
1 - self._kd_ratio
) * class_loss + self._kd_ratio * kd_loss
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -709,8 +718,8 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

class_loss_to_log = running_class_loss.item()
kd_loss_to_log = running_kd_loss.item()
class_loss_to_log = class_loss.item()
kd_loss_to_log = kd_loss.item()
loss_to_log = (
1 - self._kd_ratio
) * class_loss_to_log + self._kd_ratio * kd_loss_to_log
Expand Down
20 changes: 13 additions & 7 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,13 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
Expand All @@ -783,17 +789,17 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_fn(logits, labels) * current_num_tokens

# free logits otherwise it peaks backward memory
del logits

loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -806,7 +812,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
20 changes: 13 additions & 7 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,6 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)

# free logits otherwise it peaks backward memory
Expand Down Expand Up @@ -679,15 +678,22 @@ def train(self) -> None:
torch.cuda.memory._record_memory_history()

utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

loss = self._loss_step(batch)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
running_loss += self._loss_step(batch) * current_num_tokens

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand All @@ -699,7 +705,7 @@ def train(self) -> None:
# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
20 changes: 13 additions & 7 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,14 @@ def train(self) -> None:
self._model.apply(enable_fq)

tokens = tokens.to(self._device)
num_tokens += tokens.numel()

# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens

labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
Expand All @@ -679,23 +686,22 @@ def train(self) -> None:
logits = logits.reshape(-1, logits.size(-1))

# Compute loss
loss = self._loss_fn(logits, labels)
running_loss += self._loss_fn(logits, labels) * current_num_tokens
# free logits otherwise it peaks backward memory
del logits

loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
loss = running_loss / num_tokens
loss.backward()

self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
loss_to_log = loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
Expand Down
16 changes: 8 additions & 8 deletions tests/recipes/test_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
class TestFullFinetuneDistributedRecipe:
def _get_test_config_overrides(self):
return [
"batch_size=4",
"dtype=fp32",
"enable_activation_checkpointing=False",
"dataset.train_on_input=False",
Expand All @@ -52,21 +51,22 @@ def _fetch_expected_loss_values(self, model_type):

@pytest.mark.integration_test
@pytest.mark.parametrize(
"config, model_type, ckpt_type, fsdp_sharding_strategy",
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps",
[
("llama2/7B_full", "llama2", "hf", None),
("llama3/8B_full", "llama3", "tune", None),
("llama3/8B_full", "llama3", "tune", "NO_SHARD"),
("llama2/7B_full", "llama2", "hf", 1, 4),
("llama3/8B_full", "llama3", "tune", 1, 4),
("llama3/8B_full", "llama3", "tune", 4, 1),
],
)
@pytest.mark.parametrize("optim_in_bwd", [True, False])
@gpu_test(gpu_count=2)
def test_loss(
self,
micro_batch_size,
gradient_accumulation_steps,
config,
model_type,
ckpt_type,
fsdp_sharding_strategy,
optim_in_bwd,
tmpdir,
monkeypatch,
Expand All @@ -84,6 +84,8 @@ def test_loss(
cmd = f"""
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
--config {config} \
batch_size={micro_batch_size} \
gradient_accumulation_steps={gradient_accumulation_steps} \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
Expand All @@ -94,8 +96,6 @@ def test_loss(
tokenizer.prompt_template=null \
metric_logger.filename={log_file} \
""".split()
if fsdp_sharding_strategy:
cmd.append(f"fsdp_sharding_strategy={fsdp_sharding_strategy}")
model_config = MODEL_TEST_CONFIGS[model_type]
cmd = cmd + self._get_test_config_overrides() + model_config
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
Expand Down
Loading
Loading