Skip to content

Fixed reporting of single value of loss and ppl across devices. #496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions QEfficient/finetune/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@

def get_num_ddp_devices():
return int(os.getenv("WORLD_SIZE", 1))


def is_rank_zero():
return int(os.getenv("LOCAL_RANK", 0)) == 0
106 changes: 48 additions & 58 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tqdm import tqdm

from QEfficient.finetune.configs.training import TrainConfig
from QEfficient.finetune.utils.helper import get_num_ddp_devices, is_rank_zero

try:
import torch_qaic # noqa: F401
Expand Down Expand Up @@ -83,10 +84,7 @@ def train(
max_steps_reached = False # Flag to indicate max training steps reached

tensorboard_updates = None
if train_config.enable_ddp:
if local_rank == 0:
tensorboard_updates = SummaryWriter()
else:
if is_rank_zero():
tensorboard_updates = SummaryWriter()

device_type = torch.device(device).type
Expand Down Expand Up @@ -232,7 +230,7 @@ def train(
total_loss += loss.detach().float()

if train_config.enable_ddp:
if local_rank == 0:
if is_rank_zero():
if loss <= train_config.convergence_loss:
loss_0_counter += 1
else:
Expand All @@ -244,10 +242,7 @@ def train(
else:
loss_0_counter = torch.tensor([0]).to(device)

if train_config.enable_ddp:
if local_rank == 0:
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)
else:
if is_rank_zero():
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps)

if train_config.save_metrics:
Expand Down Expand Up @@ -353,35 +348,42 @@ def train(
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size))
)
if train_config.task_type == "seq_classification":
metric_val = acc_helper.compute()
train_epoch_metric = acc_helper.compute()
acc_helper.reset()
else:
metric_val = torch.exp(train_epoch_loss)
train_epoch_metric = torch.exp(train_epoch_loss)

train_metric.append(float(metric_val))
train_metric.append(float(train_epoch_metric))
train_loss.append(float(train_epoch_loss))

if train_config.enable_ddp:
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM)
train_epoch_loss /= get_num_ddp_devices()
dist.all_reduce(train_epoch_metric, op=dist.ReduceOp.SUM)
train_epoch_metric /= get_num_ddp_devices()

# Update the learning rate as needed
lr_scheduler.step()

if train_config.run_validation:
if train_config.enable_ddp:
dist.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving line #368 and #369 won't be any help. We can keep these here only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is code refactoring. Moved inside evaluation function.

eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
model, train_config, eval_dataloader, device
)
if local_rank == 0:
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)

else:
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper(
model, train_config, eval_dataloader, device
)
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps)
eval_loss, eval_metric, step_loss, step_metric = evaluation_helper(
model, train_config, eval_dataloader, device
)
# Print evaluation metrics
print(
f"Epoch {epoch + 1}: Eval Loss: {eval_loss.detach().cpu():.4f}, Eval metric: {eval_metric.detach().cpu():.4f}"
)
if eval_loss < best_val_loss:
best_val_loss = eval_loss
print(f"best eval loss on epoch {epoch + 1} is {best_val_loss:.4f}")

if is_rank_zero():
tensorboard_updates.add_scalars("loss", {"eval": eval_loss}, total_train_steps)
if train_config.save_metrics:
val_step_loss.extend(temp_val_loss)
val_step_metric.extend(temp_step_metric)
val_step_loss.extend(step_loss)
val_step_metric.extend(step_metric)
val_loss.append(float(eval_loss))
val_metric.append(float(eval_metric))

# saving the adapters after completion of each epoch
if train_config.save_model:
Expand All @@ -391,20 +393,9 @@ def train(
else:
model.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}")

if train_config.run_validation:
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}")
val_loss.append(float(eval_epoch_loss))
val_metric.append(float(eval_metric))
if train_config.task_type == "seq_classification":
print(
f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
)
else:
print(
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
)
print(
f"Epoch {epoch + 1}: Train epoch loss: {train_epoch_loss:.4f}, Train metric: {train_epoch_metric:.4f}, Epoch time {epoch_end_time:.2f} sec"
)

# Saving the results every epoch to plot later
if train_config.save_metrics:
Expand All @@ -421,17 +412,12 @@ def train(
)
avg_epoch_time = sum(epoch_times) / len(epoch_times)
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
avg_train_metric = sum(train_metric) / len(train_metric)
avg_train_loss = sum(train_loss) / len(train_loss)
if train_config.run_validation:
avg_eval_metric = sum(val_metric) / len(val_metric)
avg_eval_loss = sum(val_loss) / len(val_loss)

results["avg_train_metric"] = avg_train_metric
results["avg_train_loss"] = avg_train_loss
results["last_epoch_train_loss"] = train_epoch_loss
results["last_epoch_train_metric"] = train_epoch_metric
if train_config.run_validation:
results["avg_eval_metric"] = avg_eval_metric
results["avg_eval_loss"] = avg_eval_loss
results["last_epoch_eval_loss"] = eval_loss
results["last_epoch_eval_metric"] = eval_metric
results["avg_epoch_time"] = avg_epoch_time
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests also needs to be updated for this variable name changes in results dict.

results["avg_checkpoint_time"] = avg_checkpoint_time
if train_config.save_metrics:
Expand All @@ -449,6 +435,9 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric
"""
if train_config.enable_ddp:
dist.barrier()

model.eval()

if train_config.task_type == "seq_classification":
Expand All @@ -464,7 +453,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device):
val_step_loss = []
val_step_metric = []

eval_loss = 0.0 # Initialize evaluation loss
eval_loss = torch.tensor(0.0, dtype=torch.float32, device=device) # Initialize evaluation loss
device_type = torch.device(device).type

num_dummy_samples = 0
Expand Down Expand Up @@ -512,18 +501,19 @@ def evaluation_helper(model, train_config, eval_dataloader, device):

eval_loss += loss.detach().float()
# Compute average loss and metric
eval_epoch_loss = (
0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size)
)
eval_loss = 0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using the variable name train_epoch_loss for the average train loss of the epoch, it will be good to keep the name eval_epoch_loss for the average evaluation loss of the epoch to maintain uniformity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the other variables being returned from this function. Made the names consistent.

if train_config.task_type == "seq_classification":
eval_metric = acc_helper.compute()
else:
eval_metric = torch.exp(eval_epoch_loss)
eval_metric = torch.exp(eval_loss)

# Print evaluation metrics
print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}")
if train_config.enable_ddp:
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
eval_loss /= get_num_ddp_devices()
dist.all_reduce(eval_metric, op=dist.ReduceOp.SUM)
eval_metric /= get_num_ddp_devices()

return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric
return eval_loss, eval_metric, val_step_loss, val_step_metric


def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
Expand Down
Loading