-
Notifications
You must be signed in to change notification settings - Fork 648
Add vqa_dataset
, update docs
#1820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
da30333
cafa2a1
3d54213
b45f5d3
86a1b1a
95b13ba
a8df69a
0666f4c
3700994
ac2a83e
9dfa038
f882ed2
4a02e4c
e1935c7
e7826b4
12c431d
77406f6
17333ab
843338b
1f3f4b4
1fd1a99
3c2134f
f637b24
1737744
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[ | ||
{ | ||
"input": "What is presented on image?", | ||
"output": "PyTorch logo.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🫡 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that was example |
||
"image": "tests/assets/rgb_pytorch.png" | ||
} | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from PIL.PngImagePlugin import PngImageFile | ||
|
||
import pytest | ||
import numpy as np | ||
from tests.common import ASSETS | ||
from tests.test_utils import DummyTokenizer | ||
|
||
from torchtune.datasets.multimodal import multimodal_instruct_dataset | ||
|
||
|
||
class TestMultimodalInstructDataset: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
return DummyTokenizer() | ||
|
||
@pytest.mark.parametrize("train_on_input", [True, False]) | ||
def test_get_item(self, tokenizer, train_on_input): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we remove train_on_input parametrization? Otherwise I think this is just gonna run twice identically There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, good point |
||
system_prompt = "follow this prompt" | ||
|
||
dataset = multimodal_instruct_dataset( | ||
tokenizer=tokenizer, | ||
source="json", | ||
train_on_input=train_on_input, | ||
data_files=str(ASSETS / "multimodal_instruct_tiny.json"), | ||
split="train", | ||
new_system_prompt=system_prompt, | ||
) | ||
|
||
system_prompt_offset = len(system_prompt.split(" ")) + 1 | ||
|
||
expected_tokens = [ | ||
[0, 6, 4, 6, -2, 4, 2, 9, 2, 6, 7, 5, -1], | ||
] | ||
|
||
expected_labels = [ | ||
[np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(7), np.int64(5), np.int64(-1)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit but is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, will be fixed. |
||
] | ||
|
||
assert len(dataset) == 1 | ||
|
||
for i in range(len(dataset)): | ||
prompt, label, image = dataset[i]["tokens"], dataset[i]["labels"], dataset[i]["images"] | ||
assert prompt == expected_tokens[i] | ||
if train_on_input: | ||
assert ( | ||
label[system_prompt_offset:] | ||
== expected_tokens[i][system_prompt_offset:] | ||
) | ||
else: | ||
assert label == expected_labels[i] | ||
assert isinstance(image[0], PngImageFile) == True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,20 +161,23 @@ class InputOutputToMessages(Transform): | |
keeping the default "input" and "output" column names. | ||
new_system_prompt (Optional[str]): if specified, prepend a system message. This can | ||
serve as instructions to guide the model response. Default is None. | ||
if_multimodal (bool): Whether is dataset multimodal or not. | ||
|
||
Raises: | ||
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or | ||
``output`` not in ``column_map``. | ||
``output`` not in ``column_map`` or ``image`` not in ``column_map``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
train_on_input: bool = False, | ||
column_map: Optional[Dict[str, str]] = None, | ||
new_system_prompt: Optional[str] = None, | ||
is_multimodal: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we infer this similar to in ShareGPTToMessages?
so users don't need to worry about another flag There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! |
||
): | ||
self.train_on_input = train_on_input | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have this slight nagging feeling that this transform is over-generalized and it wouldn't be harmful to define another transform which should properly document that the column map can include image, and to also include validation logic, and without needing to check for multimodal. Fine to leave this as a followup but if there's consensus here I'd like to see an issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this is the same for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's consider this in further PRs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I see what you mean. Especially as we add more modalities this may get bloated quick. I was the one who initiated this design in ShareGPTToMessages but I think it's worth reconsidering. I'll put up an issue. |
||
self.new_system_prompt = new_system_prompt | ||
self.is_multimodal = is_multimodal | ||
if column_map: | ||
if "input" not in column_map: | ||
raise ValueError( | ||
|
@@ -184,15 +187,32 @@ def __init__( | |
raise ValueError( | ||
f"Expected a key of 'output' in column_map but found {column_map.keys()}." | ||
) | ||
if "image" not in column_map: | ||
raise ValueError( | ||
f"Expected a key of 'image' in column_map but found {column_map.keys()}." | ||
) | ||
|
||
self._column_map = column_map | ||
else: | ||
self._column_map = {"input": "input", "output": "output"} | ||
self._column_map = {"input": "input", "output": "output", "image": "image"} | ||
|
||
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | ||
if self.is_multimodal: | ||
image_path = sample[self._column_map["image"]] | ||
pil_image = load_image(image_path) | ||
content = [ | ||
{"type": "image", "content": pil_image}, | ||
{"type": "text", "content": sample[self._column_map["input"]]}, | ||
] | ||
else: | ||
content = [ | ||
{"type": "text", "content": sample[self._column_map["input"]]} | ||
] | ||
|
||
messages = [ | ||
Message( | ||
role="user", | ||
content=sample[self._column_map["input"]], | ||
content=content, | ||
masked=not self.train_on_input, | ||
eot=True, | ||
), | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,158 @@ | ||||||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
# All rights reserved. | ||||||
# | ||||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
from typing import Any, Callable, Dict, Optional, Union | ||||||
|
||||||
from torchtune.data import InputOutputToMessages | ||||||
from torchtune.datasets._packed import PackedDataset | ||||||
from torchtune.datasets._sft import SFTDataset | ||||||
from torchtune.modules.tokenizers import ModelTokenizer | ||||||
|
||||||
|
||||||
def multimodal_instruct_dataset( | ||||||
tokenizer: ModelTokenizer, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for multimodal datasets, this needs to be:
Suggested change
I would follow |
||||||
*, | ||||||
source: str, | ||||||
column_map: Optional[Dict[str, str]] = None, | ||||||
train_on_input: bool = False, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are removing train_on_input for multimodal datasets, as it is not really used. you'll notice that |
||||||
new_system_prompt: Optional[str] = None, | ||||||
packed: bool = False, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't currently support packing for multimodal (yet... we are planning to add this soon) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @RdoubleA Is there will be a problem to add packing in multimodal? There is no issue about it as I see. Might open then next PR with it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there is non trivial work to ensure the cross attention masks are also packed and collated correctly, so it would involve 1) adding logic to pack the cross attention masks and 2) create a new collated for packed multimodal We are actively discussing an overhaul to our dataloader / collate / packing logic in the near term, so I would hold off on multimodal packing support for now. |
||||||
filter_fn: Optional[Callable] = None, | ||||||
split: str = "train", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would also be good to include the This may require some non-trivial changes to InputOutputToMessages however to make it support loading images. I'm happy to provide more guidance on that, but it would be similar to the changes to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, I see you already made those changes :) |
||||||
**load_dataset_kwargs: Dict[str, Any], | ||||||
) -> Union[SFTDataset, PackedDataset]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Strictly speaking if we don't support packing yet should return type just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, typehint needs to be fixed |
||||||
""" | ||||||
Configure a custom multimodal dataset with user instruction prompts and model responses. | ||||||
krammnic marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
This builder function can be used to configure a custom multimodal instruct dataset directly from the yaml config | ||||||
as an alternative to :class:`~torchtune.datasets.SFTDataset`, as it is made to be config friendly. | ||||||
|
||||||
The dataset should follow this format: | ||||||
|
||||||
.. code-block:: text | ||||||
|
||||||
| input | image | output | | ||||||
|-----------------|-----------------|------------------| | ||||||
| "user prompt" | "user image" | "model response" | | ||||||
krammnic marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
If your column names are different, you can use the ``column_map`` parameter to change | ||||||
the expected column names. For example, if your dataset has columns ``"question"``, | ||||||
``"answer"`` and ``"picture"`` you can use: | ||||||
|
||||||
column_map = {"input": "question", "output": "answer", "image": "picture"} | ||||||
|
||||||
Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove train_on_input |
||||||
set to ``False`` by default | ||||||
- If ``train_on_input`` is True, the prompt is used during training and | ||||||
contributes to the loss. | ||||||
- If ``train_on_input`` is False, the prompt is masked out (tokens replaced with -100) | ||||||
|
||||||
Args: | ||||||
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. | ||||||
source (str): path to dataset repository on Hugging Face. For local datasets, | ||||||
define source as the data file type (e.g. "json", "csv", "text"), pass | ||||||
in the filepath in ``data_files``, and set ``split="train"``. See `Hugging Face's | ||||||
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_ | ||||||
``load_dataset`` for more details. | ||||||
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input" | ||||||
and "output" column names to the actual column names in the dataset. Keys should be "input" and | ||||||
"output" and values should be the actual column names. Default is None, keeping the default "input" | ||||||
and "output" column names. | ||||||
train_on_input (bool): Whether the model is trained on the user prompt or not. | ||||||
Default is False. | ||||||
new_system_prompt (Optional[str]): if specified, prepend a system message. This can | ||||||
serve as instructions to guide the model response. Default is None. | ||||||
packed (bool): Whether or not to pack the dataset to tokenizer's ``max_seq_len`` prior to training. Default is False. | ||||||
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See | ||||||
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more | ||||||
details. | ||||||
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset | ||||||
of a given split, e.g. ``split="train[:10%]"``. Default is "train". | ||||||
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``, | ||||||
such as ``data_files`` or ``split``. | ||||||
|
||||||
Examples: | ||||||
|
||||||
:: | ||||||
|
||||||
my_dataset.json | ||||||
[ | ||||||
{ | ||||||
"question": "What is presented on the image?", | ||||||
"answer": "PyTorch logo.", | ||||||
"picture": "rgb_pytorch.png" | ||||||
}, | ||||||
{ | ||||||
... | ||||||
}, | ||||||
..., | ||||||
] | ||||||
|
||||||
:: | ||||||
|
||||||
>>> from torchtune.datasets.multimodal import multimodal_instruct_dataset | ||||||
>>> dataset = multimodal_instruct_dataset( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe structure it similar to this example:
and FlamingoTransform should be replaced everywhere in the example with |
||||||
... tokenizer=tokenizer, | ||||||
... source="json", | ||||||
... data_files="my_dataset.json", | ||||||
... column_map={ | ||||||
... "input": "question", | ||||||
... "output": "answer", | ||||||
... "image": "picture" | ||||||
... }, | ||||||
... train_on_input=False, | ||||||
... packed=False, | ||||||
... split="train", | ||||||
... ) | ||||||
>>> tokens = dataset[0]["tokens"] | ||||||
>>> tokenizer.decode(tokens) | ||||||
"What is presented on the image?PyTorch logo." | ||||||
|
||||||
This can also be accomplished via the yaml config: | ||||||
|
||||||
.. code-block:: yaml | ||||||
|
||||||
dataset: | ||||||
_component_: torchtune.datasets.multimodal.multimodal_instruct_dataset | ||||||
source: json | ||||||
data_files: my_dataset.json | ||||||
column_map: | ||||||
input: question | ||||||
output: answer | ||||||
train_on_input: False | ||||||
packed: False | ||||||
split: train | ||||||
|
||||||
Returns: | ||||||
Union[SFTDataset, PackedDataset]: the configured :class:`~torchtune.datasets.SFTDataset` | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Return type here needs to be fixed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||||||
or :class:`~torchtune.datasets.PackedDataset` if ``packed=True`` | ||||||
|
||||||
Raises: | ||||||
ValueError: If ``packed=True`` and ``tokenizer.max_seq_len`` is not set. | ||||||
""" | ||||||
message_transform = InputOutputToMessages( | ||||||
train_on_input=train_on_input, | ||||||
column_map=column_map, | ||||||
new_system_prompt=new_system_prompt, | ||||||
is_multimodal=True, | ||||||
) | ||||||
|
||||||
ds = SFTDataset( | ||||||
source=source, | ||||||
message_transform=message_transform, | ||||||
model_transform=tokenizer, | ||||||
filter_fn=filter_fn, | ||||||
split=split, | ||||||
**load_dataset_kwargs, | ||||||
) | ||||||
|
||||||
if packed: | ||||||
if tokenizer.max_seq_len is None: | ||||||
raise ValueError( | ||||||
"PackedDataset requires a max_seq_len to be set on the tokenizer." | ||||||
) | ||||||
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len) | ||||||
return ds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can rename this to vqa_tiny.json