Skip to content

Commit 90338a6

Browse files
joecummingsmori360
authored andcommitted
Delete deprecated ChatDataset and InstructDataset (pytorch#1781)
1 parent 38df2dd commit 90338a6

File tree

8 files changed

+11
-326
lines changed

8 files changed

+11
-326
lines changed

docs/source/api_ref_datasets.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ Class representations for the above dataset builders.
6464
:toctree: generated/
6565
:nosignatures:
6666

67-
InstructDataset
68-
ChatDataset
6967
TextCompletionDataset
7068
ConcatDataset
7169
PackedDataset

docs/source/deep_dives/configs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ keyword arguments not specified in the config if we'd like:
119119
tokenizer: ModelTokenizer,
120120
train_on_input: bool = True,
121121
max_seq_len: int = 512,
122-
) -> InstructDataset:
122+
) -> SFTDataset:
123123
124124
from torchtune import config
125125

docs/source/tutorials/datasets.rst

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,15 +491,10 @@ Fully customized datasets
491491
-------------------------
492492

493493
More advanced tasks and dataset formats that don't fit into the templating and processing
494-
that :class:`~torchtune.datasets.InstructDataset`, :class:`~torchtune.datasets.ChatDataset`,
495-
and :class:`~torchtune.datasets.TextCompletionDataset` provide may require you to create your own dataset
496-
class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`,
494+
that :class:`~torchtune.datasets.SFTDataset` and :class:`~torchtune.datasets.TextCompletionDataset` provide may require
495+
you to create your own dataset class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`,
497496
which has custom functionality for RLHF preference data, as an example to understand what you'll need to do.
498497

