Skip to content

Add vqa_dataset, update docs #1820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 17, 2024
Merged

Conversation

krammnic
Copy link
Contributor

@krammnic krammnic commented Oct 12, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
#1704

Changelog

What are the changes made in this PR?

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1820

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1737744 with merge base 7744608 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 12, 2024
@krammnic
Copy link
Contributor Author

@RdoubleA Require review.

@krammnic
Copy link
Contributor Author

I see where is the problem. I committed and then switched branch...

@krammnic
Copy link
Contributor Author

@RdoubleA Should be fine now.

@krammnic krammnic changed the title [WIP] Add multimodal_instruct_dataset, update docs Add multimodal_instruct_dataset, update docs Oct 13, 2024
@krammnic
Copy link
Contributor Author

Yeah, some stuff failed( Let me fix then other PRs and then this one.

Copy link
Collaborator

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test this builder with an example VQA dataset once comments are addressed? Something like https://huggingface.co/datasets/derek-thomas/ScienceQA using the question and solution columns.

If you're able to easily configure this dataset using the builder you added, and you can successfully decode tokens of a random sample, and the iamge is encoded correctly, then I would declare this builder to be well-tested (in addition to the unit test you put up)

[
{
"input": "What is presented on image?",
"output": "PyTorch logo.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫡

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that was example

*,
source: str,
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are removing train_on_input for multimodal datasets, as it is not really used. you'll notice that multimodal_chat_dataset also has this removed

column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
new_system_prompt: Optional[str] = None,
packed: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't currently support packing for multimodal (yet... we are planning to add this soon)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RdoubleA Is there will be a problem to add packing in multimodal? There is no issue about it as I see. Might open then next PR with it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is non trivial work to ensure the cross attention masks are also packed and collated correctly, so it would involve 1) adding logic to pack the cross attention masks and 2) create a new collated for packed multimodal

We are actively discussing an overhaul to our dataloader / collate / packing logic in the near term, so I would hold off on multimodal packing support for now.

new_system_prompt: Optional[str] = None,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would also be good to include the image_dir functionality similar to multimodal chat: https://github.com/pytorch/torchtune/blob/main/torchtune/datasets/multimodal/_multimodal.py#L22. this would enable users to configure datasets that either have image paths or the images directly. For example, take a look a MathVista (https://huggingface.co/datasets/AI4Math/MathVista), a math image QA benchmark that has both the image paths and the raw image. some datasets may have one or the other.

This may require some non-trivial changes to InputOutputToMessages however to make it support loading images. I'm happy to provide more guidance on that, but it would be similar to the changes to ShareGPTToMessages in this pr: #1667

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I see you already made those changes :)


column_map = {"input": "question", "output": "answer", "image": "picture"}

Masking of the prompt during training is controlled by the ``train_on_input`` flag, which is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove train_on_input



def multimodal_instruct_dataset(
tokenizer: ModelTokenizer,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for multimodal datasets, this needs to be:

Suggested change
tokenizer: ModelTokenizer,
model_transform: Transform,

I would follow multimodal_chat_dataset more closely instead of instruct_dataset for this one

::

>>> from torchtune.datasets.multimodal import multimodal_instruct_dataset
>>> dataset = multimodal_instruct_dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe structure it similar to this example:

>>> model_transform = FlamingoTransform(

and FlamingoTransform should be replaced everywhere in the example with Llama3VisionTransform https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2_vision/_transform.py#L17

"""

def __init__(
self,
train_on_input: bool = False,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
is_multimodal: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we infer this similar to in ShareGPTToMessages?

is_multimodal = "image" in sample or (
            "image" in self._column_map and self._column_map["image"] in sample
        )

so users don't need to worry about another flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

@krammnic krammnic changed the title Add multimodal_instruct_dataset, update docs [WIP] Add multimodal_instruct_dataset, update docs Oct 13, 2024
]

expected_labels = [
[np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(7), np.int64(5), np.int64(-1)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but is the np.int64 strictly necessary? I know labels are long dtype, but if we are just comparing raw lists via == and not using torch tensor asserts or something I think you shouldn't need it (and it makes the expected values easier to read)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, will be fixed.

@krammnic
Copy link
Contributor Author

Done all small fixes + CI fixes. Only need extra test

@krammnic
Copy link
Contributor Author

@RdoubleA @ebsmothers
Made all required fixes, manually tested on required dataset:

from torchtune.datasets.multimodal import multimodal_instruct_dataset
import torch
from torch import nn
from torchtune.data import Message, PromptTemplate, truncate
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union


class DummyTokenizer(ModelTokenizer, Transform):
    def __init__(self, max_seq_len: Optional[int] = None):
        self.max_seq_len = max_seq_len

    def encode(self, text, add_bos=True, add_eos=True, **kwargs) -> List[int]:
        words = text.split()
        tokens = [len(word) for word in words]
        if add_bos:
            tokens = [self.bos_id] + tokens
        if add_eos:
            tokens = tokens + [self.eos_id]
        return tokens

    def tokenize_messages(
        self,
        messages: List[Message],
    ) -> Tuple[List[int], List[bool]]:
        """
        A simplified version of Llama2Tokenizer's ``tokenize_messages`` for testing purposes.
        """
        start_of_turn = True
        end_of_turn = False
        tokenized_messages = []
        mask = []
        for message in messages:
            # If assistant message, this is the end of a turn
            end_of_turn = message.role == "assistant"

            # Prepend BOS on start of new turns
            if start_of_turn:
                tokenized_messages.append(self.bos_id)
                mask.append(message.masked)

            # Tokenize current message, append with masks
            tokens = []
            for item in message.content:
                if item["type"] == "text":
                    tokens = tokens + self.encode(
                        item["content"],
                        add_bos=False,
                        add_eos=False,
                    )
                elif item["type"] == "image":
                    tokens = tokens + [self.image_id]

            tokenized_messages.extend(tokens)
            mask.extend([message.masked] * len(tokens))

            # If assistant message, append EOS at end
            if end_of_turn:
                tokenized_messages.append(self.eos_id)
                mask.append(message.masked)
                end_of_turn = False
                start_of_turn = True
            else:
                start_of_turn = False

            # Break out early if we reach max_seq_len
            if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len:
                break

        # Finally, truncate if necessary
        if self.max_seq_len:
            tokenized_messages = truncate(
                tokenized_messages, self.max_seq_len, self.eos_id
            )
            mask = truncate(mask, self.max_seq_len, message.masked)

        return tokenized_messages, mask

    def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
        messages: List[Message] = sample.pop("messages")
        images = []
        for message in messages:
            images += message.get_media()
        tokens, mask = self.tokenize_messages(messages)
        sample["tokens"] = tokens
        sample["mask"] = mask
        sample["images"] = images
        return sample

    @property
    def eos_id(self):
        return -1

    @property
    def bos_id(self):
        return 0

    @property
    def image_id(self):
        return -2

dataset = multimodal_instruct_dataset(
    model_transform=DummyTokenizer(),
    # source="json",
    source="derek-thomas/ScienceQA",
    # data_files="derek-thomas/ScienceQA",
    column_map={
        "input": "question",
        "output": "solution",
        "image": "image"
    },
    split="train"
)

tokens = dataset[0]
print(tokens)

Output:

{'input': 'question', 'output': 'solution', 'image': 'image'}
{'tokens': [0, -2, 5, 2, 5, 6, 2, 8, 6, 2, 4, 3, 7, 4, 2, 3, 7, 5, 4, 2, 5, 3, 3, 5, 5, 2, 9, 4, 8, 2, 8, 6, -1], 'mask': [True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False], 'images': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x7FD20A590D50>], 'labels': [np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(-100), np.int64(2), np.int64(4), np.int64(3), np.int64(7), np.int64(4), np.int64(2), np.int64(3), np.int64(7), np.int64(5), np.int64(4), np.int64(2), np.int64(5), np.int64(3), np.int64(3), np.int64(5), np.int64(5), np.int64(2), np.int64(9), np.int64(4), np.int64(8), np.int64(2), np.int64(8), np.int64(6), np.int64(-1)]}

I'm pretty confident that it is actually working

@joecummings joecummings mentioned this pull request Oct 15, 2024
34 tasks
@krammnic
Copy link
Contributor Author

@ebsmothers @RdoubleA Added another fix, locally tests are passed

@krammnic
Copy link
Contributor Author

Yeah, finnaly. Can someone approve and merge? Or should we do some extra checks?

Comment on lines 20 to 21
@pytest.mark.parametrize("train_on_input", [True, False])
def test_get_item(self, tokenizer, train_on_input):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove train_on_input parametrization? Otherwise I think this is just gonna run twice identically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point

Comment on lines 179 to 193
self.column_map = column_map

if self.column_map:
if "input" not in self.column_map:
raise ValueError(
f"Expected a key of 'input' in column_map but found {column_map.keys()}."
f"Expected a key of 'input' in column_map but found {self.column_map.keys()}."
)
if "output" not in column_map:
if "output" not in self.column_map:
raise ValueError(
f"Expected a key of 'output' in column_map but found {column_map.keys()}."
f"Expected a key of 'output' in column_map but found {self.column_map.keys()}."
)
self._column_map = column_map
else:
self._column_map = {"input": "input", "output": "output"}
self.column_map = {"input": "input", "output": "output"}

self._column_map = column_map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels duplicative.. why do we need to define both self.column_map and self._column_map?

Copy link
Contributor Author

@krammnic krammnic Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, will think about this, problem that we can't fully define column_map in __init__ because of is_multimodal check in__call__

filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking if we don't support packing yet should return type just be SFTDataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, typehint needs to be fixed

)
assert prompt == expected_tokens[i]
assert label == expected_labels[i]
assert isinstance(image[0], PngImageFile) is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need is True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably True might be removed, yes

@krammnic
Copy link
Contributor Author

Pushed some fixes

@krammnic
Copy link
Contributor Author

I have a feeling that we will have to separate 2 version of InputToMessages to not break SRP for example

@krammnic
Copy link
Contributor Author

@RdoubleA @joecummings Can we merge?

@@ -175,30 +175,54 @@ def __init__(
):
self.train_on_input = train_on_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this slight nagging feeling that this transform is over-generalized and it wouldn't be harmful to define another transform which should properly document that the column map can include image, and to also include validation logic, and without needing to check for multimodal.

Fine to leave this as a followup but if there's consensus here I'd like to see an issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is the same for ShareGPTToMessages, so maybe it's fine. We need to properly document MM usage either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider this in further PRs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see what you mean. Especially as we add more modalities this may get bloated quick. I was the one who initiated this design in ShareGPTToMessages but I think it's worth reconsidering. I'll put up an issue.

column_map = {"input": "question", "output": "answer", "image": "picture"}

Args:
model_transform (Transform): callable that applies model-specific pre-processing to the sample.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be more concrete here? For example, the cauldron dataset builder describes this arg as:

        model_transform (Transform): model-specific transform class that takes in a sample dict and applies custom
            transforms on the keys. It should consist of at minimum two components: text tokenization (called
            on the "messages" field) and image transform (called on the "images" field). The keys returned by
            the model transform should be aligned with the expected inputs into the model.

Comment on lines +39 to +42
the expected column names. For example, if your dataset has columns ``"question"``,
``"answer"`` and ``"picture"`` you can use:

column_map = {"input": "question", "output": "answer", "image": "picture"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we need a concrete example of using column_map here, it is better described below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is. For me it looks fine, but we can remove this description if you want

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 17, 2024

This looks great to me. Coming at this from an outside perspective I'd really love to see a couple follow up issues to make sure we properly document new features we add to our datasets. I've mentioned above some things I immediately found confusing when combing through the current docs in this PR.

@krammnic
Copy link
Contributor Author

krammnic commented Oct 17, 2024

We now can document usage of a general vqa MM builder in https://pytorch.org/torchtune/main/basics/multimodal_datasets.html, to the same extent which we have documented the MM chat builder. Alternatively, we can rename the current MM docpage to Multimodal chat datasets, and add a new page for MM instruct.

Probably the second idea is better.

I'd like to see a fine-tuning run using this dataset.

Then will run full procedure

Should I make this docs fixes in this PR or open separate?

@krammnic
Copy link
Contributor Author

Oh, lint(I accepted changes right here). Let me fix

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 17, 2024

Then will run full procedure
Should I make this docs fixes in this PR or open separate?

We'd be incredibly grateful if you're up for helping out here! I was speaking slightly generally just to make sure some of those tasks get done at some point, so myself (and the other maintainers) are happy to help out, particularly with a training run, and of course, reviews!

In any case, seperate PRs are fine.

@krammnic
Copy link
Contributor Author

krammnic commented Oct 17, 2024

Fixed

@krammnic krammnic changed the title [WIP] Add multimodal_instruct_dataset, update docs Add vqa_dataset, update docs Oct 17, 2024
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 34.88372% with 28 lines in your changes missing coverage. Please review.

Project coverage is 25.75%. Comparing base (54673b7) to head (3c2134f).
Report is 26 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/data/_messages.py 0.00% 14 Missing ⚠️
.../torchtune/datasets/multimodal/test_vqa_dataset.py 45.00% 11 Missing ⚠️
torchtune/datasets/multimodal/_vqa.py 62.50% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1820       +/-   ##
===========================================
- Coverage   67.05%   25.75%   -41.30%     
===========================================
  Files         305      307        +2     
  Lines       15937    16068      +131     
===========================================
- Hits        10687     4139     -6548     
- Misses       5250    11929     +6679     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@krammnic
Copy link
Contributor Author

Is there anything else to add in this PR?

Copy link
Collaborator

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate you pushing this all the way through. I think this highlighted some rough parts on our multimodal pipeline which we should follow up on and created some productive discussions. Just one comment, otherwise it's good to go once CI is green.

in the filepath in ``data_files``, and set ``split="train"``. See `Hugging Face's
<https://huggingface.co/docs/datasets/en/package_reference/loading_methods#datasets.load_dataset.path>`_
``load_dataset`` for more details.
image_dir (str): path to the directory containing the images as you are expected to download the COCO dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be generalized and not specific to COCO. Maybe use the same docstring from image_dir argument in InputOutputToMessages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@joecummings joecummings changed the title Add vqa_dataset, update docs Add vqa_dataset, update docs Oct 17, 2024
@joecummings joecummings merged commit f8073ed into pytorch:main Oct 17, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants