Skip to content

Delete deprecated ChatDataset and InstructDataset #1781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/api_ref_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ Class representations for the above dataset builders.
:toctree: generated/
:nosignatures:

InstructDataset
ChatDataset
TextCompletionDataset
ConcatDataset
PackedDataset
Expand Down
2 changes: 1 addition & 1 deletion docs/source/deep_dives/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ keyword arguments not specified in the config if we'd like:
tokenizer: ModelTokenizer,
train_on_input: bool = True,
max_seq_len: int = 512,
) -> InstructDataset:
) -> SFTDataset:

from torchtune import config

Expand Down
9 changes: 2 additions & 7 deletions docs/source/tutorials/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,10 @@ Fully customized datasets
-------------------------

More advanced tasks and dataset formats that don't fit into the templating and processing
that :class:`~torchtune.datasets.InstructDataset`, :class:`~torchtune.datasets.ChatDataset`,
and :class:`~torchtune.datasets.TextCompletionDataset` provide may require you to create your own dataset
class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`,
that :class:`~torchtune.datasets.SFTDataset` and :class:`~torchtune.datasets.TextCompletionDataset` provide may require
you to create your own dataset class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line below still contains a reference to InstructDataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoop

which has custom functionality for RLHF preference data, as an example to understand what you'll need to do.

If you take a look at the code for the :class:`~torchtune.datasets.PreferenceDataset` class,
you'll notice it's quite similar to :class:`~torchtune.datasets.InstructDataset` with a few
adjustments for chosen and rejected samples in preference data.

.. code-block:: python

chosen_message = [
Expand Down
59 changes: 1 addition & 58 deletions tests/torchtune/datasets/test_chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import pytest
from tests.common import ASSETS
from tests.test_utils import DummyChatFormat, DummyTokenizer
from torchtune.data import get_sharegpt_messages
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import chat_dataset, ChatDataset
from torchtune.datasets import chat_dataset


class TestChatDataset:
Expand All @@ -18,62 +17,6 @@ def chat_format(self):
return DummyChatFormat

def test_get_item(self, chat_format):
expected_tokenized_prompts = [
[
0,
7,
3,
3,
2,
2,
10,
5,
4,
2,
3,
7,
2,
5,
10,
3,
7,
2,
4,
2,
3,
-1,
0,
5,
6,
11,
10,
1,
6,
-1,
]
]
prompt_lengths = (15, 5)
expected_labels = [
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0]
+ [3, 7, 2, 4, 2, 3, -1]
+ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
+ [1, 6, -1]
]
ds = ChatDataset(
tokenizer=DummyTokenizer(),
source="json",
convert_to_messages=get_sharegpt_messages,
chat_format=chat_format,
max_seq_len=100,
train_on_input=False,
data_files=str(ASSETS / "chat_tiny.json"),
split="train",
)
assert len(ds) == 1
prompt, label = ds[0]["tokens"], ds[0]["labels"]
assert prompt == expected_tokenized_prompts[0]
assert label == expected_labels[0]

expected_tokenized_prompts = [
[
0,
Expand Down
34 changes: 1 addition & 33 deletions tests/torchtune/datasets/test_instruct_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tests.test_utils import DummyTokenizer
from torchtune.data import InstructTemplate
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import instruct_dataset, InstructDataset
from torchtune.datasets import instruct_dataset


def dummy_transform(sample):
Expand All @@ -29,38 +29,6 @@ def format(cls, sample, column_map):
class TestInstructDataset:
@pytest.mark.parametrize("train_on_input", [True, False])
def test_get_item(self, train_on_input):
template = DummyTemplate
expected_tokenized_prompts = [
[0, 12, 4, 4, 2, 2, 2, 7, 10, 9, 2, 2, 5, 2, 2, 6, 10, -1],
[0, 12, 2, 2, 8, 2, 15, 10, 9, 8, 3, 15, 3, 4, 9, 3, 15, 10, -1],
]
prompt_lengths = (10, 9)
expected_labels = [
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + [2, 2, 5, 2, 2, 6, 10, -1],
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
+ [8, 3, 15, 3, 4, 9, 3, 15, 10, -1],
]

dataset = InstructDataset(
tokenizer=DummyTokenizer(),
source="json",
template=template,
transform=dummy_transform,
train_on_input=train_on_input,
data_files=str(ASSETS / "instruct_tiny.json"),
column_map={"output": "response"},
split="train",
)
assert len(dataset) == 2

for i in range(len(dataset)):
prompt, label = dataset[i]["tokens"], dataset[i]["labels"]
assert prompt == expected_tokenized_prompts[i]
if train_on_input:
assert label == expected_tokenized_prompts[i]
else:
assert label == expected_labels[i]

expected_tokenized_prompts = [
[0, 6, 4, 6, 4, 4, 2, 2, 2, 7, 2, 2, 5, 2, 2, 6, -1],
[0, 6, 4, 6, 2, 2, 8, 2, 15, 8, 3, 15, 3, 4, 9, 3, 15, -1],
Expand Down
6 changes: 2 additions & 4 deletions torchtune/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from torchtune.datasets import multimodal
from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset
from torchtune.datasets._chat import chat_dataset, ChatDataset
from torchtune.datasets._chat import chat_dataset
from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset
from torchtune.datasets._concat import ConcatDataset
from torchtune.datasets._grammar import grammar_dataset
from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset
from torchtune.datasets._instruct import instruct_dataset, InstructDataset
from torchtune.datasets._instruct import instruct_dataset
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._preference import preference_dataset, PreferenceDataset
from torchtune.datasets._samsum import samsum_dataset
Expand All @@ -30,9 +30,7 @@
"grammar_dataset",
"samsum_dataset",
"stack_exchange_paired_dataset",
"InstructDataset",
"slimorca_dataset",
"ChatDataset",
"instruct_dataset",
"preference_dataset",
"chat_dataset",
Expand Down
103 changes: 2 additions & 101 deletions torchtune/datasets/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,111 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, Optional, Union

import numpy as np

from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data._chat_formats import ChatFormat
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import (
Message,
OpenAIToMessages,
ShareGPTToMessages,
validate_messages,
)
from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.utils._logging import deprecated


@deprecated(msg="Please use `torchtune.datasets.SFTDataset` for custom chat data.")
class ChatDataset(Dataset):
"""
Note:
This class is deprecated and will be removed in a future release. Please use
:class:`~torchtune.datasets.SFTDataset` or :func:`~torchtune.datasets.chat_dataset`
for custom chat data.

Class that supports any custom dataset with multiturn conversations.

The general flow from loading a sample to tokenized prompt is:
load sample -> apply transform -> foreach turn{format into template -> tokenize}

Use ``convert_to_messages`` to prepare your dataset into the Llama2 chat format
and roles::

[
Message(
role=<system|user|assistant>,
content=<message>,
),
...
]

This class supports multi-turn conversations. If a tokenizer sample with multiple
turns does not fit within ``max_seq_len`` then it is truncated.

Args:
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text") and pass
in the filepath in ``data_files``. See Hugging Face's ``load_dataset``
(https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path)
for more details.
convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample
and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys
chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual
messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not
as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed,
unless you want to structure messages in a particular way for inference.
max_seq_len (int): Maximum number of tokens in the returned input and label token id lists.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
such as ``data_files`` or ``split``.
"""

def __init__(
self,
*,
tokenizer: ModelTokenizer,
source: str,
convert_to_messages: Callable[[Mapping[str, Any]], List[Message]],
chat_format: Optional[ChatFormat] = None,
max_seq_len: int,
train_on_input: bool = False,
**load_dataset_kwargs: Dict[str, Any],
) -> None:

self._tokenizer = tokenizer
self._data = load_dataset(source, **load_dataset_kwargs)
self._convert_to_messages = convert_to_messages
self.chat_format = chat_format
self.max_seq_len = max_seq_len
self.train_on_input = train_on_input

def __len__(self):
return len(self._data)

def __getitem__(self, index: int) -> Dict[str, List[int]]:
sample = self._data[index]
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
messages = self._convert_to_messages(sample, self.train_on_input)
if self.chat_format is not None:
messages = self.chat_format.format(messages)
validate_messages(messages)
tokens, mask = self._tokenizer.tokenize_messages(
messages,
)
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens))
assert len(tokens) == len(labels)

return {"tokens": tokens, "labels": labels}


def chat_dataset(
Expand Down
Loading
Loading