-
Notifications
You must be signed in to change notification settings - Fork 647
[RFC] Batched inference 🤝 KV-cache 🤝 compile #1424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1424
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a0ba770 with merge base 000bb70 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BIG MOVES
We can leave evals for a follow-up.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1424 +/- ##
==========================================
- Coverage 72.87% 72.78% -0.09%
==========================================
Files 290 288 -2
Lines 14252 14214 -38
==========================================
- Hits 10386 10346 -40
- Misses 3866 3868 +2 ☔ View full report in Codecov by Sentry. |
cache_pos (Optional[torch.Tensor]): Optional tensor which contains the cache positions | ||
of each token, used during inference. This is useful when ``input_ids`` are | ||
right-shifted to account for padding tokens. Default is None, in which case | ||
``input_pos`` is used (if specified). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused by cache_pos; is this just a stopgap until #1449? If so why not consolidate the changes across the two PRs a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're landing both now so I'll update in 1449 probably.
Note: | ||
At the very first step of inference, when the model is provided with a prompt, | ||
``input_pos`` should contain the positions of all of the tokens in the prompt. | ||
For a single-batch prompt, or a batch of prompts with identical lengths, this | ||
will be``torch.arange(prompt_length)``. For a batch of varying-length prompts, | ||
shorter prompts are left-padded and position ids are correspondingly right-shifted, | ||
thus positional ids should be of shape ``[b, padded_prompt_length]``. | ||
This is because we will need to retrieve the positional embeddings for each input id. | ||
In the subsequent steps, if the model has been setup with KV-caches, ``input_pos`` will contain | ||
the position(s) of the current token(s) ``torch.tensor([padded_prompt_length])``. Otherwise, | ||
``input_pos`` will contain all the position ids up to the current token. | ||
|
||
In the case above when ``input_pos`` are right-shifted due to padding, ``cache_pos`` | ||
should be used to correctly update KV-caches, where ``cache_pos`` is ``torch.arange(prompt_length)`` | ||
during the first pre-fill step, and ``torch.tensor([prompt_length])`` for subsequent steps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here. Also I know that some of this is inevitable, but I notice that our transformer class is getting a bit bloated and a lot of the bloat comes from generation-related functionality. I think some of that is inevitable but still would like to keep it to a minimum as much as possible. I think something like this is useful for someone doing generation, but for those who are not it's a lot of extra stuff to sort through to get to the stuff they care about.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generally agree here. I don't think I mess with things too much here, but in #1449 at the least, I'd like to lift a lot of the generation logic out of the transformer/layer/attention. Currently we do a bit of hand holding like creating input pos and masks if they're not there. I'd find things much cleaner and simpler if all the generation logic was front loaded to the generate fn or recipe.
q = torch.empty( | ||
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device | ||
).exponential_(1, generator=rng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we've now just given a couple equivalent default definitions of this in different places which adds to the burden of understanding the code (but lmk if I'm missing something important here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason we have to define it here is to make generate_next_token compile-friendly when using rng, but also allow generate_next_token to work independently from this kind-of-hacky logic i.e. you can do all this without having to sample outside of generate_next_token.
I agree it makes things confusing and we could drop the default defn inside generate_next_token so it's a required arg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another dumb q - why does it need to be defined every time within this loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to how you'd re-generate it every time in sample
you have a new random distribution to sample from, right?
nit: mask[2:, 0] should be False Also one high-level question: what is the plan with respect to this and generate_v2 in #1563? Are we gonna consolidate to a single implementation, and if so when? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, but some questions mainly around the use of q
torchtune/generation/_generation.py
Outdated
0, total_response_length, device=generated_tokens.device | ||
).unsqueeze(0) | ||
|
||
del padding_masks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious - Do we know how much memory this saves?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the best case 1 byte * bsz * model_max_seq_len, so like ~10MB with Llama3.1 and bsz=100. will remove
q = torch.empty( | ||
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device | ||
).exponential_(1, generator=rng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another dumb q - why does it need to be defined every time within this loop?
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Closes #1250.
Changelog
Currently, we have two separate generation utils, one in
torchtune.generation
, and one intorchtune.rlhf._generation
.torchtune.generation
is used for single-device, single-sample generation through thegenerate
recipe, and supports iterative decoding with KV-cacheing, and compile. This makes it fast and efficient.torchtune.rlhf._generation
is primarily used for the batched generation step in PPO, and it allows for generation with variable-length samples in a batch. This is about all it does currently.This PR combines the above functionality into a single unified generation module, pairing batched generation with the speedups of KV-cacheing and compile.
Batched inference
To correctly perform inference on batched, variable length samples, we must first pad all sequences to the same length. It doesn't make a huge amount of sense to right-pad like we usually do for SFT, since we're generating new tokens to concatenate to the prompts. Sequences are therefore left padded:
Now, when we generate tokens, the completed prompt-generation sequences will be left-padded together. However, we don't actually want the model to attend to these padding tokens - they're invalid. How can we do this?
Attention masks are a way of ensuring the model ignores tokens we don't want it to pay attention to. A common example is for next-token prediction tasks - for each token in a sequence, attention masks ensure only preceeding tokens are used in attention calculation and not any future tokens. For SDPA, we typically use boolean masks where True indicates participation in attention, and False otherwise.
When we have invalid tokens in the prompt, we not only want to exclude all future tokens from the attention calculation, but also any invalid tokens in the sequence so far. If we take the first element of our unpadded-prompt earlier, we have:
However, with padding, this becomes:
Our original causal mask is now a block of a full causal mask for the whole sequence. For the batched case, the appropriate mask for each element is simply stacked i.e. shape
[bsz, seq_len, seq_len]
.Position IDs are used to inform the model of the relative position in a sequence of each token. Taking my example from above:
tells the model that the first token is in position 0, the second in position 1, etc. Now, when we have padded inputs, these position IDs are right-shifted so that the first valid token still remains in position 0:
Make sense? The batched case is exactly the same, but, you know, batched.
There are two relatively small changes we need to make here to support KV-cacheing. Firstly, note that when setting caches up, we create emtpy key and value tensors with sequence lengths
model.max_seq_len
.When using inference with KV-cacheing enabled, we pre-fill the firstprompt_length
positions in the cache with the key-value tensors for the prompt so far. Then, rather than recomputing the keys/values for the whole sequence when generating every new token, we only do it for one token at a time and reuse the cached values. This also means caches are updated every time we generate a new token - first the position atprompt_length + 1
, thenprompt_length + 2
, and so on, untilprompt_length + max_generated_tokens
.Tying this into above:
We don't actually need to change how we we grab position IDs from the usual incremental decoding - just grab the position IDs at the current position for all batch elements!
We do, however, need to ensure we provide the correct shape of causal masks. If we work our way backwards, the full causal mask for a prompt + completion will be of shape
(bsz, prompt_length + max_generated_tokens, model_max_seq_len
). To convince yourself of this fact, consider the full mask again for the example prompt above, withmax_generated_tokens=2
, andmodel_max_seq_len=9
:It's pretty much the same mask we showed above, but now we've extended it out to an additional
max_generated_tokens
rows, and up tomodel_max_seq_length
columns. It makes sense from here that during prefill step, we're just using the firstprompt_length
rows for this mask:Then for every subsequent step, we're just grabbing the corresponding row. For the first generated token, this will be the 7th row of the full causal mask for the whole sequence:
Great! Make sense so far? Seeing the Matrix yet? You can see how we can be kind-of efficient about this and pre-generate the whole causal mask for the entire expected sequence, and all the position IDs, and just correctly index at each step. This is what we're doing here.
One last thing. Recall that our position IDs are shifted to account for padding. Currently, we use position IDs to index and update KV-caches (this will be changed with #1449 when the cache tracks its own position internally). However, this doesn't make sense when position IDs are shifted, since they no longer line up neatly with the key/value tensors - KV cache positions should be unaffected by padding, i.e. the 0th cache position should correspond to the key/values for the 0th element in the sequence, even if it's an invalid token - otherwise we're stacking all the keys/values for invalid tokens into a single cache position. This is a nightmare for indexing and general sane reasoning.
Concretely, our cache positions are always just
torch.arange(0, model_max_seq_len)
, and indexed accordingly for prefill and subsequent token generation steps ([:prompt_length]
, and[curr_pos]
, respectively). Hopefully it's clear that we don't need to use separate cache positions when there are no padding tokens in the input.Did I say two changes earlier? I probably meant more.
Caveats
Currently, in order to use batched inference with KV-cacheing, we must use
model.setup_caches(batch_size=batch_size)
, which sets up caches with the model'smax_seq_len
. This is not efficient for models with long context lengths (or even any scenario when we know how many tokens we want to generate, and wheremax_generated_tokens << model.max_seq_len
. This means using batched generation with KV-cacheing will consume a huge amount of memory scaling with the model's context length and the desired batch size. This will be addressed in #1449.very quick results
Will provide example outputs from
generate.py
- but also look at my tests plsquick results - more comprehensive testing to come.
(bsz=4 varying-length sequences, max_generated_tokens = 256) on a 2080 super)
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models