Skip to content

[Feat] Kernel Support for Deepseek Headdim=192 (For prefilling only) #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

l1cacheDell
Copy link

@l1cacheDell l1cacheDell commented Feb 19, 2025

Introduction

Hello, this PR mainly add support for deepseek model inference. Deepseek used MLA during inference and its attention qk shape is different from most of the trending language models.

  • Trending model qk shape: [bsz, seq_len, num_head, head_dim].
  • Deepseek qk shape: [bsz, seq_len, num_head, head_dim_nope + head_dim_pe]

where pe stands for position embedding. For details please refer to this link to see their implementation.

To integrate Sage Attention with Deepseek, the modification of head_dim must be made in order to fit the head_dim=192 inference of deepseek.

And as we have already know, the sage attention kernel usage in language models is mainly in prefilling phase. So the very first thing is to support the head_dim=192 inference.

1. Quant QK phase

Currently our approach is to enlarge qk input, as head_dim=256, (pad them to 256 in last dim) and do the quantization. (Because we have noticed that the head_dim does not affect the process of quantization so much).

2. QKV attention phase

As for attention kernel, we split the q k as:

  • query_nope: [bsz, seq_len, num_head, 128]
  • query_pe: [bsz, seq_len, num_head, 128]
  • key_nope: [bsz, seq_len, num_head, 64]
  • key_pe: [bsz, seq_len, num_head, 64]
  • _: [bsz, seq_len, num_head, 64] (all-zero tensor, leave it)
  • _: [bsz, seq_len, num_head, 64] (all-zero tensor, leave it)

and merge them before doing softmax. Thus the output shape will not be changed.

The reason why we adopt this method, is that we have met countless obstacles when we were trying to enlarge the qk input as head_dim=192, as static_assert are ubiquitous in code and enlarge head_dim to 192 will trigger countless bugs.

So, split qk to nope(128) and pe(64) and zero_tensor(64) -> compute q@k, q_pe@k_pe seperately -> merge them together can be the possible solution.

Plan

This PR is currently working in progress, which met some bugs and we are trying to fix. (We are also looking forward to the maintainer's review to check our implementation).

Right now we only uploaded the kernel impl file, and we have done a lot of modifications to other .cuh files but currently not uploaded.

@l1cacheDell l1cacheDell marked this pull request as draft February 19, 2025 11:47
@l1cacheDell
Copy link
Author

l1cacheDell commented Feb 19, 2025

A complete patch has been prepared.

Just pip install -e . and cd patch_test && python test_dsk.py will see the results.


3.4 Update: To leave a clear code space, the patch_test directory has been removed. See bench/bench_qk_int8_pv_fp8_cuda_dsk_sm90.py instead.

@l1cacheDell
Copy link
Author

l1cacheDell commented Feb 20, 2025

Sim and Diff of Sage Attn & torch SDPA: 0.9994051456451416, 0.140625
Sim and Diff of Sage Attn & Flash Attn: 0.9994051456451416, 0.140625
Sim and Diff of Flash Attn & torch SDPA: 1.0000001192092896, 0.0

Accuracy has been reached.

@l1cacheDell l1cacheDell changed the title [Feat] Kernel Support for Deepseek MLA [Feat] Kernel Support for Deepseek Headdim=192 (For prefilling only) Feb 26, 2025
@l1cacheDell l1cacheDell marked this pull request as ready for review March 3, 2025 14:17
@l1cacheDell
Copy link
Author

Deepseek-V2-Lite-Chat test results:

  • seq_len=32K
  • GPU: H100
  • Serving 100+ requests

TTFT (avg): 0.61 -> 0.54 s

@Andy0422
Copy link

Andy0422 commented Mar 4, 2025

Deepseek-V2-Lite-Chat test results:

  • seq_len=32K
  • GPU: H100
  • Serving 100+ requests

TTFT (avg): 0.61 -> 0.54 s

good job! Could you please provide a example if it ready ? Cheers!

@l1cacheDell
Copy link
Author

Deepseek-V2-Lite-Chat test results:

  • seq_len=32K
  • GPU: H100
  • Serving 100+ requests

TTFT (avg): 0.61 -> 0.54 s

good job! Could you please provide a example if it ready ? Cheers!

Thank you for attention. There is a stand alone script:

b580f8c#diff-0817e284a4dbf70d9085c87bbf5d2119d24174bafb29e1fdaa7ae52630508d9b

See patch_test/test_dsk.py this deleted file.

I have removed it from this PR, for a clean code space. But if you want a test script, I think this may works.


By the way, if you want to know how to support batching during inference, I would say: SA does not support batching currently for different seq_lens in a single batch.

So what we were trying to do is simply padding them to the same seq_len.


And, we were trying to support the unpadded batch serving, like FA2. But still work in progress. I will make another PR if things ready.

@kevinj-44
Copy link

is this suitable for RTX5000 issue?
i have a kernel error (SM89), mine is rtx5080 (3.12.9, 2.8 cu128)

@l1cacheDell
Copy link
Author

is this suitable for RTX5000 issue? i have a kernel error (SM89), mine is rtx5080 (3.12.9, 2.8 cu128)

I have never thought somebody will use this PR. So far only SM90 arch was supported for headdim=192.

I will implement SM80, SM89 version soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants