From 03bf740585cd5c6bba834070d7c40a174216262b Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 9 Oct 2024 04:59:42 -0700 Subject: [PATCH 1/3] Delete deprecated ChatDataset and InstructDataset --- tests/torchtune/datasets/test_chat_dataset.py | 59 +-------- .../datasets/test_instruct_dataset.py | 34 +---- torchtune/datasets/__init__.py | 6 +- torchtune/datasets/_chat.py | 103 +-------------- torchtune/datasets/_instruct.py | 122 +----------------- 5 files changed, 8 insertions(+), 316 deletions(-) diff --git a/tests/torchtune/datasets/test_chat_dataset.py b/tests/torchtune/datasets/test_chat_dataset.py index 23db9569fe..8e99ad92ae 100644 --- a/tests/torchtune/datasets/test_chat_dataset.py +++ b/tests/torchtune/datasets/test_chat_dataset.py @@ -7,9 +7,8 @@ import pytest from tests.common import ASSETS from tests.test_utils import DummyChatFormat, DummyTokenizer -from torchtune.data import get_sharegpt_messages from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX -from torchtune.datasets import chat_dataset, ChatDataset +from torchtune.datasets import chat_dataset class TestChatDataset: @@ -18,62 +17,6 @@ def chat_format(self): return DummyChatFormat def test_get_item(self, chat_format): - expected_tokenized_prompts = [ - [ - 0, - 7, - 3, - 3, - 2, - 2, - 10, - 5, - 4, - 2, - 3, - 7, - 2, - 5, - 10, - 3, - 7, - 2, - 4, - 2, - 3, - -1, - 0, - 5, - 6, - 11, - 10, - 1, - 6, - -1, - ] - ] - prompt_lengths = (15, 5) - expected_labels = [ - [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] - + [3, 7, 2, 4, 2, 3, -1] - + [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] - + [1, 6, -1] - ] - ds = ChatDataset( - tokenizer=DummyTokenizer(), - source="json", - convert_to_messages=get_sharegpt_messages, - chat_format=chat_format, - max_seq_len=100, - train_on_input=False, - data_files=str(ASSETS / "chat_tiny.json"), - split="train", - ) - assert len(ds) == 1 - prompt, label = ds[0]["tokens"], ds[0]["labels"] - assert prompt == expected_tokenized_prompts[0] - assert label == expected_labels[0] - expected_tokenized_prompts = [ [ 0, diff --git a/tests/torchtune/datasets/test_instruct_dataset.py b/tests/torchtune/datasets/test_instruct_dataset.py index c734bec885..04f3c2f49c 100644 --- a/tests/torchtune/datasets/test_instruct_dataset.py +++ b/tests/torchtune/datasets/test_instruct_dataset.py @@ -9,7 +9,7 @@ from tests.test_utils import DummyTokenizer from torchtune.data import InstructTemplate from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX -from torchtune.datasets import instruct_dataset, InstructDataset +from torchtune.datasets import instruct_dataset def dummy_transform(sample): @@ -29,38 +29,6 @@ def format(cls, sample, column_map): class TestInstructDataset: @pytest.mark.parametrize("train_on_input", [True, False]) def test_get_item(self, train_on_input): - template = DummyTemplate - expected_tokenized_prompts = [ - [0, 12, 4, 4, 2, 2, 2, 7, 10, 9, 2, 2, 5, 2, 2, 6, 10, -1], - [0, 12, 2, 2, 8, 2, 15, 10, 9, 8, 3, 15, 3, 4, 9, 3, 15, 10, -1], - ] - prompt_lengths = (10, 9) - expected_labels = [ - [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0] + [2, 2, 5, 2, 2, 6, 10, -1], - [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1] - + [8, 3, 15, 3, 4, 9, 3, 15, 10, -1], - ] - - dataset = InstructDataset( - tokenizer=DummyTokenizer(), - source="json", - template=template, - transform=dummy_transform, - train_on_input=train_on_input, - data_files=str(ASSETS / "instruct_tiny.json"), - column_map={"output": "response"}, - split="train", - ) - assert len(dataset) == 2 - - for i in range(len(dataset)): - prompt, label = dataset[i]["tokens"], dataset[i]["labels"] - assert prompt == expected_tokenized_prompts[i] - if train_on_input: - assert label == expected_tokenized_prompts[i] - else: - assert label == expected_labels[i] - expected_tokenized_prompts = [ [0, 6, 4, 6, 4, 4, 2, 2, 2, 7, 2, 2, 5, 2, 2, 6, -1], [0, 6, 4, 6, 2, 2, 8, 2, 15, 8, 3, 15, 3, 4, 9, 3, 15, -1], diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index de2e22beda..b0c7c11738 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -6,12 +6,12 @@ from torchtune.datasets import multimodal from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset -from torchtune.datasets._chat import chat_dataset, ChatDataset +from torchtune.datasets._chat import chat_dataset from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset from torchtune.datasets._concat import ConcatDataset from torchtune.datasets._grammar import grammar_dataset from torchtune.datasets._hh_rlhf_helpful import hh_rlhf_helpful_dataset -from torchtune.datasets._instruct import instruct_dataset, InstructDataset +from torchtune.datasets._instruct import instruct_dataset from torchtune.datasets._packed import PackedDataset from torchtune.datasets._preference import preference_dataset, PreferenceDataset from torchtune.datasets._samsum import samsum_dataset @@ -30,9 +30,7 @@ "grammar_dataset", "samsum_dataset", "stack_exchange_paired_dataset", - "InstructDataset", "slimorca_dataset", - "ChatDataset", "instruct_dataset", "preference_dataset", "chat_dataset", diff --git a/torchtune/datasets/_chat.py b/torchtune/datasets/_chat.py index f18961d36f..b9f5639706 100644 --- a/torchtune/datasets/_chat.py +++ b/torchtune/datasets/_chat.py @@ -4,111 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, Optional, Union -import numpy as np - -from datasets import load_dataset -from torch.utils.data import Dataset -from torchtune.data._chat_formats import ChatFormat -from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX -from torchtune.data._messages import ( - Message, - OpenAIToMessages, - ShareGPTToMessages, - validate_messages, -) +from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset from torchtune.modules.tokenizers import ModelTokenizer -from torchtune.utils._logging import deprecated - - -@deprecated(msg="Please use `torchtune.datasets.SFTDataset` for custom chat data.") -class ChatDataset(Dataset): - """ - Note: - This class is deprecated and will be removed in a future release. Please use - :class:`~torchtune.datasets.SFTDataset` or :func:`~torchtune.datasets.chat_dataset` - for custom chat data. - - Class that supports any custom dataset with multiturn conversations. - - The general flow from loading a sample to tokenized prompt is: - load sample -> apply transform -> foreach turn{format into template -> tokenize} - - Use ``convert_to_messages`` to prepare your dataset into the Llama2 chat format - and roles:: - - [ - Message( - role=, - content=, - ), - ... - ] - - This class supports multi-turn conversations. If a tokenizer sample with multiple - turns does not fit within ``max_seq_len`` then it is truncated. - - Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. - source (str): path to dataset repository on Hugging Face. For local datasets, - define source as the data file type (e.g. "json", "csv", "text") and pass - in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` - (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) - for more details. - convert_to_messages (Callable[[Mapping[str, Any]], List[Message]]): function that keys into the desired field in the sample - and converts to a list of :class:`~torchtune.data.Message` that follows the Llama format with the expected keys - chat_format (Optional[ChatFormat]): template used to format the chat. This is used to add structured text around the actual - messages, such as the [INST] tags in Llama2 and in Mistral. The extra text will still get tokenized as normal text, not - as special tokens. In models like Llama3 where the tokenizer adds tags as special tokens, ``chat_format`` is not needed, - unless you want to structure messages in a particular way for inference. - max_seq_len (int): Maximum number of tokens in the returned input and label token id lists. - train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. - **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, - such as ``data_files`` or ``split``. - """ - - def __init__( - self, - *, - tokenizer: ModelTokenizer, - source: str, - convert_to_messages: Callable[[Mapping[str, Any]], List[Message]], - chat_format: Optional[ChatFormat] = None, - max_seq_len: int, - train_on_input: bool = False, - **load_dataset_kwargs: Dict[str, Any], - ) -> None: - - self._tokenizer = tokenizer - self._data = load_dataset(source, **load_dataset_kwargs) - self._convert_to_messages = convert_to_messages - self.chat_format = chat_format - self.max_seq_len = max_seq_len - self.train_on_input = train_on_input - - def __len__(self): - return len(self._data) - - def __getitem__(self, index: int) -> Dict[str, List[int]]: - sample = self._data[index] - return self._prepare_sample(sample) - - def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: - messages = self._convert_to_messages(sample, self.train_on_input) - if self.chat_format is not None: - messages = self.chat_format.format(messages) - validate_messages(messages) - tokens, mask = self._tokenizer.tokenize_messages( - messages, - ) - # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens - labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens)) - assert len(tokens) == len(labels) - - return {"tokens": tokens, "labels": labels} def chat_dataset( diff --git a/torchtune/datasets/_instruct.py b/torchtune/datasets/_instruct.py index e291feadb3..82d9bb267e 100644 --- a/torchtune/datasets/_instruct.py +++ b/torchtune/datasets/_instruct.py @@ -4,130 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, Optional, Union -import numpy as np -from datasets import load_dataset -from torch.utils.data import Dataset -from torchtune.data import ( - CROSS_ENTROPY_IGNORE_IDX, - InputOutputToMessages, - InstructTemplate, - Message, - validate_messages, -) +from torchtune.data import InputOutputToMessages from torchtune.datasets._packed import PackedDataset from torchtune.datasets._sft import SFTDataset from torchtune.modules.tokenizers import ModelTokenizer -from torchtune.utils._logging import deprecated - - -@deprecated( - msg="Please use `torchtune.datasets.SFTDataset` or :func:`~torchtune.datasets.instruct_dataset` for custom instruct data." -) -class InstructDataset(Dataset): - """ - Note: - This class is deprecated and will be removed in a future release. Please use - :class:`~torchtune.datasets.SFTDataset` or :func:`~torchtune.datasets.instruct_dataset` - for custom instruct data. - - Class that supports any custom dataset with instruction-based prompts and a - configurable template. - - The general flow from loading a sample to tokenized prompt is: - load sample -> apply transform -> format into template -> tokenize - - If the column/key names differ from the expected names in the :class:`~torchtune.data.InstructTemplate`, - then the ``column_map`` argument can be used to provide this mapping. - - Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is - set to ``False`` by default. - - If ``train_on_input`` is True, the prompt is used during training and - contributes to the loss. - - If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) - - Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. - source (str): path to dataset repository on Hugging Face. For local datasets, - define source as the data file type (e.g. "json", "csv", "text") and pass - in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` - (https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path) - for more details. - template (InstructTemplate): template used to format the prompt. If the placeholder variable - names in the template do not match the column/key names in the dataset, use ``column_map`` to map them. - transform (Optional[Callable]): transform to apply to the sample before formatting to the template. - Default is None. - column_map (Optional[Dict[str, str]]): a mapping from the expected placeholder names in the template - to the column/key names in the sample. If None, assume these are identical. - The output column can be indicated using the ``output`` key mapping. - If no placeholder for the ``output`` column is provided in ``column_map`` it is assumed to be ``output``. - train_on_input (bool): Whether the model is trained on the prompt or not. Default is False. - max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists. - Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory - and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length. - **load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, - such as ``data_files`` or ``split``. - Raises: - ValueError: If ``template`` is not an instance of :class:`torchtune.data.InstructTemplate` - """ - - def __init__( - self, - tokenizer: ModelTokenizer, - source: str, - template: InstructTemplate, - transform: Optional[Callable] = None, - column_map: Optional[Dict[str, str]] = None, - train_on_input: bool = False, - max_seq_len: Optional[int] = None, - **load_dataset_kwargs: Dict[str, Any], - ) -> None: - if not isinstance(template(), InstructTemplate): - raise ValueError( - f"template must be an InstructTemplate class, not {type(template())}" - ) - - self._tokenizer = tokenizer - self._data = load_dataset(source, **load_dataset_kwargs) - self.template = template - self._transform = transform - self._column_map = column_map - self.train_on_input = train_on_input - self.max_seq_len = max_seq_len - - def __len__(self): - return len(self._data) - - def __getitem__(self, index: int) -> Dict[str, List[int]]: - sample = self._data[index] - return self._prepare_sample(sample) - - def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]: - transformed_sample = self._transform(sample) if self._transform else sample - - prompt = self.template.format(transformed_sample, self._column_map) - key_output = ( - self._column_map["output"] - if self._column_map and "output" in self._column_map - else "output" - ) - messages = [ - Message(role="user", content=prompt, masked=(not self.train_on_input)), - Message(role="assistant", content=transformed_sample[key_output]), - ] - - validate_messages(messages) - - tokens, mask = self._tokenizer.tokenize_messages( - messages, - ) - - # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens - labels = list(np.where(mask, CROSS_ENTROPY_IGNORE_IDX, tokens)) - assert len(tokens) == len(labels) - - return {"tokens": tokens, "labels": labels} def instruct_dataset( From c6c1b7ae43f6a4234f18ce1ed43a36075c60dcc1 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 9 Oct 2024 05:03:26 -0700 Subject: [PATCH 2/3] Update docs --- docs/source/api_ref_datasets.rst | 2 -- docs/source/tutorials/datasets.rst | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index cc0f6da466..40def346e4 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -64,8 +64,6 @@ Class representations for the above dataset builders. :toctree: generated/ :nosignatures: - InstructDataset - ChatDataset TextCompletionDataset ConcatDataset PackedDataset diff --git a/docs/source/tutorials/datasets.rst b/docs/source/tutorials/datasets.rst index a45239a523..6ece20d385 100644 --- a/docs/source/tutorials/datasets.rst +++ b/docs/source/tutorials/datasets.rst @@ -491,9 +491,8 @@ Fully customized datasets ------------------------- More advanced tasks and dataset formats that don't fit into the templating and processing -that :class:`~torchtune.datasets.InstructDataset`, :class:`~torchtune.datasets.ChatDataset`, -and :class:`~torchtune.datasets.TextCompletionDataset` provide may require you to create your own dataset -class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`, +that :class:`~torchtune.datasets.SFTDataset` and :class:`~torchtune.datasets.TextCompletionDataset` provide may require +you to create your own dataset class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`, which has custom functionality for RLHF preference data, as an example to understand what you'll need to do. If you take a look at the code for the :class:`~torchtune.datasets.PreferenceDataset` class, From 01c90db06efa19f864ac91e828321dddd5c7e125 Mon Sep 17 00:00:00 2001 From: joecummings Date: Wed, 9 Oct 2024 05:17:30 -0700 Subject: [PATCH 3/3] Finish removing references to InstructDataset --- docs/source/deep_dives/configs.rst | 2 +- docs/source/tutorials/datasets.rst | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/source/deep_dives/configs.rst b/docs/source/deep_dives/configs.rst index 54eef7144d..0f86a29a58 100644 --- a/docs/source/deep_dives/configs.rst +++ b/docs/source/deep_dives/configs.rst @@ -119,7 +119,7 @@ keyword arguments not specified in the config if we'd like: tokenizer: ModelTokenizer, train_on_input: bool = True, max_seq_len: int = 512, - ) -> InstructDataset: + ) -> SFTDataset: from torchtune import config diff --git a/docs/source/tutorials/datasets.rst b/docs/source/tutorials/datasets.rst index 6ece20d385..781573b89e 100644 --- a/docs/source/tutorials/datasets.rst +++ b/docs/source/tutorials/datasets.rst @@ -495,10 +495,6 @@ that :class:`~torchtune.datasets.SFTDataset` and :class:`~torchtune.datasets.Tex you to create your own dataset class for more flexibility. Let's walk through the :class:`~torchtune.datasets.PreferenceDataset`, which has custom functionality for RLHF preference data, as an example to understand what you'll need to do. -If you take a look at the code for the :class:`~torchtune.datasets.PreferenceDataset` class, -you'll notice it's quite similar to :class:`~torchtune.datasets.InstructDataset` with a few -adjustments for chosen and rejected samples in preference data. - .. code-block:: python chosen_message = [