-
Notifications
You must be signed in to change notification settings - Fork 45
Fixed reporting of single value of loss and ppl across devices. #496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from tqdm import tqdm | ||
|
||
from QEfficient.finetune.configs.training import TrainConfig | ||
from QEfficient.finetune.utils.helper import get_num_ddp_devices, is_rank_zero | ||
|
||
try: | ||
import torch_qaic # noqa: F401 | ||
|
@@ -83,10 +84,7 @@ def train( | |
max_steps_reached = False # Flag to indicate max training steps reached | ||
|
||
tensorboard_updates = None | ||
if train_config.enable_ddp: | ||
if local_rank == 0: | ||
tensorboard_updates = SummaryWriter() | ||
else: | ||
if is_rank_zero(): | ||
tensorboard_updates = SummaryWriter() | ||
|
||
device_type = torch.device(device).type | ||
|
@@ -232,7 +230,7 @@ def train( | |
total_loss += loss.detach().float() | ||
|
||
if train_config.enable_ddp: | ||
if local_rank == 0: | ||
if is_rank_zero(): | ||
if loss <= train_config.convergence_loss: | ||
loss_0_counter += 1 | ||
else: | ||
|
@@ -244,10 +242,7 @@ def train( | |
else: | ||
loss_0_counter = torch.tensor([0]).to(device) | ||
|
||
if train_config.enable_ddp: | ||
if local_rank == 0: | ||
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps) | ||
else: | ||
if is_rank_zero(): | ||
tensorboard_updates.add_scalars("loss", {"train": loss}, total_train_steps) | ||
|
||
if train_config.save_metrics: | ||
|
@@ -353,35 +348,42 @@ def train( | |
else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size)) | ||
) | ||
if train_config.task_type == "seq_classification": | ||
metric_val = acc_helper.compute() | ||
train_epoch_metric = acc_helper.compute() | ||
acc_helper.reset() | ||
else: | ||
metric_val = torch.exp(train_epoch_loss) | ||
train_epoch_metric = torch.exp(train_epoch_loss) | ||
|
||
train_metric.append(float(metric_val)) | ||
train_metric.append(float(train_epoch_metric)) | ||
train_loss.append(float(train_epoch_loss)) | ||
|
||
if train_config.enable_ddp: | ||
dist.all_reduce(train_epoch_loss, op=dist.ReduceOp.SUM) | ||
train_epoch_loss /= get_num_ddp_devices() | ||
dist.all_reduce(train_epoch_metric, op=dist.ReduceOp.SUM) | ||
train_epoch_metric /= get_num_ddp_devices() | ||
|
||
# Update the learning rate as needed | ||
lr_scheduler.step() | ||
|
||
if train_config.run_validation: | ||
if train_config.enable_ddp: | ||
dist.barrier() | ||
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper( | ||
model, train_config, eval_dataloader, device | ||
) | ||
if local_rank == 0: | ||
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) | ||
|
||
else: | ||
eval_epoch_loss, eval_metric, temp_val_loss, temp_step_metric = evaluation_helper( | ||
model, train_config, eval_dataloader, device | ||
) | ||
tensorboard_updates.add_scalars("loss", {"eval": eval_epoch_loss}, total_train_steps) | ||
eval_loss, eval_metric, step_loss, step_metric = evaluation_helper( | ||
model, train_config, eval_dataloader, device | ||
) | ||
# Print evaluation metrics | ||
print( | ||
f"Epoch {epoch + 1}: Eval Loss: {eval_loss.detach().cpu():.4f}, Eval metric: {eval_metric.detach().cpu():.4f}" | ||
) | ||
if eval_loss < best_val_loss: | ||
best_val_loss = eval_loss | ||
print(f"best eval loss on epoch {epoch + 1} is {best_val_loss:.4f}") | ||
|
||
if is_rank_zero(): | ||
tensorboard_updates.add_scalars("loss", {"eval": eval_loss}, total_train_steps) | ||
if train_config.save_metrics: | ||
val_step_loss.extend(temp_val_loss) | ||
val_step_metric.extend(temp_step_metric) | ||
val_step_loss.extend(step_loss) | ||
val_step_metric.extend(step_metric) | ||
val_loss.append(float(eval_loss)) | ||
val_metric.append(float(eval_metric)) | ||
|
||
# saving the adapters after completion of each epoch | ||
if train_config.save_model: | ||
|
@@ -391,20 +393,9 @@ def train( | |
else: | ||
model.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}") | ||
|
||
if train_config.run_validation: | ||
if eval_epoch_loss < best_val_loss: | ||
best_val_loss = eval_epoch_loss | ||
print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}") | ||
val_loss.append(float(eval_epoch_loss)) | ||
val_metric.append(float(eval_metric)) | ||
if train_config.task_type == "seq_classification": | ||
print( | ||
f"Epoch {epoch + 1}: train_acc={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" | ||
) | ||
else: | ||
print( | ||
f"Epoch {epoch + 1}: train_metric={metric_val:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" | ||
) | ||
print( | ||
f"Epoch {epoch + 1}: Train epoch loss: {train_epoch_loss:.4f}, Train metric: {train_epoch_metric:.4f}, Epoch time {epoch_end_time:.2f} sec" | ||
) | ||
|
||
# Saving the results every epoch to plot later | ||
if train_config.save_metrics: | ||
|
@@ -421,17 +412,12 @@ def train( | |
) | ||
avg_epoch_time = sum(epoch_times) / len(epoch_times) | ||
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0 | ||
avg_train_metric = sum(train_metric) / len(train_metric) | ||
avg_train_loss = sum(train_loss) / len(train_loss) | ||
if train_config.run_validation: | ||
avg_eval_metric = sum(val_metric) / len(val_metric) | ||
avg_eval_loss = sum(val_loss) / len(val_loss) | ||
|
||
results["avg_train_metric"] = avg_train_metric | ||
results["avg_train_loss"] = avg_train_loss | ||
results["last_epoch_train_loss"] = train_epoch_loss | ||
results["last_epoch_train_metric"] = train_epoch_metric | ||
if train_config.run_validation: | ||
results["avg_eval_metric"] = avg_eval_metric | ||
results["avg_eval_loss"] = avg_eval_loss | ||
results["last_epoch_eval_loss"] = eval_loss | ||
results["last_epoch_eval_metric"] = eval_metric | ||
results["avg_epoch_time"] = avg_epoch_time | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tests also needs to be updated for this variable name changes in results dict. |
||
results["avg_checkpoint_time"] = avg_checkpoint_time | ||
if train_config.save_metrics: | ||
|
@@ -449,6 +435,9 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
Returns: eval_epoch_loss, eval_metric, eval_step_loss, eval_step_metric | ||
""" | ||
if train_config.enable_ddp: | ||
dist.barrier() | ||
|
||
model.eval() | ||
|
||
if train_config.task_type == "seq_classification": | ||
|
@@ -464,7 +453,7 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
val_step_loss = [] | ||
val_step_metric = [] | ||
|
||
eval_loss = 0.0 # Initialize evaluation loss | ||
eval_loss = torch.tensor(0.0, dtype=torch.float32, device=device) # Initialize evaluation loss | ||
device_type = torch.device(device).type | ||
|
||
num_dummy_samples = 0 | ||
|
@@ -512,18 +501,19 @@ def evaluation_helper(model, train_config, eval_dataloader, device): | |
|
||
eval_loss += loss.detach().float() | ||
# Compute average loss and metric | ||
eval_epoch_loss = ( | ||
0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size) | ||
) | ||
eval_loss = 0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are using the variable name train_epoch_loss for the average train loss of the epoch, it will be good to keep the name eval_epoch_loss for the average evaluation loss of the epoch to maintain uniformity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check the other variables being returned from this function. Made the names consistent. |
||
if train_config.task_type == "seq_classification": | ||
eval_metric = acc_helper.compute() | ||
else: | ||
eval_metric = torch.exp(eval_epoch_loss) | ||
eval_metric = torch.exp(eval_loss) | ||
|
||
# Print evaluation metrics | ||
print(f" {eval_metric.detach().cpu()=} {eval_epoch_loss.detach().cpu()=}") | ||
if train_config.enable_ddp: | ||
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) | ||
eval_loss /= get_num_ddp_devices() | ||
dist.all_reduce(eval_metric, op=dist.ReduceOp.SUM) | ||
eval_metric /= get_num_ddp_devices() | ||
quic-mamta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return eval_epoch_loss, eval_metric, val_step_loss, val_step_metric | ||
return eval_loss, eval_metric, val_step_loss, val_step_metric | ||
|
||
|
||
def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving line #368 and #369 won't be any help. We can keep these here only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is code refactoring. Moved inside evaluation function.