Skip to content

prefetch_factor*worker datums thrown away because of _check_dataloader_iterable #18414

@ben-davidson-6

Description

@ben-davidson-6

Bug description

If you have a dataset which just pops things off a queue and you set persistent_workers=True, num_workers > 0 in your dataloader, then the first two items on the queue are thrown away (as we prefetch num_workers*prefetch_factor).

This is because we call https://github.com/Lightning-AI/lightning/blob/722fdeac44cce49928184d89684eeb668742bf37/src/lightning/pytorch/trainer/connectors/data_connector.py#L391 in the training loop. This starts the dataloading process which fills the prefetch buffer, this buffer is then tossed once we start the first epoch.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

python
import multiprocessing as mp
from queue import Queue
from typing import Iterator

import numpy as np
from torch.utils.data import DataLoader, IterableDataset

from lightning import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel


class QueueDataset(IterableDataset):
    def __init__(self, queue: Queue) -> None:
        super().__init__()
        self.queue = queue

    def __iter__(self) -> Iterator:
        for k in range(5):
            print(f"getting {k}")
            tensor, index = self.queue.get(timeout=10)
            print(f"got {index}")
            yield tensor


if __name__ == "__main__":
    q = mp.Queue()
    arr = np.random.random([1, 32]).astype(np.float32)
    for ind in range(5):
        q.put((arr, ind))
    dataloader = DataLoader(QueueDataset(q), num_workers=1, batch_size=None, persistent_workers=True)
    trainer = Trainer(max_epochs=1, enable_progress_bar=False)
    trainer.fit(BoringModel(), dataloader)

Error messages and logs

getting 0
got 0
getting 1
got 1
getting 0
got 2
getting 1
got 3
getting 2
got 4
getting 3

Then we get the _queue.Empty exception as it times out since the queue is empty

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @justusschock @awaelchli @Borda

Activity

added
data handlingGeneric data-related topic
and removed
needs triageWaiting to be triaged by maintainers
on Aug 28, 2023
self-assigned this
on Aug 28, 2023
added this to the 2.0.x milestone on Aug 28, 2023
awaelchli

awaelchli commented on Aug 28, 2023

@awaelchli
Contributor

Hi @ben-davidson-6
We need to run this check to give the user meaningful feedback if they return something invalid from the dataloader methods, and there is no known reliable check whether an object is iterable other than calling iter() on it. However, we can mitigate the issue with a fast-path check on DataLoader #18415.

Thanks for raising the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

    Participants

    @ben-davidson-6@awaelchli

    Issue actions

      prefetch_factor*worker datums thrown away because of _check_dataloader_iterable · Issue #18414 · Lightning-AI/pytorch-lightning