-
Notifications
You must be signed in to change notification settings - Fork 647
GENERATE V2 HOT DOG #1563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GENERATE V2 HOT DOG #1563
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1563
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 026107d with merge base b4fea32 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -252,6 +252,17 @@ class Recipe: | |||
], | |||
supports_distributed=False, | |||
), | |||
Recipe( | |||
name="dev/generate_v2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we wanna call it "dev/generate_v2"? Can do what we did for FSDP2 and just call "generate_v2", which is snappier from the CLI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbontrager yelled at me for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to call the recipe generate_v2 once it's moved out of dev. But I'm fine with not including dev in the cli name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could move the old recipe into dev as /generate_old
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I would just remove dev from the CLI name
recipes/dev/generate_v2.py
Outdated
self._logger.info( | ||
f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use our friendly util for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overkill
recipes/dev/generate_v2.py
Outdated
class SingleTurnYAMLToMessages(Transform): | ||
""" | ||
Converts a single turn conversation in YAML format to a list of messages. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why does it have to be in YAML format? Like why not just call .to_dict or whatever so this isn't as tied to configs?
Separately, is this the long-term home for this class or should it go in data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recipes ARE tied to configs though. So, this recipe needs a way to translate YAML to Message.
The core of the recipe can be ripped out / modified in any way now that things are in messages. Eventually, this should go in data, but not until we get some usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By YAML you mean an omegaconf DictConfig, right? That's why I'm suggesting to just use .to_dict cause then we get a primitive type straightaway (I think). But anyways I don't disagree with your points here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm being dumb but I'm not sure I understand. Can you show a code sample?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean in practice it's basically the same thing since the signature in SingleTurnYAMLToMessages
already takes Dict[str, Any]
. But really you are doing self.to_messages(cfg.prompt)
and (I think) cfg.prompt
is an OmegaConf.DictConfig
. So might be better to just do self.to_messages(OmegaConf.to_container(cfg.prompt))
so that SingleTurnYAMLToMessages
is actually just taking a vanilla dict rather than some OmegaConf class.
recipes/dev/generate_v2.py
Outdated
return messages | ||
|
||
|
||
def batch_to_device(batch: dict, device: torch.device) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put in training/ or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah @pbontrager is going to add this in his PR. This code will be deleted here.
recipes/dev/generate_v2.py
Outdated
logits = self.model(**batch)[:, -1] | ||
token = sample(logits, cfg.temperature, cfg.top_k) | ||
generated_tokens.append(token.item()) | ||
|
||
if is_multimodal_input: | ||
cache_mask = {"encoder_mask": batch["encoder_mask"][:, -1:]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape comments here would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also dumb q: how should I think about batch["encoder_mask"][:, -1:]
? Like is this just some kind of stand-in for "here are all the text tokens seen by the final image"? And why do we not need to update it as we generate more text tokens? I know we aggressively pad in the image dimension (I think in the collate function?), but didn't think we were doing this in the text dimension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @pbontrager
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a pretty solution, but essentially what's happening is during decoding your attention shapes are [1 x kv_cache_size] where 1 is the token being decoded. Inside of TransformerDecoder we take a single slice of a causal mask for our self attention masks. For cross attention, the mask is static during decoding, and the assumption being made here is the decoding tokens have the same mask as the final token of pre-fill e.g. the decoded tokens can see whatever images the previous token could.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a beautiful recipe, much more transparent than our existing one. Add tests 😃
recipes/dev/generate_v2.py
Outdated
generated_tokens = [] | ||
t0 = time.perf_counter() | ||
logits = self.model(**batch)[:, -1] | ||
token = sample(logits, cfg.temperature, cfg.top_k) | ||
generated_tokens.append(token.item()) | ||
|
||
if is_multimodal_input: | ||
cache_mask = {"encoder_mask": batch["encoder_mask"][:, -1:]} | ||
else: | ||
cache_mask = {} | ||
|
||
# 5. Continue generating | ||
for _ in range(cfg.max_new_tokens): | ||
if token.item() in self.model_transform.stop_tokens: | ||
break | ||
logits = self.model(token, **cache_mask)[:, -1] | ||
token = sample(logits, cfg.temperature, cfg.top_k) | ||
generated_tokens.append(token.item()) | ||
t = time.perf_counter() - t0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this going to be replaced with a call to generation.generate
? If so, does that need to be updated to support encoder masks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking is to actually leave this logic in the recipe as it increases visibility into how we actually do generate and allows users to hack around specific parts. Eventually, I'd like to deprecate the generation.generate
for use in the generation recipe. We can still use it for DPO.
Then, we can expand our generation utils to easily support more versions of sampling or maybe even beam search decoding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess generation.generate
would also be used for evals?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I put up the PR, I'd actually love to have a discussion on this. I'm honestly not sure what the best way to balance code readability + abstractions for this kind of thing right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is definitely the way forward for a generation recipe just used to test out your finetuned models, we can benefit from this simplicity by not needing any funky masking/padding logic. It's precise and very easy to understand.
I think it's mostly okay to sacrifice this simplicity when we need to cover broader use-cases and maximise performance like in PPO/evals/agentic etc. How easy it to understand all the compile integrations in the codebase? Not very. The best we can do imo is surface as much of the generation logic as we can so users can reference a single file for all the chunkier generation functionality.
from torchtune.modules.transforms import Transform | ||
|
||
|
||
class SingleTurnYAMLToMessages(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this end up somewhere like data/_utils.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, eventually will move it there.
This looks really clean. Generation has felt like a bit of a second-class citizen in torchtune, and I think this pushes us in the right direction in setting the bar for clean + hackable code across all our recipes. I really like how transparent it makes going from MM prompt -> message, this is probably the most accessible example we have in our codebase. I don't have a huge amount to add ontop of Evan's comments. What's the minimum feature set you'd like this to support before it lands? |
Minimum feature set is what I have up now. Ideal state would have compile, as well. Quantization will definitely come later. |
# Generation arguments | ||
prompt: | ||
system: You are a helpful assistant who responds like the author Shakespeare. | ||
user: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this support multiple image/text entries or just preprend image?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prepend
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1563 +/- ##
==========================================
- Coverage 71.11% 71.08% -0.03%
==========================================
Files 297 298 +1
Lines 15120 15165 +45
==========================================
+ Hits 10752 10780 +28
- Misses 4368 4385 +17
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
How is the prompt template applied here if we are not calling |
We are! The tokenizer / transform call method uses |
self.to_messages = SingleTurnYAMLToMessages() | ||
|
||
@torch.inference_mode() | ||
def generate(self, cfg: DictConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks pretty neat!
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
|
||
# Transform arguments | ||
transform: | ||
_component_: torchtune.models.llama2.llama2_tokenizer | ||
path: /tmp/Llama-2-7b-chat-hf/tokenizer.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we update the defaults to something newer than Llama2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just for an example + the test.
The real default is coming soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In reality, I want to move generation to under each model.
@@ -252,6 +252,17 @@ class Recipe: | |||
], | |||
supports_distributed=False, | |||
), | |||
Recipe( | |||
name="dev/generate_v2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I would just remove dev from the CLI name
recipes/dev/generate_v2.py
Outdated
class SingleTurnYAMLToMessages(Transform): | ||
""" | ||
Converts a single turn conversation in YAML format to a list of messages. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean in practice it's basically the same thing since the signature in SingleTurnYAMLToMessages
already takes Dict[str, Any]
. But really you are doing self.to_messages(cfg.prompt)
and (I think) cfg.prompt
is an OmegaConf.DictConfig
. So might be better to just do self.to_messages(OmegaConf.to_container(cfg.prompt))
so that SingleTurnYAMLToMessages
is actually just taking a vanilla dict rather than some OmegaConf class.
self.model_transform = config.instantiate(cfg.transform) | ||
self.to_messages = SingleTurnYAMLToMessages() | ||
|
||
def log_metrics(self, total_time: int, tokens_per_second: float) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice very hackable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hot dog Is there any reason why this wouldn't work with compile? edit: maybe because I broke something? Currently working this out. |
@ebsmothers Why? I think it makes sense b/c it's buried in dev. Lots of our users use the recipes from the git clone installation. |
Discussed the rationale in this comment. Our FSDP2 recipe was just |
Finally, a generate recipe that doesn't make me wanna eat fire ants.
Output from generate with Llama2
(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (b4fea32)]$ tune run dev/generate_v2 --config llama2/generation_v2 2024-09-23:19:51:51,843 INFO [_logging.py:101] Running InferenceRecipe with resolved config:checkpointer:
component: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-chat-hf
checkpoint_files:
model_type: LLAMA2
output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 200
model:
component: torchtune.models.llama2.llama2_7b
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
seed: 1234
temperature: 0.6
top_k: 300
transform:
component: torchtune.models.llama2.llama2_tokenizer
max_seq_len: 2048
path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
2024-09-23:19:52:01,241 INFO [generate_v2.py:90] Model was initialized with precision torch.bfloat16.
2024-09-23:19:52:09,680 INFO [generate_v2.py:201]
Oh, how delightful! adjusts glasses The capital of France is... drumroll Paris! 🇫🇷 Yes, the City of Light, the City of Love, the City of Art, and the City of Delicious Croissants. 🥐 Is there anything else I can help you with? 😊
2024-09-23:19:52:09,684 INFO [generate_v2.py:108] Time for inference: 7.57 sec total, 11.09 tokens/sec
2024-09-23:19:52:09,684 INFO [generate_v2.py:111] Bandwidth achieved: 151.59 GB/s
2024-09-23:19:52:09,684 INFO [generate_v2.py:114] Max memory allocated: 13.95 GB
Output from a MM model
(joe-torchtune-2) [[email protected] ~/projects/(fe6d3ad0d)]$ tune run generate.py --config multimodal_generation.yaml 2024-09-23:19:36:31,579 INFO [_logging.py:101] Running InferenceRecipe with resolved config:checkpointer:
component: checkpointer.FullModelMetaCheckpointer
checkpoint_dir: ../joe-torchtune
checkpoint_files:
model_type:
output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 200
model:
component:
prompt:
system: You are a helpful assistant who responds like the author Shakespeare.
user:
image: https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg
text: What is in this image?
seed: 1234
temperature: 0.6
top_k: 300
transform:
component:
max_seq_len: 8192
tokenizer_path: ../joe-torchtune/tokenizer.model
2024-09-23:19:36:37,868 INFO [generate.py:90] Model was initialized with precision torch.bfloat16.
2024-09-23:19:36:48,979 INFO [generate.py:201]
Fair patron, thou dost behold a wondrous sight, a monument of liberty, a beacon of hope, that doth stand tall on Liberty Island, in the midst of the Hudson River's flow. 'Tis the Statue of Liberty, a gift from France, a symbol of freedom's call, that doth welcome all, who seek a new life, in this fair land of America's thrall.
2024-09-23:19:36:48,987 INFO [generate.py:108] Time for inference: 9.75 sec total, 8.51 tokens/sec
2024-09-23:19:36:48,987 INFO [generate.py:111] Bandwidth achieved: 187.46 GB/s
2024-09-23:19:36:48,987 INFO [generate.py:114] Max memory allocated: 22.32 GB
Limitations: