Open
Description
When a) train_on_input=False
and b) message is too long that output is truncated, there may be a batch without trainable tokens, raising an error on the loss because of division by zero.
Beyond raising an inconvenient bug, this is a waste of compute, and fixing the loss seems to be fixing a symptom, instead of the root cause.
In the dataloader, should we skip rows that dont have trainable embeddings?