Skip to content

Sample packing for ConcatDataset #2278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 18, 2025
Merged

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Jan 17, 2025

Currently we error when any individual datasets in ConcatDataset have packed=True. This goes back to #1708: because packed and unpacked datasets require different collators and packing is done on the individual datasets rather than the ConcatDataset, we can't really guarantee a single collator for ConcatDataset is well-defined in all cases.

Fortunately, it seems pretty likely that someone who is enabling packing on one dataset would want to do so on all of them. And it's actually trivial to support this case. So this PR relaxes the check in ConcatDataset to also allow the case that every dataset is packed.

A couple things that can be revisited here:

  1. Given that dataset packing in ConcatDataset is either all or nothing, it's probably more natural to define the packed attribute there instead. However, then we get into questions of whether we pack before or after merging. Relatedly,
  2. this implementation packs each individual dataset and then merges. This means that we do not support packs combining samples from different datasets.

Test plan:

Added a unit test:

pytest tests/torchtune/datasets/test_concat_dataset.py
...
======== 5 passed in 0.10s ==========

Also ran one of the updated recipes with the following config updates:

tokenizer:
	max_seq_len: 512


dataset:
	- _component_: torchtune.datasets.alpaca_dataset
		packed: True
	- _component_: torchtune.datasets.alpaca_cleaned_dataset
		packed: True

(Per HF datasets page, Alpaca dataset has 52k samples, Alpaca cleaned has 51.8k samples)

Printing the length of different dataset properties at different points:

# Inside PackedDataset, prior to packing
>>> len(self.ds) # alpaca dataset
52002
>>> len(self.ds) # alpaca-cleaned dataset
51760

# Inside PackedDataset, after packing
>>> len(self) # alpaca dataset
13348
>>> len(self) # alpaca-cleaned dataset
 25891

# Inside ConcatDataset, after construction
>>> len(self)
39239
>>> self._indexes
[(0, 13348, 0), (13348, 39239, 1)]

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2278

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 50b72a1 with merge base b68cddd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@@ -90,3 +90,33 @@ def test_packed_dataset(self, torch_datasets):

with pytest.raises(ValueError):
concated_dataset = ConcatDataset(torch_datasets)

def test_all_packed_datasets(self, torch_datasets):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also test error is caught when some datasets are unpacked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah it already exists in the previous unit test

@codecov-commenter
Copy link

codecov-commenter commented Jan 17, 2025

Codecov Report

Attention: Patch coverage is 8.33333% with 22 lines in your changes missing coverage. Please review.

Project coverage is 23.94%. Comparing base (baae232) to head (50b72a1).
Report is 245 commits behind head on main.

Files with missing lines Patch % Lines
tests/torchtune/datasets/test_concat_dataset.py 15.38% 11 Missing ⚠️
torchtune/datasets/_concat.py 0.00% 4 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 1 Missing ⚠️
recipes/knowledge_distillation_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 1 Missing ⚠️
recipes/qat_distributed.py 0.00% 1 Missing ⚠️
recipes/qat_lora_finetune_distributed.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (baae232) and HEAD (50b72a1). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (baae232) HEAD (50b72a1)
3 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2278       +/-   ##
===========================================
- Coverage   64.30%   23.94%   -40.37%     
===========================================
  Files         352      357        +5     
  Lines       20566    21151      +585     
===========================================
- Hits        13225     5064     -8161     
- Misses       7341    16087     +8746     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ebsmothers ebsmothers merged commit 1036095 into pytorch:main Jan 18, 2025
17 checks passed
@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants