Skip to content

Add vqa_dataset, update docs #1820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/assets/multimodal_instruct_tiny.json
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can rename this to vqa_tiny.json

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[
{
"input": "What is presented on image?",
"output": "PyTorch logo.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫡

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that was example

"image": "tests/assets/rgb_pytorch.png"
}
]
Binary file added tests/assets/rgb_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from PIL.PngImagePlugin import PngImageFile

import pytest
import numpy as np
from tests.common import ASSETS
from tests.test_utils import DummyTokenizer

from torchtune.datasets.multimodal import multimodal_instruct_dataset


class TestMultimodalInstructDataset:
@pytest.fixture
def tokenizer(self):
return DummyTokenizer()

@pytest.mark.parametrize("train_on_input", [True, False])
def test_get_item(self, tokenizer, train_on_input):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove train_on_input parametrization? Otherwise I think this is just gonna run twice identically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point

system_prompt = "follow this prompt"

dataset = multimodal_instruct_dataset(
tokenizer=tokenizer,
source="json",
train_on_input=train_on_input,
data_files=str(ASSETS / "multimodal_instruct_tiny.json"),
split="train",
new_system_prompt=system_prompt,
)

system_prompt_offset = len(system_prompt.split(" ")) + 1

expected_tokens = [
[0, 6, 4, 6, -2, 4, 2, 9, 2, 6, 7, 5, -1],
]

expected_labels = [
[np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(7), np.int64(5), np.int64(-1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but is the np.int64 strictly necessary? I know labels are long dtype, but if we are just comparing raw lists via == and not using torch tensor asserts or something I think you shouldn't need it (and it makes the expected values easier to read)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, will be fixed.

]

assert len(dataset) == 1

for i in range(len(dataset)):
prompt, label, image = dataset[i]["tokens"], dataset[i]["labels"], dataset[i]["images"]
assert prompt == expected_tokens[i]
if train_on_input:
assert (
label[system_prompt_offset:]
== expected_tokens[i][system_prompt_offset:]
)
else:
assert label == expected_labels[i]
assert isinstance(image[0], PngImageFile) == True

26 changes: 23 additions & 3 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,20 +161,23 @@ class InputOutputToMessages(Transform):
keeping the default "input" and "output" column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.
if_multimodal (bool): Whether is dataset multimodal or not.

Raises:
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
``output`` not in ``column_map``.
``output`` not in ``column_map`` or ``image`` not in ``column_map``.
"""

def __init__(
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
is_multimodal: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we infer this similar to in ShareGPTToMessages?

is_multimodal = "image" in sample or (
            "image" in self._column_map and self._column_map["image"] in sample
        )

so users don't need to worry about another flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

):
self.train_on_input = train_on_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this slight nagging feeling that this transform is over-generalized and it wouldn't be harmful to define another transform which should properly document that the column map can include image, and to also include validation logic, and without needing to check for multimodal.

Fine to leave this as a followup but if there's consensus here I'd like to see an issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is the same for ShareGPTToMessages, so maybe it's fine. We need to properly document MM usage either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider this in further PRs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see what you mean. Especially as we add more modalities this may get bloated quick. I was the one who initiated this design in ShareGPTToMessages but I think it's worth reconsidering. I'll put up an issue.

self.new_system_prompt = new_system_prompt
self.is_multimodal = is_multimodal
if column_map:
if "input" not in column_map:
raise ValueError(
Expand All @@ -184,15 +187,32 @@ def __init__(
raise ValueError(
f"Expected a key of 'output' in column_map but found {column_map.keys()}."
)
if "image" not in column_map:
raise ValueError(
f"Expected a key of 'image' in column_map but found {column_map.keys()}."
)

self._column_map = column_map
else:
self._column_map = {"input": "input", "output": "output"}
self._column_map = {"input": "input", "output": "output", "image": "image"}

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
if self.is_multimodal:
image_path = sample[self._column_map["image"]]
pil_image = load_image(image_path)
content = [
{"type": "image", "content": pil_image},
{"type": "text", "content": sample[self._column_map["input"]]},
]
else:
content = [
{"type": "text", "content": sample[self._column_map["input"]]}
]

messages = [
Message(
role="user",
content=sample[self._column_map["input"]],
content=content,
masked=not self.train_on_input,
eot=True,
),
Expand Down
2 changes: 2 additions & 0 deletions torchtune/datasets/multimodal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from ._llava_instruct import llava_instruct_dataset
from ._multimodal import multimodal_chat_dataset
from ._the_cauldron import the_cauldron_dataset
from ._multimodal_instruct import multimodal_instruct_dataset

__all__ = [
"the_cauldron_dataset",
"llava_instruct_dataset",
"multimodal_chat_dataset",
"multimodal_instruct_dataset",
]
158 changes: 158 additions & 0 deletions torchtune/datasets/multimodal/_multimodal_instruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Dict, Optional, Union

from torchtune.data import InputOutputToMessages
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer


def multimodal_instruct_dataset(
tokenizer: ModelTokenizer,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for multimodal datasets, this needs to be:

Suggested change
tokenizer: ModelTokenizer,
model_transform: Transform,

I would follow multimodal_chat_dataset more closely instead of instruct_dataset for this one

*,
source: str,
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are removing train_on_input for multimodal datasets, as it is not really used. you'll notice that multimodal_chat_dataset also has this removed

new_system_prompt: Optional[str] = None,
packed: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't currently support packing for multimodal (yet... we are planning to add this soon)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RdoubleA Is there will be a problem to add packing in multimodal? There is no issue about it as I see. Might open then next PR with it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is non trivial work to ensure the cross attention masks are also packed and collated correctly, so it would involve 1) adding logic to pack the cross attention masks and 2) create a new collated for packed multimodal

We are actively discussing an overhaul to our dataloader / collate / packing logic in the near term, so I would hold off on multimodal packing support for now.

filter_fn: Optional[Callable] = None,
split: str = "train",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would also be good to include the image_dir functionality similar to multimodal chat: https://github.com/pytorch/torchtune/blob/main/torchtune/datasets/multimodal/_multimodal.py#L22. this would enable users to configure datasets that either have image paths or the images directly. For example, take a look a MathVista (https://huggingface.co/datasets/AI4Math/MathVista), a math image QA benchmark that has both the image paths and the raw image. some datasets may have one or the other.

This may require some non-trivial changes to InputOutputToMessages however to make it support loading images. I'm happy to provide more guidance on that, but it would be similar to the changes to ShareGPTToMessages in this pr: #1667

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I see you already made those changes :)

**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking if we don't support packing yet should return type just be SFTDataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, typehint needs to be fixed

"""
Configure a custom multimodal dataset with user instruction prompts and model responses.

This builder function can be used to configure a custom multimodal instruct dataset directly from the yaml config
as an alternative to :class:`~torchtune.datasets.SFTDataset`, as it is made to be config friendly.

The dataset should follow this format:

.. code-block:: text

| input | image | output |
|-----------------|-----------------|------------------|
| "user prompt" | "user image" | "model response" |

If your column names are different, you can use the ``column_map`` parameter to change
the expected column names. For example, if your dataset has columns ``"question"``,
``"answer"`` and ``"picture"`` you can use:

column_map = {"input": "question", "output": "answer", "image": "picture"}

Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove train_on_input

set to ``False`` by default
- If ``train_on_input`` is True, the prompt is used during training and
contributes to the loss.
- If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100)

Args:
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
source (str): path to dataset repository on Hugging Face. For local datasets,
define source as the data file type (e.g. "json", "csv", "text"), pass
in the filepath in ``data_files``, and set ``split="train"``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input"
and "output" column names to the actual column names in the dataset. Keys should be "input" and
"output" and values should be the actual column names. Default is None, keeping the default "input"
and "output" column names.
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is False.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Default is None.
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
such as ``data_files`` or ``split``.

Examples:

::

my_dataset.json
[
{
"question": "What is presented on the image?",
"answer": "PyTorch logo.",
"picture": "rgb_pytorch.png"
},
{
...
},
...,
]

::

>>> from torchtune.datasets.multimodal import multimodal_instruct_dataset
>>> dataset = multimodal_instruct_dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe structure it similar to this example:

>>> model_transform = FlamingoTransform(

and FlamingoTransform should be replaced everywhere in the example with Llama3VisionTransform https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2_vision/_transform.py#L17

... tokenizer=tokenizer,
... source="json",
... data_files="my_dataset.json",
... column_map={
... "input": "question",
... "output": "answer",
... "image": "picture"
... },
... train_on_input=False,
... packed=False,
... split="train",
... )
>>> tokens = dataset[0]["tokens"]
>>> tokenizer.decode(tokens)
"What is presented on the image?PyTorch logo."

This can also be accomplished via the yaml config:

.. code-block:: yaml

dataset:
_component_: torchtune.datasets.multimodal.multimodal_instruct_dataset
source: json
data_files: my_dataset.json
column_map:
input: question
output: answer
train_on_input: False
packed: False
split: train

Returns:
Union[SFTDataset, PackedDataset]: the configured :class:`~torchtune.datasets.SFTDataset`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type here needs to be fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

or :class:`~torchtune.datasets.PackedDataset` if ``packed=True``

Raises:
ValueError: If ``packed=True`` and ``tokenizer.max_seq_len`` is not set.
"""
message_transform = InputOutputToMessages(
train_on_input=train_on_input,
column_map=column_map,
new_system_prompt=new_system_prompt,
is_multimodal=True,
)

ds = SFTDataset(
source=source,
message_transform=message_transform,
model_transform=tokenizer,
filter_fn=filter_fn,
split=split,
**load_dataset_kwargs,
)

if packed:
if tokenizer.max_seq_len is None:
raise ValueError(
"PackedDataset requires a max_seq_len to be set on the tokenizer."
)
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len)
return ds
Loading