Description
I've noticed that tokens_per_second
fluctuates wildly depending on the dataset, padding, masking strategies, etc. I believe the reason for this is that we only consider unmasked tokens as part of the tps calculation (here).
This is somewhat misleading because the masked tokens are still processed during the forward pass (and possibly the backward pass, but not 100% certain). So we are expending FLOPs that are not being counted in the tps calculation.
This also leads to confusing situations where the exact same dataset with masking causes the tps to drop precipitously even though the same (or potentially even less) computation is happening under the hood. This makes the metric somewhat meaningless to understand how fast we are actually training.
I'm happy to send a PR to update if the team agrees with this take. Another option could be to add both numbers and e.g. have active_tps
and total_tps
.