-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT #6493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c123aa7
to
7fca26f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just some minor issues to address.
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: | ||
x1 = x[..., ::2] | ||
x2 = x[..., 1::2] | ||
x = torch.stack((-x2, x1), dim=-1) | ||
return x.flatten(-2) | ||
|
||
|
||
def rotary_emb(q_pe, k_pe, pos, cos_sin_cache): | ||
orig_dtype = q_pe.dtype | ||
q_pe = q_pe.float() | ||
k_pe = k_pe.float() | ||
cos_sin_cache = cos_sin_cache.float() | ||
|
||
query_rot = q_pe[..., :rotary_dim] | ||
key_rot = k_pe[..., :rotary_dim] | ||
cos_sin = cos_sin_cache[pos] | ||
cos, sin = cos_sin.chunk(2, dim=-1) | ||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
query_rot = query_rot * cos + _rotate_gptj(query_rot) * sin | ||
key_rot = key_rot * cos + _rotate_gptj(key_rot) * sin | ||
return query_rot.to(orig_dtype), key_rot.to(orig_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
kva_packed = torch.ops.sgl_kernel.convert_weight_packed(kv_a_proj_weight) | ||
wkc_packed = torch.ops.sgl_kernel.convert_weight_packed(w_kc) | ||
|
||
q_out, k_out, v_out = torch.ops.sgl_kernel.qkv_proj_with_rope( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can import everything you need at the beggining.
from torch.ops.sgl_kernel import xxx, yyy, zzz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Custom registered ops cannot be accessed through this kind of import statement. I manually wrapped convert_weight_packed
and qkv_proj_with_rope
as local variables at the beginning of the file to simplify their usage in the unit tests.
Motivation
This PR is a follow-up on #2807 and #5150 to add fp8 qkv_proj_with_rope kernel for CPU. The bf16 and int8 fused_experts kernel is already added in #5150.
This PR also adds UTs for bf16, int8 and fp8 qkv_proj_with_rope kernels for CPU.
This PR also addresses the issues with the definition and usage of
decode_attention_cpu
andextend_attention_cpu
on the main branch.Modifications
The main change is the C++ kernels for fp8 qkv_proj_with_rope on CPU: sgl-kernel/csrc/cpu/qkv_proj.cpp
The UTs for qkv_proj_with_rope OPs on CPU: test/srt/cpu/test_qkv_proj_with_rope.py
Checklist