Skip to content

Add changes to support MM eval #1669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 25, 2024

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Sep 25, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1669

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 04d1e53 with merge base 18efc81 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

except ImportError:
logger.error(
"Recipe requires EleutherAI Eval Harness v0.4. Please install with `pip install lm_eval==0.4.*`"
lm_eval_version = importlib.metadata.version("lm_eval")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking via importing legacy functions, we actually check the version.

This is the correct way lol

return self._device

@property
def cache_hook(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to do this to make the harness happy.

for text, images in zip(all_texts, all_images):
# Ensure images are all RGB
proper_images = []
for image in images:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all the images are in RGB format so we need to convert them

text, image_tag=self._image_str, images=proper_images
)
messages.append(Message(role="user", content=content))
messages.append(Message(role="assistant", content=""))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Append assistant message to kick start generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it fine for the context here to be empty, or should it be defined somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the content should be empty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to properly prompt the model for generation we need to

  • Make sure we trail with an empty assistant message
  • Set inference to true on tokenize messages

This is hard to remember both... not for this PR, but we should see if there's a more intuitive way to do this

Copy link
Contributor Author

@joecummings joecummings Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah definitely agree. Shouldn't be that hard to know what's going to give proper generation.

batch_size=self._batch_size,
dtype=self._dtype,
# Finally, we setup the actual EvalWrapper class
eleuther_model_wrapper = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per our discussion earlier, this is simplified @felipemello1

Copy link
Contributor

@felipemello1 felipemello1 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe worth logging and letting the user know that we see it as a multimodal or text only eval? Could unify it with "self.logger.info(f"Running evaluation on the following tasks: {self.tasks}")"


# Log metrics
self.logger.info(f"Eval completed in {t1:.02f} seconds.")
self.logger.info(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some memory logging.

@@ -22,21 +22,21 @@ class TestEleutherEval:
@pytest.mark.parametrize(
"eval_name, expected_acc, bsz",
[
("truthfulqa_gen", 0.1, 8),
("truthfulqa_gen", 0.1, 4),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to do bsz 8, that's just slow as hell.

@@ -122,3 +122,40 @@ def test_eval_recipe_errors_without_lm_eval(self, caplog, monkeypatch, tmpdir):

err_log = caplog.messages[-1]
assert "Recipe requires EleutherAI Eval Harness v0.4" in err_log

@pytest.mark.integration_test
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more tests good, yes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I though @SalmanMohammadi added a test very similar to this already? Also test isn't added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've decided to be a bad person and add it in a follow-up.

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can just take it out when I fix this behaviour.


if not lm_eval_version >= "0.4.2":
raise ImportError(
"lm_eval version must be >= 0.4.2. Please install lm_eval >= 0.4.2."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi, i had to create this import_guard.py for another PR, because i was using it in multiple files. I wonder if it would make sense to move this there. My first intuition is that we should NOT, as keeping this logic closer to the code is better, unless you are using it somewhere else too: https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_import_guard.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to keep in mind - my guess is that we will end up utilizing version and package guards much more frequently as our recipes expand.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 25, 2024
# +1% on truthfulqa_mc2 with a LoRA finetune. lit-gpt also sets this to False,
# see https://github.com/Lightning-AI/lit-gpt/blob/main/eval/lm_eval_harness.py#
# L66, though notably fast-gpt does the opposite
return self._transform.tokenizer.encode(string, add_bos=False, add_eos=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs are not passed. Should we do args and kwargs?

Copy link
Contributor

@ebsmothers ebsmothers Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general Transform may not have tokenizer field, right? May wanna add a check or something. I see you use this in a couple places, so maybe just in init?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A multimodal transform will

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the transform should have the encode method directly, no?

but yeah in general, I'm leaning towards just calling them both tokenizers now since it still takes in a list of messages...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - we should have a unified language going forward and do a s/ replace globally at some point.

WRT this specific issue, no I don't think transforms are guaranteed to have the encode method. And the encode output is NOT being used as input to the model - it's just for sorting lengths of inputs on the backend for Eleuther.

def tok_batch_multimodal_encode(
self,
all_texts: List[str],
all_images: List[List[PIL.Image.Image]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fine to keep "all_", but it feels weird. I think it should be the same as encode and reduce the number of bumps on the way because of different arg names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is batch_texts and batch_images better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your call, just a nit

# it into a Message format for our tokenizer
all_encoded_messages = []

for text, images in zip(all_texts, all_images):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add a test here to check if they are lists. If not, maybe we should make them a list in the case bsz=1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not an issue - they are definitely lists even if bsz = 1

proper_images = []
for image in images:
if image.mode != "RGB":
image = image.convert("RGB")
Copy link
Contributor

@felipemello1 felipemello1 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our image transform takes care of it, I believe, as long as the input is PIL. I dont like the idea that the eval is processing images.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I remove this, our transform does not work. Worth investigating if our transform should take care of this but I can tell you right now it doesn't seem to be.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this should be in the model transform, I think we enforce 3 channels now but not rgb vs bgr which might be happening here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth just adding Phil's comment in the code?


# Pad the encoded messages
tok_batch = padded_collate_tiled_images_and_mask(
all_encoded_messages,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: feel free to ignore. Not a big fan of "all_" and "tok_", but i think it may help the reader to understand its not a single sample?

@felipemello1
Copy link
Contributor

please, if you make changes, and they are applicable to the other class, we should probably keep them in sync

@@ -41,7 +41,7 @@ jobs:
run: |
python -m pip install torch torchvision torchao
python -m pip install -e ".[dev]"
python -m pip install lm-eval==0.4.*
python -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@fb963f0f0a5b28b69763590bb59676072cf43a01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we plan to add something on this to our readme or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

README would be ideal

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

U gonna add?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no u

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, why are we pinning to a commit here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eleuther added support for multimodal eval in this commit, but hasn't released a v0.4.5 patch yet. Once they do, a lot of this nonsense can go away.

content = format_content_with_images(
text, image_tag=self._image_tag, images=proper_images
)
messages.append(Message(role="user", content=content))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb q: how do we support system messages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MMMU does not have system messages so we do not support system messages right now.


# 4. Prefill step
generated_tokens = []
logits = self.model(prompt, **batch)[:, -1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: why do we pop prompt earlier just to get shape and then reinsert it here? Seems unintuitive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mirrors our generate recipe where the tokens used in the forward pass are passed in positionally and the batch is unrolled via double asterisk.

):
# TODO (@joecummings): Remove this init function so we don't load in extraneous stuff
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry what does this mean? Are some of these fields unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhmmmmm, turns out we actually load in a copy of a GPT-2 model (not that large so es okay) when we do this call. We overwrite anything that would affect generation, but it's extra memory overhead we shouldn't need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How big? If nontrivial let's def file a follow-up task (I assume we can just load in some dummy empty model instead?)

@codecov-commenter
Copy link

codecov-commenter commented Sep 25, 2024

Codecov Report

Attention: Patch coverage is 1.57068% with 188 lines in your changes missing coverage. Please review.

Project coverage is 26.07%. Comparing base (50b24e5) to head (04d1e53).
Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
recipes/eleuther_eval.py 0.00% 184 Missing ⚠️
tests/recipes/test_eleuther_eval.py 42.85% 4 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1669       +/-   ##
===========================================
- Coverage   71.11%   26.07%   -45.04%     
===========================================
  Files         297      299        +2     
  Lines       15120    15392      +272     
===========================================
- Hits        10752     4013     -6739     
- Misses       4368    11379     +7011     
Flag Coverage Δ
26.07% <1.57%> (-45.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 221 to 225
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(self._device)
model.load_state_dict(model_state_dict, assign=True)
else:
model.load_state_dict(model_state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this changed deliberately? Based on #1403 I think we want to load state dict with assign=True when quantization is enabled

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah.

"Any decoding strategy other than greedy is not supported."
)

if bsz > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit out of the loop here, what needs to change to make this happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So currently we have no utils that support batching for Fusion models (encoder_input and encoder_mask is never accounted for). My plan would be to add it here first, then upstream it to the generate function.

Comment on lines +252 to +253
encoder_max_seq_len=self.model_transform.image_seq_len
* self._max_images_per_sample,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
encoder_max_seq_len=self.model_transform.image_seq_len
* self._max_images_per_sample,
encoder_max_seq_len= (self.model_transform.image_seq_len
* self._max_images_per_sample),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless this turns it into a tuple or smth

- Loading model in fp32 or bf16. Fp16 is currently not supported.

We recommend launching evaluation using the tune CLI:
- Quantization and torch.compile (for text-only models) is supported.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you planning to add in compile support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes definitely.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take out compile for now?

@SalmanMohammadi
Copy link
Collaborator

very very nice

@SalmanMohammadi
Copy link
Collaborator

What's a guy gotta do to see some outputs round here?

"multimodal generation."
)

# 1. Setup caches for a given batch size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not actually doing this rn

if self.model.caches_are_enabled():
self.model.reset_caches()
else:
self.model.setup_caches(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if self._enable_kv_cache=False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good catch. I'm actually considering turning off this option since using the kv cache is strictly faster.

For text-only models, this was not a huge deal, but for multimodal models, things are getting slowwwwww. I'll drop a note that this is what we're doing.

if self.model.caches_are_enabled():
self.model.reset_caches()
else:
with self.device:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment above aboutself._enable_kv_cache

@joecummings
Copy link
Contributor Author

What's a guy gotta do to see some outputs round here?

|                 Tasks                 |Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|---------------------------------------|------:|------|------|------|---|-----:|---|-----:|
|mmmu_val                               |      0|none  |      |acc   |↑  |0.3233|±  |0.0271|
| - Art and Design                      |      0|none  |      |acc   |↑  |0.2000|±  |0.0656|
|  - Art                                |      0|none  |None  |acc   |↑  |0.1000|±  |0.1000|
|  - Art Theory                         |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Design                             |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Music                              |      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
| - Business                            |      0|none  |      |acc   |↑  |0.3200|±  |0.0680|
|  - Accounting                         |      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
|  - Economics                          |      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
|  - Finance                            |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Manage                             |      0|none  |None  |acc   |↑  |0.5000|±  |0.1667|
|  - Marketing                          |      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
| - Health and Medicine                 |      0|none  |      |acc   |↑  |0.3600|±  |0.0706|
|  - Basic Medical Science              |      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
|  - Clinical Medicine                  |      0|none  |None  |acc   |↑  |0.5000|±  |0.1667|
|  - Diagnostics and Laboratory Medicine|      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
|  - Pharmacy                           |      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
|  - Public Health                      |      0|none  |None  |acc   |↑  |0.4000|±  |0.1633|
| - Humanities and Social Science       |      0|none  |      |acc   |↑  |0.4000|±  |0.0791|
|  - History                            |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Literature                         |      0|none  |None  |acc   |↑  |0.5000|±  |0.1667|
|  - Psychology                         |      0|none  |None  |acc   |↑  |0.4000|±  |0.1633|
|  - Sociology                          |      0|none  |None  |acc   |↑  |0.5000|±  |0.1667|
| - Science                             |      0|none  |      |acc   |↑  |0.3000|±  |0.0585|
|  - Biology                            |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Chemistry                          |      0|none  |None  |acc   |↑  |0.0000|±  |0.0000|
|  - Geography                          |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Math                               |      0|none  |None  |acc   |↑  |0.4000|±  |0.1633|
|  - Physics                            |      0|none  |None  |acc   |↑  |0.7000|±  |0.1528|
| - Tech and Engineering                |      0|none  |      |acc   |↑  |0.3429|±  |0.0583|
|  - Agriculture                        |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Architecture and Engineering       |      0|none  |None  |acc   |↑  |0.4000|±  |0.1633|
|  - Computer Science                   |      0|none  |None  |acc   |↑  |0.4000|±  |0.1633|
|  - Electronics                        |      0|none  |None  |acc   |↑  |0.3000|±  |0.1528|
|  - Energy and Power                   |      0|none  |None  |acc   |↑  |0.4000|±  |0.1633|
|  - Materials                          |      0|none  |None  |acc   |↑  |0.2000|±  |0.1333|
|  - Mechanical Engineering             |      0|none  |None  |acc   |↑  |0.5000|±  |0.1667|

@joecummings joecummings merged commit 7207d3d into pytorch:main Sep 25, 2024
17 checks passed
@joecummings joecummings deleted the update-eval-recipe-for-mm branch September 25, 2024 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants