-
Notifications
You must be signed in to change notification settings - Fork 646
Iterable Dataset #2852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: impl-step-based-ckpt
Are you sure you want to change the base?
Iterable Dataset #2852
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2852
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Unrelated FailuresAs of commit 72211c9 with merge base 3d73591 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -94,3 +95,72 @@ def slimorca_dataset( | |||
) | |||
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len) | |||
return ds | |||
|
|||
|
|||
def slimorca_iterable_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added here to demonstrate datamix iterable dataset with this example. Personally, i dislike exposing all of the args and defaults. I would prefer to expose only whats specific to this builder.
logger.warning( | ||
f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. " | ||
"This is unexpected for an infinite dataset. Re-initializing its iterator." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not 100% sure i like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do this: simply have a subclass for InfiniteIterable
so this is super explicit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did this one land? I don't see InfiniteIterable
anywhere (personally I don't know enough yet to have a strong preference here, just wanna understand where things currently stand)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i made changes but didnt push them yet. I added a dummy class that does nothing:
class InfiniteTuneIterableDataset(TuneIterableDataset):
"""Abstract base class for infinite datasets, which yield samples indefinitely.
It only purpose is to make it explicit that the dataset is expected to be infinite, i.e.
it never exhausts. This is helpful to avoid complexity due to some rank hanging because
of lack of data""
pass
and replaced this logger.warning with raise ValueError.
I think its better to have zero tolerance. Datasets that are not infinite need work to make sure no rank hangs.
@@ -101,3 +102,64 @@ def alpaca_dataset( | |||
original Alpaca dataset, `yahma/alpaca-cleaned <https://huggingface.co/datasets/yahma/alpaca-cleaned>`_. | |||
See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details. | |||
""" | |||
|
|||
|
|||
def alpaca_iterable_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added here to demonstrate datamix iterable dataset with this example. Personally, i dislike exposing all of the args and defaults. I would prefer to expose only whats specific to this builder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you are doing this with ``load_dataset_kwargs, right? Or did you mean something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it's a function, so... get_alpaca_iterable_dataset
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the get makes sense, but its not the pattern we have in tune :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR! I mainly had a question on the interaction with packing and on the SFT transform
@@ -101,3 +102,64 @@ def alpaca_dataset( | |||
original Alpaca dataset, `yahma/alpaca-cleaned <https://huggingface.co/datasets/yahma/alpaca-cleaned>`_. | |||
See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details. | |||
""" | |||
|
|||
|
|||
def alpaca_iterable_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you are doing this with ``load_dataset_kwargs, right? Or did you mean something else?
logger.warning( | ||
f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. " | ||
"This is unexpected for an infinite dataset. Re-initializing its iterator." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do this: simply have a subclass for InfiniteIterable
so this is super explicit
from torch.utils.data import IterableDataset | ||
|
||
|
||
class TuneIterableDataset(IterableDataset, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need this guy to interact with packing and IIUC I don't believe this is currently happening?
The algo we should implement is this:
- One batch can be made of multiple calls to next. We keep taking until we exceed the max seq len. When we do, we put the last one aside (we'll use it to start the next batch), pad the current one to max len and return.
- The calls to next will go to the interleaved dataset, therefore we automatically construct mixed batches from multiple datasets without much effort
- Also, every time we call next we should make space for logging transforms (which we are, you already wrote them). I think it's ok to make your metrics transforms and aggregators an optional property here so the semantics are clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have packing here: #2819
…htune into iterable_dataset_final
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read every line of this PR. (Kidding but I tried to at least look at most of the important stuff.) Thanks for taking on this massive set of changes, I think the dataset classes are a big improvement
# Load and shard dataset | ||
ds = load_dataset(**load_dataset_kwargs) | ||
|
||
# Use to_iterable_dataset for streaming datasets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this comment. Based on L185 it seems you're using to_iterable_dataset for non-streaming datasets
|
||
# If the dataset is not streaming and has a defined length, | ||
# we cannot have num_shards > dataset_size. | ||
if not load_dataset_kwargs.get("streaming", False) and hasattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant inside the current if statement
# we will try 2*16 = 32 shards. Since 32 is not a multiple of 3, we will do 36 shards. | ||
# Each rank gets 16 shards, each dataloader worker in that rankgets 6 shards. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this math doesn't seem quite right
# we will try 2*16 = 32 shards. Since 32 is not a multiple of 3, we will do 36 shards. | |
# Each rank gets 16 shards, each dataloader worker in that rankgets 6 shards. | |
# we will try 2*16 = 32 shards. Since 32 is not a multiple of 6, we will do 36 shards. | |
# Each rank gets 18 shards, each dataloader worker in that rank gets 6 shards. |
if num_shards > dataset_size: | ||
raise ValueError( | ||
f"Number of shards ({num_shards}) is greater than the dataset size ({dataset_size})." | ||
f"Please decrease num_shards_per_rank." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but this can also happen due to total_workers > dataset_size (since we round up). Probably an edge case but if e.g. torch.utils.data.get_worker_info()
returns a really large number then this is not actually the right fix
epoch_seed = self._seed + self._num_epochs | ||
self._ds.set_epoch(epoch_seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused.. isn't this bumping epoch by self._seed
each epoch (when it should be bumping by 1)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, from their docs, it should be just: dataset.set_epoch(epoch)
I will make the change
self.new_metric( | ||
name="tokens_seen", value=token_len, agg_type=AggregationType.SUM | ||
), | ||
self.new_metric( | ||
name="seq_len", value=token_len, agg_type=AggregationType.DISTRIBUTION | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor thing, but to me metrics having the same value but different aggregation types should not actually be represented as distinct metrics. Like I should be able to just define how a metric is computed for a given sample, then separately choose different types of aggregation as needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, so agg_type being a List[AggregationType]?
self.new_metric(
name="tokens_seen", value=token_len,
agg_type=[AggregationType.SUM, AggregationType.DISTRIBUTION
),
I dont know if the extra complexity is worth it. Adding two metrics is cheap. But i guess there can be situations where i just want 'mean' and 'sum', and then if i create two metrics named "metric_mean" and "metric_sum", they would be logged as "metric_mean_mean" and "metric_sum_sum".
I have to think a bit about it
from torchtune.data.metrics._metric_transform import AggregationType, Metric | ||
|
||
|
||
class MetricsAggregator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A high level comment: the relationship between this and the agg handlers is not super clear to me. It seems like we are using a registry pattern where the handlers are responsible for defining the actual aggregation logic. But then the all-gather happens in here. (Separately I stand by my claim that it would be better to hold off on more complex cases like distribution aggregators so as not to boil the ocean here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the handlers are responsible for defining the actual aggregation logic. But then the all-gather happens in here.
why is that a contradiction?
- The MetricsAggregator calls the handler.finalize_local_agg
- then does a single all_gather to get the results from all ranks for all metrics
- Then calls handler._finalize_dist_agg([aggregated_results_per_rank]*n_ranks)
Do you wanna suggest a different way of doing it? Or is it hard to spot this pattern in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my claim that it would be better to hold off on more complex cases like distribution aggregators
If we dont do aggregation across ranks, we wouldnt be able to count things like "tokens_seen", right? :/
Or do you mean that we should delete DistributionAggHandler
? To clarify, this distribution has nothing to do with multiple gpus. Its just stats, e.g. std, percentiles, max, min, etc. Maybe i should rename if its causing confusion.
if cfg.get("dataset_val") is not None: | ||
raise NotImplementedError( | ||
"Validation is not supported yet with iterable datasets." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific technical reason here? Or we just haven't gotten to it yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validation datasets are not infinite!!! Need to figure out how to solve this one, but it wont be on this PR
self.global_step = 0 | ||
|
||
# Step-based training support | ||
self.num_training_steps = cfg.num_training_steps | ||
self._dataset_metrics_log_freq = cfg.get("dataset_metrics_log_freq", 100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor but does this really need to be configured separately from log_every_n_steps? (Edit: I see you already added a TODO about this below, I guess you know my vote)
self._metric_logger.log_dict(log_dict, step=self.global_step) | ||
|
||
# Log dataset metrics | ||
# #TODO: it requires all_gather. Should we keep a separate log_freq for this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth benchmarking, but if it's just a few floats I think it shouldn't matter
Context
What is the purpose of this PR? Is it to
Enable Iterable datasets in torchtune.
CONTEXT: built on top of ongoing PR step-based-ckpt: #2384
TIps when reviewing this pr
Follow this order:
torchtune/datasets/_hf_iterable.py
Changelog
Config and builder design based on the discussions after this RFC: #2785
Next steps:
7. Gather feedback on metric logging. E.g. we can add more aggregation types.
8. Polish the code a little bit
9. Add packing from this RFC: #2819
10. Add curriculum learning
11. Docs?
Test plan
UNTESTED: resume from ckpt in the recipe. However, we have plenty of tests showing that resuming works for these iterable datasets.