-
Notifications
You must be signed in to change notification settings - Fork 147
[Feat] Kernel Support for Deepseek Headdim=192 (For prefilling only) #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
A complete patch has been prepared. Just 3.4 Update: To leave a clear code space, the |
754830b
to
e207ba6
Compare
Accuracy has been reached. |
Deepseek-V2-Lite-Chat test results:
TTFT (avg): 0.61 -> 0.54 s |
good job! Could you please provide a example if it ready ? Cheers! |
Thank you for attention. There is a stand alone script: b580f8c#diff-0817e284a4dbf70d9085c87bbf5d2119d24174bafb29e1fdaa7ae52630508d9b See I have removed it from this PR, for a clean code space. But if you want a test script, I think this may works. By the way, if you want to know how to support batching during inference, I would say: SA does not support batching currently for different seq_lens in a single batch. So what we were trying to do is simply padding them to the same seq_len. And, we were trying to support the unpadded batch serving, like FA2. But still work in progress. I will make another PR if things ready. |
is this suitable for RTX5000 issue? |
I have never thought somebody will use this PR. So far only SM90 arch was supported for headdim=192. I will implement SM80, SM89 version soon. |
Introduction
Hello, this PR mainly add support for deepseek model inference. Deepseek used MLA during inference and its attention qk shape is different from most of the trending language models.
[bsz, seq_len, num_head, head_dim]
.[bsz, seq_len, num_head, head_dim_nope + head_dim_pe]
where
pe
stands for position embedding. For details please refer to this link to see their implementation.To integrate Sage Attention with Deepseek, the modification of
head_dim
must be made in order to fit the head_dim=192 inference of deepseek.And as we have already know, the sage attention kernel usage in language models is mainly in prefilling phase. So the very first thing is to support the
head_dim=192
inference.1. Quant QK phase
Currently our approach is to enlarge qk input, as head_dim=256, (pad them to 256 in last dim) and do the quantization. (Because we have noticed that the
head_dim
does not affect the process of quantization so much).2. QKV attention phase
As for attention kernel, we split the q k as:
query_nope
:[bsz, seq_len, num_head, 128]
query_pe
:[bsz, seq_len, num_head, 128]
key_nope
:[bsz, seq_len, num_head, 64]
key_pe
:[bsz, seq_len, num_head, 64]
_
:[bsz, seq_len, num_head, 64]
(all-zero tensor, leave it)_
:[bsz, seq_len, num_head, 64]
(all-zero tensor, leave it)and merge them before doing softmax. Thus the output shape will not be changed.
The reason why we adopt this method, is that we have met countless obstacles when we were trying to enlarge the qk input as head_dim=192, as
static_assert
are ubiquitous in code and enlargehead_dim
to 192 will trigger countless bugs.So, split qk to
nope(128)
andpe(64)
andzero_tensor(64)
-> compute q@k, q_pe@k_pe seperately -> merge them together can be the possible solution.Plan
This PR is currently working in progress, which met some bugs and we are trying to fix. (We are also looking forward to the maintainer's review to check our implementation).
Right now we only uploaded the kernel impl file, and we have done a lot of modifications to other
.cuh
files but currently not uploaded.