Skip to content

[RFC] Batched inference 🤝 KV-cache 🤝 compile #1424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Sep 16, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Aug 28, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Closes #1250.

Changelog

Currently, we have two separate generation utils, one in torchtune.generation, and one in torchtune.rlhf._generation.

  • torchtune.generation is used for single-device, single-sample generation through the generate recipe, and supports iterative decoding with KV-cacheing, and compile. This makes it fast and efficient.
  • torchtune.rlhf._generation is primarily used for the batched generation step in PPO, and it allows for generation with variable-length samples in a batch. This is about all it does currently.

This PR combines the above functionality into a single unified generation module, pairing batched generation with the speedups of KV-cacheing and compile.

Batched inference

To correctly perform inference on batched, variable length samples, we must first pad all sequences to the same length. It doesn't make a huge amount of sense to right-pad like we usually do for SFT, since we're generating new tokens to concatenate to the prompts. Sequences are therefore left padded:

prompts = [
    [1, 2, 3],
    [1, 2, 3, 4, 5],
    [1]
]
prompts_padded = left_pad_sequence(..., pad_id=0)
prompts_padded
torch.tensor(
            [[0, 0, 0, 1, 2, 3], 
             [0, 1, 2, 3, 4, 5], 
             [0, 0, 0, 0, 0, 1]]
)

Now, when we generate tokens, the completed prompt-generation sequences will be left-padded together. However, we don't actually want the model to attend to these padding tokens - they're invalid. How can we do this?

  1. Causal masks

Attention masks are a way of ensuring the model ignores tokens we don't want it to pay attention to. A common example is for next-token prediction tasks - for each token in a sequence, attention masks ensure only preceeding tokens are used in attention calculation and not any future tokens. For SDPA, we typically use boolean masks where True indicates participation in attention, and False otherwise.

image

When we have invalid tokens in the prompt, we not only want to exclude all future tokens from the attention calculation, but also any invalid tokens in the sequence so far. If we take the first element of our unpadded-prompt earlier, we have:

prompt = [1, 2, 3]
mask = [
    [True, False, False],
    [True, True, False],
    [True, True, True]
]

However, with padding, this becomes:

prompt = [0, 0, 0, 1, 2, 3]
mask = [
    [True, False, False, False, False, False],
    [False, True, False, False, False, False],
    [False, False, True, False, False, False],
    [False, False, False, True, False, False],
    [False, False, False, True, True, False],
    [False, False, False, True, True, True],
]

Our original causal mask is now a block of a full causal mask for the whole sequence. For the batched case, the appropriate mask for each element is simply stacked i.e. shape [bsz, seq_len, seq_len].

  1. Position IDs

Position IDs are used to inform the model of the relative position in a sequence of each token. Taking my example from above:

prompt = [1, 2, 3]
position_ids = [0, 1, 2]

tells the model that the first token is in position 0, the second in position 1, etc. Now, when we have padded inputs, these position IDs are right-shifted so that the first valid token still remains in position 0:

prompt = [0, 0, 0, 1, 2, 3]
position_ids = [0, 0, 0, 0, 1, 2]

Make sense? The batched case is exactly the same, but, you know, batched.

  1. How does this work with KV-cacheing?

There are two relatively small changes we need to make here to support KV-cacheing. Firstly, note that when setting caches up, we create emtpy key and value tensors with sequence lengths model.max_seq_len.When using inference with KV-cacheing enabled, we pre-fill the first prompt_length positions in the cache with the key-value tensors for the prompt so far. Then, rather than recomputing the keys/values for the whole sequence when generating every new token, we only do it for one token at a time and reuse the cached values. This also means caches are updated every time we generate a new token - first the position at prompt_length + 1, then prompt_length + 2, and so on, until prompt_length + max_generated_tokens.

Tying this into above:

We don't actually need to change how we we grab position IDs from the usual incremental decoding - just grab the position IDs at the current position for all batch elements!

We do, however, need to ensure we provide the correct shape of causal masks. If we work our way backwards, the full causal mask for a prompt + completion will be of shape (bsz, prompt_length + max_generated_tokens, model_max_seq_len). To convince yourself of this fact, consider the full mask again for the example prompt above, with max_generated_tokens=2, and model_max_seq_len=9:

prompt = [0, 0, 0, 1, 2, 3]
mask = tensor(
     [[[ True, False, False, False, False, False, False, False, False],
         [False,  True, False, False, False, False, False, False, False],
         [False, False,  True, False, False, False, False, False, False],
         [False, False, False,  True, False, False, False, False, False],
         [False, False, False,  True,  True, False, False, False, False],
         [False, False, False,  True,  True,  True, False, False, False],
         [False, False, False,  True,  True,  True,  True, False, False],
         [False, False, False,  True,  True,  True,  True,  True, False]]])

It's pretty much the same mask we showed above, but now we've extended it out to an additional max_generated_tokens rows, and up tomodel_max_seq_length columns. It makes sense from here that during prefill step, we're just using the first prompt_length rows for this mask:

tensor([[[ True, False, False, False, False, False, False, False, False],
         [False,  True, False, False, False, False, False, False, False],
         [False, False,  True, False, False, False, False, False, False],
         [False, False, False,  True, False, False, False, False, False],
         [False, False, False,  True,  True, False, False, False, False],
         [False, False, False,  True,  True,  True, False, False, False]]])

Then for every subsequent step, we're just grabbing the corresponding row. For the first generated token, this will be the 7th row of the full causal mask for the whole sequence:

tensor([[False, False, False,  True,  True,  True,  True,  False, False]])

Great! Make sense so far? Seeing the Matrix yet? You can see how we can be kind-of efficient about this and pre-generate the whole causal mask for the entire expected sequence, and all the position IDs, and just correctly index at each step. This is what we're doing here.

  1. Cache positions

One last thing. Recall that our position IDs are shifted to account for padding. Currently, we use position IDs to index and update KV-caches (this will be changed with #1449 when the cache tracks its own position internally). However, this doesn't make sense when position IDs are shifted, since they no longer line up neatly with the key/value tensors - KV cache positions should be unaffected by padding, i.e. the 0th cache position should correspond to the key/values for the 0th element in the sequence, even if it's an invalid token - otherwise we're stacking all the keys/values for invalid tokens into a single cache position. This is a nightmare for indexing and general sane reasoning.

Concretely, our cache positions are always just torch.arange(0, model_max_seq_len), and indexed accordingly for prefill and subsequent token generation steps ([:prompt_length], and [curr_pos], respectively). Hopefully it's clear that we don't need to use separate cache positions when there are no padding tokens in the input.

Did I say two changes earlier? I probably meant more.

Caveats

Currently, in order to use batched inference with KV-cacheing, we must use model.setup_caches(batch_size=batch_size), which sets up caches with the model's max_seq_len. This is not efficient for models with long context lengths (or even any scenario when we know how many tokens we want to generate, and where max_generated_tokens << model.max_seq_len. This means using batched generation with KV-cacheing will consume a huge amount of memory scaling with the model's context length and the desired batch size. This will be addressed in #1449.

very quick results

Will provide example outputs from generate.py - but also look at my tests pls

quick results - more comprehensive testing to come.
(bsz=4 varying-length sequences, max_generated_tokens = 256) on a 2080 super)

on main - rlhf.generate_with_logits - no cache or compile:
Time for inference: 28.48 sec total, 35.96 tokens/sec
Bandwidth achieved: 79.52 GB/s
Memory used: 2.79 GB
---
this branch - generation.generate w/cache+compile:
Time for inference: 1.73 sec total, 590.29 tokens/sec
Bandwidth achieved: 1493.16 GB/s
Memory used: 2.68 GB
---
on main - utils.generate w/cache+compile (generates nonesense):
Time for inference: 2.68 sec total, 381.58 tokens/sec
Bandwidth achieved: 1407.27 GB/s
Memory used: 3.85 GB

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1424

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a0ba770 with merge base 000bb70 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 28, 2024
@SalmanMohammadi SalmanMohammadi changed the title [WIP] Batched inference 🤝 KV-cache 🤝 compile [WIP][RFC] Batched inference 🤝 KV-cache 🤝 compile Aug 29, 2024
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BIG MOVES

We can leave evals for a follow-up.

@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review September 13, 2024 16:38
@codecov-commenter
Copy link

codecov-commenter commented Sep 14, 2024

Codecov Report

Attention: Patch coverage is 93.47826% with 12 lines in your changes missing coverage. Please review.

Project coverage is 72.78%. Comparing base (6820089) to head (b27951b).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
recipes/ppo_full_finetune_single_device.py 0.00% 5 Missing ⚠️
recipes/generate.py 0.00% 3 Missing ⚠️
tests/torchtune/generation/test_generation.py 98.00% 2 Missing ⚠️
torchtune/generation/_generation.py 98.24% 1 Missing ⚠️
torchtune/modules/kv_cache.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1424      +/-   ##
==========================================
- Coverage   72.87%   72.78%   -0.09%     
==========================================
  Files         290      288       -2     
  Lines       14252    14214      -38     
==========================================
- Hits        10386    10346      -40     
- Misses       3866     3868       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@SalmanMohammadi SalmanMohammadi changed the title [WIP][RFC] Batched inference 🤝 KV-cache 🤝 compile [RFC] Batched inference 🤝 KV-cache 🤝 compile Sep 15, 2024
Comment on lines +94 to +97
cache_pos (Optional[torch.Tensor]): Optional tensor which contains the cache positions
of each token, used during inference. This is useful when ``input_ids`` are
right-shifted to account for padding tokens. Default is None, in which case
``input_pos`` is used (if specified).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused by cache_pos; is this just a stopgap until #1449? If so why not consolidate the changes across the two PRs a bit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're landing both now so I'll update in 1449 probably.

Comment on lines +470 to +484
Note:
At the very first step of inference, when the model is provided with a prompt,
``input_pos`` should contain the positions of all of the tokens in the prompt.
For a single-batch prompt, or a batch of prompts with identical lengths, this
will be``torch.arange(prompt_length)``. For a batch of varying-length prompts,
shorter prompts are left-padded and position ids are correspondingly right-shifted,
thus positional ids should be of shape ``[b, padded_prompt_length]``.
This is because we will need to retrieve the positional embeddings for each input id.
In the subsequent steps, if the model has been setup with KV-caches, ``input_pos`` will contain
the position(s) of the current token(s) ``torch.tensor([padded_prompt_length])``. Otherwise,
``input_pos`` will contain all the position ids up to the current token.

In the case above when ``input_pos`` are right-shifted due to padding, ``cache_pos``
should be used to correctly update KV-caches, where ``cache_pos`` is ``torch.arange(prompt_length)``
during the first pre-fill step, and ``torch.tensor([prompt_length])`` for subsequent steps.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here. Also I know that some of this is inevitable, but I notice that our transformer class is getting a bit bloated and a lot of the bloat comes from generation-related functionality. I think some of that is inevitable but still would like to keep it to a minimum as much as possible. I think something like this is useful for someone doing generation, but for those who are not it's a lot of extra stuff to sort through to get to the stuff they care about.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally agree here. I don't think I mess with things too much here, but in #1449 at the least, I'd like to lift a lot of the generation logic out of the transformer/layer/attention. Currently we do a bit of hand holding like creating input pos and masks if they're not there. I'd find things much cleaner and simpler if all the generation logic was front loaded to the generate fn or recipe.

Comment on lines +331 to +333
q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we've now just given a couple equivalent default definitions of this in different places which adds to the burden of understanding the code (but lmk if I'm missing something important here)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason we have to define it here is to make generate_next_token compile-friendly when using rng, but also allow generate_next_token to work independently from this kind-of-hacky logic i.e. you can do all this without having to sample outside of generate_next_token.

I agree it makes things confusing and we could drop the default defn inside generate_next_token so it's a required arg.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another dumb q - why does it need to be defined every time within this loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to how you'd re-generate it every time in sample you have a new random distribution to sample from, right?

@ebsmothers
Copy link
Contributor

ebsmothers commented Sep 15, 2024

However, with padding, this becomes:

prompt = [0, 0, 0, 1, 2, 3]
mask = [
[True, False, False, False, False, False],
[False, True, False, False, False, False],
[False, False, True, False, False, False],
[False, False, False, True, False, False],
[False, False, False, True, True, False],
[False, False, False, True, True, True],
]

nit: mask[2:, 0] should be False

Also one high-level question: what is the plan with respect to this and generate_v2 in #1563? Are we gonna consolidate to a single implementation, and if so when?

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but some questions mainly around the use of q

0, total_response_length, device=generated_tokens.device
).unsqueeze(0)

del padding_masks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious - Do we know how much memory this saves?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the best case 1 byte * bsz * model_max_seq_len, so like ~10MB with Llama3.1 and bsz=100. will remove

Comment on lines +331 to +333
q = torch.empty(
(bsz, model.tok_embeddings.num_embeddings), device=prompt.device
).exponential_(1, generator=rng)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another dumb q - why does it need to be defined every time within this loop?

@SalmanMohammadi SalmanMohammadi merged commit 726abb0 into pytorch:main Sep 16, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the batched_kv branch September 16, 2024 15:07
@SalmanMohammadi SalmanMohammadi mentioned this pull request Sep 16, 2024
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix generation for bsz > 1
5 participants