499-
If you take a look at the code for the :class:`~torchtune.datasets.PreferenceDataset` class,
500-
you'll notice it's quite similar to :class:`~torchtune.datasets.InstructDataset` with a few
501-
adjustments for chosen and rejected samples in preference data.
502-
503498
.. code-block:: python
504499
505500
chosen_message = [

tests/torchtune/datasets/test_chat_dataset.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import pytest
88
from tests.common import ASSETS
99
from tests.test_utils import DummyChatFormat, DummyTokenizer
10-
from torchtune.data import get_sharegpt_messages
1110
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
12-
from torchtune.datasets import chat_dataset, ChatDataset
11+
from torchtune.datasets import chat_dataset
1312

1413

1514
class TestChatDataset:
@@ -18,62 +17,6 @@ def chat_format(self):
1817
return DummyChatFormat
1918

2019
def test_get_item(self, chat_format):
21-
expected_tokenized_prompts = [
22-
[
23-
0,
24-
7,
25-
3,
26-
3,
27-
2,
28-
2,
29-
10,
30-
5,
31-
4,
32-
2,
33-
3,
34-
7,
35-
2,
36-
5,
37-
10,
38-
3,
39-
7,
40-
2,
41-
4,
42-
2,
43-
3,
44-
-1,
45-
0,
46-
5,
47-
6,
48-
11,
49-
10,
50-
1,
51-
6,
52-
-1,
53-
]
54-
]
55-
prompt_lengths = (15, 5)
56-
expected_labels = [
57-
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0]
58-
+ [3, 7, 2, 4, 2, 3, -1]
59-
+ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
60-
+ [1, 6, -1]
61-
]
62-
ds = ChatDataset(
63-
tokenizer=DummyTokenizer(),
64-
source="json",
65-
convert_to_messages=get_sharegpt_messages,
66-
chat_format=chat_format,
67-
max_seq_len=100,
68-
train_on_input=False,
69-
data_files=str(ASSETS / "chat_tiny.json"),
70-
split="train",
71-
)
72-
assert len(ds) == 1
73-
prompt, label = ds[0]["tokens"], ds[0]["labels"]
74-
assert prompt == expected_tokenized_prompts[0]
75-
assert label == expected_labels[0]
76-
7720
expected_tokenized_prompts = [
7821
[
7922
0,

tests/torchtune/datasets/test_instruct_dataset.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tests.test_utils import DummyTokenizer
1010
from torchtune.data import InstructTemplate
1111
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
12-
from torchtune.datasets import instruct_dataset, InstructDataset
12+
from torchtune.datasets import instruct_dataset
1313

1414

1515
def dummy_transform(sample):
@@ -29,38 +29,6 @@ def format(cls, sample, column_map):
2929
class TestInstructDataset:
3030
@pytest.mark.parametrize("train_on_input", [True, False])
3131
def test_get_item(self, train_on_input):
32-
template = DummyTemplate
33-
expected_tokenized_prompts = [
34-
[0, 12, 4, 4, 2, 2, 2, 7, 10, 9, 2, 2, 5, 2, 2, 6, 10, -1],
35-
[0, 12, 2, 2, 8, 2, 15, 10, 9, 8, 3, 15, 3, 4, 9, 3, 15, 10, -1],
36-
]
37-
prompt_lengths = (10, 9)
38-
expected_labels = [
39-
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + [2, 2, 5, 2, 2, 6, 10, -1],
40-
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
41-
+ [8, 3, 15, 3, 4, 9, 3, 15, 10, -1],
42-
]
43-
44-
dataset = InstructDataset(
45-
tokenizer=DummyTokenizer(),
46-
source="json",
47-
template=template,
48-
transform=dummy_transform,
49-
train_on_input=train_on_input,
50-
data_files=str(ASSETS / "instruct_tiny.json"),
51-
column_map={"output": "response"},
52-
split="train",
53-
)
54-
assert len(dataset) == 2
55-
56-
for i in range(len(dataset)):
57-
prompt, label = dataset[i]["tokens"], dataset[i]["labels"]
58-
assert prompt == expected_tokenized_prompts[i]
59-
if train_on_input:
60-
assert label == expected_tokenized_prompts[i]
61-
else:
62-
assert label == expected_labels[i]
63-
6432
expected_tokenized_prompts = [
6533
[0, 6, 4, 6, 4, 4, 2, 2, 2, 7, 2, 2, 5, 2, 2, 6, -1],
6634
[0, 6, 4, 6, 2, 2, 8, 2, 15, 8, 3, 15, 3, 4, 9, 3, 15, -1],

torchtune/datasets/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
from torchtune.datasets import multimodal
88
from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset
9-
from torchtune.datasets._chat import chat_dataset, ChatDataset
9+
from torchtune.datasets._chat import chat_dataset
1010
from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset
1111
from torchtune.datasets._concat import ConcatDataset
1212
from torchtune.datasets._grammar import grammar_dataset
1313
from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset
14-
from torchtune.datasets._instruct import instruct_dataset, InstructDataset
14+
from torchtune.datasets._instruct import instruct_dataset
1515
from torchtune.datasets._packed import PackedDataset
1616
from torchtune.datasets._preference import preference_dataset, PreferenceDataset
1717
from torchtune.datasets._samsum import samsum_dataset
@@ -30,9 +30,7 @@
3030
"grammar_dataset",
3131
"samsum_dataset",
3232
"stack_exchange_paired_dataset",
33-
"InstructDataset",
3433
"slimorca_dataset",
35-
"ChatDataset",
3634
"instruct_dataset",
3735
"preference_dataset",
3836
"chat_dataset",

torchtune/datasets/_chat.py

Lines changed: 2 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -4,111 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
7+
from typing import Any, Dict, Optional, Union
88

9-
import numpy as np
10-
11-
from datasets import load_dataset
12-
from torch.utils.data import Dataset
13-
from torchtune.data._chat_formats import ChatFormat
14-
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
15-
from torchtune.data._messages import (
16-
Message,
17-
OpenAIToMessages,
18-
ShareGPTToMessages,
19-
validate_messages,
20-
)
9+
from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
2110
from torchtune.datasets._packed import PackedDataset
2211
from torchtune.datasets._sft import SFTDataset
2312
from torchtune.modules.tokenizers import ModelTokenizer
24-
from torchtune.utils._logging import deprecated
25-
26-
27-
@deprecated(msg="Please use `torchtune.datasets.SFTDataset` for custom chat data.")
28-
class ChatDataset(Dataset):
29-
"""
30-
Note:
31-
This class is deprecated and will be removed in a future release. Please use
32-
:class:`~torchtune.datasets.SFTDataset` or :func:`~torchtune.datasets.chat_dataset`
33-
for custom chat data.
34-
35-
Class that supports any custom dataset with multiturn conversations.
36-
37-
The general flow from loading a sample to tokenized prompt is:
38-
load sample -> apply transform -> foreach turn{format into template -> tokenize}
39-
40-
Use ``convert_to_messages`` to prepare your dataset into the Llama2 chat format
41-
and roles::
42-
43-
[
44-
Message(
45-
role=<system|user|assistant>,
46-
content=<message>,
47-
),
48-
...
49-
]
50-
51-
This class supports multi-turn conversations. If a tokenizer sample with multiple
52-
turns does not fit within ``max_seq_len`` then it is truncated.
53-
54-
Args:
55-
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
56-
source (str): path to dataset repository on Hugging Face. For local datasets,
57-
define source as the data file type (e.g. "json", "csv", "text") and pass
58-
in the filepath in ``data_files``. See Hugging Face's ``load_dataset``
59-
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
60-
for more details.
61-
convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample
62-
and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys
63-
chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual
64-
messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not
65-
as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed,
66-
unless you want to structure messages in a particular way for inference.
67-
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
68-
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
69-
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
70-
such as ``data_files`` or ``split``.
71-
"""
72-
73-
def __init__(
74-
self,
75-
*,
76-
tokenizer: ModelTokenizer,
77-
source: str,
78-
convert_to_messages: Callable[[Mapping[str, Any]], List[Message]],
79-
chat_format: Optional[ChatFormat] = None,
80-
max_seq_len: int,
81-
train_on_input: bool = False,
82-
**load_dataset_kwargs: Dict[str, Any],
83-
) -> None:
84-
85-
self._tokenizer = tokenizer
86-
self._data = load_dataset(source, **load_dataset_kwargs)
87-
self._convert_to_messages = convert_to_messages
88-
self.chat_format = chat_format
89-
self.max_seq_len = max_seq_len
90-
self.train_on_input = train_on_input
91-
92-
def __len__(self):
93-
return len(self._data)
94-
95-
def __getitem__(self, index: int) -> Dict[str, List[int]]:
96-
sample = self._data[index]
97-
return self._prepare_sample(sample)
98-
99-
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
100-
messages = self._convert_to_messages(sample, self.train_on_input)
101-
if self.chat_format is not None:
102-
messages = self.chat_format.format(messages)
103-
validate_messages(messages)
104-
tokens, mask = self._tokenizer.tokenize_messages(
105-
messages,
106-
)
107-
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
108-
labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens))
109-
assert len(tokens) == len(labels)
110-
111-
return {"tokens": tokens, "labels": labels}
11213

11314

11415
def chat_dataset(

0 commit comments

Comments
 (0)