Skip to content

Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT #6493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 23, 2025

Conversation

blzheng
Copy link
Contributor

@blzheng blzheng commented May 21, 2025

Motivation

This PR is a follow-up on #2807 and #5150 to add fp8 qkv_proj_with_rope kernel for CPU. The bf16 and int8 fused_experts kernel is already added in #5150.

This PR also adds UTs for bf16, int8 and fp8 qkv_proj_with_rope kernels for CPU.
This PR also addresses the issues with the definition and usage of decode_attention_cpu and extend_attention_cpu on the main branch.

Modifications

The main change is the C++ kernels for fp8 qkv_proj_with_rope on CPU: sgl-kernel/csrc/cpu/qkv_proj.cpp
The UTs for qkv_proj_with_rope OPs on CPU: test/srt/cpu/test_qkv_proj_with_rope.py

Checklist

@blzheng blzheng changed the title Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel andd add UT Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT May 21, 2025
@blzheng blzheng marked this pull request as ready for review May 22, 2025 03:16
@mingfeima mingfeima added sgl-kernel intel cpu cpu backend performance optimization labels May 22, 2025
@mingfeima mingfeima marked this pull request as draft May 22, 2025 03:21
@blzheng blzheng force-pushed the beilei/fp8_qkvproj branch from c123aa7 to 7fca26f Compare May 22, 2025 03:33
@blzheng blzheng marked this pull request as ready for review May 22, 2025 09:11
Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just some minor issues to address.

Comment on lines 36 to 57
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)


def rotary_emb(q_pe, k_pe, pos, cos_sin_cache):
orig_dtype = q_pe.dtype
q_pe = q_pe.float()
k_pe = k_pe.float()
cos_sin_cache = cos_sin_cache.float()

query_rot = q_pe[..., :rotary_dim]
key_rot = k_pe[..., :rotary_dim]
cos_sin = cos_sin_cache[pos]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
query_rot = query_rot * cos + _rotate_gptj(query_rot) * sin
key_rot = key_rot * cos + _rotate_gptj(key_rot) * sin
return query_rot.to(orig_dtype), key_rot.to(orig_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

kva_packed = torch.ops.sgl_kernel.convert_weight_packed(kv_a_proj_weight)
wkc_packed = torch.ops.sgl_kernel.convert_weight_packed(w_kc)

q_out, k_out, v_out = torch.ops.sgl_kernel.qkv_proj_with_rope(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can import everything you need at the beggining.
from torch.ops.sgl_kernel import xxx, yyy, zzz

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom registered ops cannot be accessed through this kind of import statement. I manually wrapped convert_weight_packed and qkv_proj_with_rope as local variables at the beginning of the file to simplify their usage in the unit tests.

@blzheng blzheng requested a review from mingfeima May 23, 2025 04:18
@zhyncs zhyncs merged commit 4ba1eea into sgl-project:main May 23, 2025
29 of 47 checks passed
Layssy pushed a commit to Layssy/sglang-iaas that referenced this pull request Jun 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cpu cpu backend performance optimization intel sgl-kernel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants