Skip to content

Tokens per second calculation #2296

Open
@EugenHotaj

Description

@EugenHotaj

I've noticed that tokens_per_second fluctuates wildly depending on the dataset, padding, masking strategies, etc. I believe the reason for this is that we only consider unmasked tokens as part of the tps calculation (here).

This is somewhat misleading because the masked tokens are still processed during the forward pass (and possibly the backward pass, but not 100% certain). So we are expending FLOPs that are not being counted in the tps calculation.

This also leads to confusing situations where the exact same dataset with masking causes the tps to drop precipitously even though the same (or potentially even less) computation is happening under the hood. This makes the metric somewhat meaningless to understand how fast we are actually training.

I'm happy to send a PR to update if the team agrees with this take. Another option could be to add both numbers and e.g. have active_tps and total_tps.

Metadata

Metadata

Assignees

Labels

best practiceThings we should be doing but aren'ttriage reviewThis issue should be discussed in weekly review

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions