Skip to content

Commit 47bc8df

Browse files
authored
Add rotary_position_embedding_cpu kernel instead of native impl (#18)
* add rope * remove B * Fix issue * update * Add fused rope * refactor * add checks * support non-contiguous * update parallel
1 parent 731290e commit 47bc8df

File tree

6 files changed

+280
-27
lines changed

6 files changed

+280
-27
lines changed

python/sglang/srt/layers/rotary_embedding.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
if _is_cuda_available:
1616
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
1717

18+
from sglang.srt.cpu_utils import cpu_has_amx_support
19+
if cpu_has_amx_support():
20+
import sgl_kernel.cpu
1821

1922
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
2023
x1 = x[..., : x.shape[-1] // 2]
@@ -719,37 +722,42 @@ def forward(
719722
offsets: Optional[torch.Tensor] = None,
720723
) -> Tuple[torch.Tensor, torch.Tensor]:
721724
"""PyTorch-native implementation equivalent to forward()."""
722-
query_rot = query[..., : self.rotary_dim]
723-
key_rot = key[..., : self.rotary_dim]
724-
if self.rotary_dim < self.head_size:
725-
query_pass = query[..., self.rotary_dim :]
726-
key_pass = key[..., self.rotary_dim :]
725+
positions = torch.add(positions, offsets) if offsets is not None else positions
727726

728-
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
729-
cos_sin = self.cos_sin_cache[
730-
torch.add(positions, offsets) if offsets is not None else positions
731-
]
732-
cos, sin = cos_sin.chunk(2, dim=-1)
733-
if self.is_neox_style:
734-
# NOTE(woosuk): Here we assume that the positions tensor has the
735-
# shape [batch_size, seq_len].
736-
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
737-
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
727+
# TODO: Add scenario of self.rotary_dim < self.head_size
728+
if positions.device == torch.device("cpu") and cpu_has_amx_support():
729+
return sgl_kernel.cpu.rotary_position_embedding(
730+
positions, query, key, self.cos_sin_cache)
738731
else:
739-
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
740-
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
732+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
733+
query_rot = query[..., : self.rotary_dim]
734+
key_rot = key[..., : self.rotary_dim]
735+
if self.rotary_dim < self.head_size:
736+
query_pass = query[..., self.rotary_dim :]
737+
key_pass = key[..., self.rotary_dim :]
738+
739+
cos_sin = self.cos_sin_cache[positions]
740+
cos, sin = cos_sin.chunk(2, dim=-1)
741+
if self.is_neox_style:
742+
# NOTE(woosuk): Here we assume that the positions tensor has the
743+
# shape [batch_size, seq_len].
744+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
745+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
746+
else:
747+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
748+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
741749

742-
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
743-
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
744-
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
750+
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
751+
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
752+
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
745753

746-
if self.rotary_dim < self.head_size:
747-
query = torch.cat((query_rot, query_pass), dim=-1)
748-
key = torch.cat((key_rot, key_pass), dim=-1)
749-
else:
750-
query = query_rot
751-
key = key_rot
752-
return query, key
754+
if self.rotary_dim < self.head_size:
755+
query = torch.cat((query_rot, query_pass), dim=-1)
756+
key = torch.cat((key_rot, key_pass), dim=-1)
757+
else:
758+
query = query_rot
759+
key = key_rot
760+
return query, key
753761

754762

755763
class Llama3RotaryEmbedding(RotaryEmbedding):

sgl-kernel/csrc/cpu/rope.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include "common.h"
2+
#include "vec.h"
3+
4+
namespace {
5+
6+
template <typename scalar_t>
7+
void rope_kernel_impl(
8+
scalar_t* __restrict__ q_pe_out,
9+
scalar_t* __restrict__ k_pe_out,
10+
int64_t* __restrict__ t_pos,
11+
scalar_t* __restrict__ q_pe,
12+
scalar_t* __restrict__ k_pe,
13+
scalar_t* __restrict__ t_emb_pos,
14+
int64_t seq_len,
15+
int64_t num_head,
16+
int64_t rotary_dim,
17+
int64_t HR,
18+
int64_t q_pe_stride_s,
19+
int64_t out_stride_qs,
20+
int64_t out_stride_ks,
21+
int64_t HK,
22+
int64_t k_pe_stride_s,
23+
int64_t q_pe_stride_n,
24+
int64_t out_stride_qn) {
25+
int64_t COFF = HR / 2;
26+
at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
27+
int64_t seq{0}, head_id{0};
28+
data_index_init(begin, seq, seq_len, head_id, num_head);
29+
for (int64_t i = begin; i < end; ++i) {
30+
int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n;
31+
int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn;
32+
int64_t out_offset_k = seq * out_stride_ks;
33+
int64_t p = 0;
34+
scalar_t* sin_start = nullptr;
35+
scalar_t* cos_start = nullptr;
36+
// step 0) get the rotary position embedding for the current position
37+
p = t_pos[seq];
38+
sin_start = t_emb_pos + p * HR + COFF;
39+
cos_start = t_emb_pos + p * HR;
40+
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
41+
// head of query/key
42+
for (int64_t h = 0; h < rotary_dim; h += 2) {
43+
scalar_t cos = cos_start[h >> 1];
44+
scalar_t sin = sin_start[h >> 1];
45+
scalar_t in1 = q_pe[in_offset_q + h];
46+
scalar_t in2 = q_pe[in_offset_q + h + 1];
47+
scalar_t out1 = in1 * cos - in2 * sin;
48+
scalar_t out2 = in2 * cos + in1 * sin;
49+
q_pe_out[out_offset_q + h] = out1;
50+
q_pe_out[out_offset_q + h + 1] = out2;
51+
}
52+
for (int64_t h = 0; h < HK; h += 2) {
53+
scalar_t cos = cos_start[h >> 1];
54+
scalar_t sin = sin_start[h >> 1];
55+
int64_t k_pe_offset = seq * k_pe_stride_s;
56+
scalar_t in1_k = k_pe[k_pe_offset + h];
57+
scalar_t in2_k = k_pe[k_pe_offset + h + 1];
58+
scalar_t out1_k = in1_k * cos - in2_k * sin;
59+
scalar_t out2_k = in2_k * cos + in1_k * sin;
60+
k_pe_out[out_offset_k + h] = out1_k;
61+
k_pe_out[out_offset_k + h + 1] = out2_k;
62+
}
63+
// move to the next index
64+
data_index_step(seq, seq_len, head_id, num_head);
65+
}
66+
});
67+
}
68+
} // namespace
69+
70+
std::tuple<at::Tensor, at::Tensor>
71+
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) {
72+
RECORD_FUNCTION(
73+
"sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos}));
74+
CHECK_INPUT (t_pos);
75+
CHECK_LAST_DIM_CONTIGUOUS_INPUT (q_pe);
76+
CHECK_LAST_DIM_CONTIGUOUS_INPUT (k_pe);
77+
CHECK_INPUT (t_emb_pos);
78+
CHECK_DIM(1, t_pos);
79+
CHECK_DIM(3, q_pe);
80+
CHECK_DIM(3, k_pe);
81+
CHECK_DIM(2, t_emb_pos);
82+
83+
int64_t seq_len = q_pe.size(0);
84+
int64_t num_head = q_pe.size(1);
85+
int64_t rotary_dim = q_pe.size(2);
86+
int64_t HK = k_pe.size(2);
87+
int64_t HR = t_emb_pos.size(1);
88+
CHECK_EQ(HR, rotary_dim);
89+
CHECK_EQ(k_pe.size(0), seq_len);
90+
CHECK_EQ(k_pe.size(1), 1);
91+
CHECK_EQ(t_pos.size(0), seq_len);
92+
CHECK_EQ(HK, rotary_dim);
93+
94+
at::Tensor q_pe_out = at::empty_like(q_pe);
95+
at::Tensor k_pe_out = at::empty_like(k_pe);
96+
int64_t q_pe_stride_s = q_pe.stride(0);
97+
int64_t q_pe_stride_n = q_pe.stride(1);
98+
int64_t k_pe_stride_s = k_pe.stride(0);
99+
int64_t out_stride_qs = q_pe_out.stride(0);
100+
int64_t out_stride_qn = q_pe_out.stride(1);
101+
int64_t out_stride_ks = k_pe_out.stride(0);
102+
103+
const auto input_dtype = q_pe.scalar_type();
104+
TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type());
105+
TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type");
106+
TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type");
107+
108+
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] {
109+
rope_kernel_impl<scalar_t>(
110+
q_pe_out.data_ptr<scalar_t>(),
111+
k_pe_out.data_ptr<scalar_t>(),
112+
t_pos.data_ptr<int64_t>(),
113+
q_pe.data_ptr<scalar_t>(),
114+
k_pe.data_ptr<scalar_t>(),
115+
t_emb_pos.data_ptr<scalar_t>(),
116+
seq_len,
117+
num_head,
118+
rotary_dim,
119+
HR,
120+
q_pe_stride_s,
121+
out_stride_qs,
122+
out_stride_ks,
123+
HK,
124+
k_pe_stride_s,
125+
q_pe_stride_n,
126+
out_stride_qn);
127+
});
128+
return std::make_tuple(q_pe_out, k_pe_out);
129+
}

sgl-kernel/csrc/cpu/torch_extension_cpu.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ void initialize(int size, int rank);
8181
// shared mmeory all_reduce
8282
void shm_allreduce(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op);
8383

84+
// rope
85+
std::tuple<at::Tensor, at::Tensor> rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe,
86+
at::Tensor& k_pe, at::Tensor& t_emb_pos);
87+
8488
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8589
// activation
8690
m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU");
@@ -122,4 +126,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
122126
// all reduce
123127
m.def("initialize", &initialize, "shared memory initialization for CPU");
124128
m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU");
129+
130+
// rope
131+
m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU");
125132
}

sgl-kernel/python/sgl_kernel/cpu.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,15 @@ def int8_scaled_mm(
200200

201201
def per_token_quant_int8(x):
202202
return sgl_kernel.common_ops.per_token_quant_int8_cpu(x)
203+
def rotary_position_embedding(
204+
t_pos,
205+
q_pe,
206+
k_pe,
207+
t_emb_pos,
208+
):
209+
return sgl_kernel.common_ops.rotary_position_embedding_cpu(
210+
t_pos,
211+
q_pe,
212+
k_pe,
213+
t_emb_pos,
214+
)

sgl-kernel/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def copy_deepgemm_to_build_lib(self):
155155
"csrc/cpu/moe.cpp",
156156
"csrc/cpu/moe_int8.cpp",
157157
"csrc/cpu/norm.cpp",
158+
"csrc/cpu/rope.cpp",
158159
"csrc/cpu/topk.cpp",
159160
"csrc/cpu/interface.cpp",
160161
"csrc/cpu/shm.cpp",

test/srt/test_rope.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import unittest
2+
import expecttest
3+
4+
import torch
5+
import sgl_kernel.cpu
6+
7+
class TestROPE(expecttest.TestCase):
8+
def test_deepseek_v2_rope(self):
9+
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
10+
x1 = x[..., : x.shape[-1] // 2]
11+
x2 = x[..., x.shape[-1] // 2 :]
12+
return torch.cat((-x2, x1), dim=-1)
13+
14+
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
15+
x1 = x[..., ::2]
16+
x2 = x[..., 1::2]
17+
x = torch.stack((-x2, x1), dim=-1)
18+
return x.flatten(-2)
19+
20+
def forward_ref(positions, query, key, cos_sin_cache, offsets=None):
21+
self.rotary_dim = 64
22+
self.head_size = 64
23+
self.is_neox_style = False
24+
query_rot = query[..., : self.rotary_dim]
25+
key_rot = key[..., : self.rotary_dim]
26+
if self.rotary_dim < self.head_size:
27+
query_pass = query[..., self.rotary_dim :]
28+
key_pass = key[..., self.rotary_dim :]
29+
30+
cos_sin = cos_sin_cache[
31+
torch.add(positions, offsets) if offsets is not None else positions
32+
]
33+
cos, sin = cos_sin.chunk(2, dim=-1)
34+
if self.is_neox_style:
35+
# NOTE(woosuk): Here we assume that the positions tensor has the
36+
# shape [batch_size, seq_len].
37+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
38+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
39+
else:
40+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
41+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
42+
43+
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
44+
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
45+
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
46+
47+
if self.rotary_dim < self.head_size:
48+
query = torch.cat((query_rot, query_pass), dim=-1)
49+
key = torch.cat((key_rot, key_pass), dim=-1)
50+
else:
51+
query = query_rot
52+
key = key_rot
53+
return query, key
54+
55+
num_head = 16
56+
seq_len = 1024
57+
q_head_dim = 192
58+
qk_nope_head_dim = 128
59+
qk_rope_head_dim = 64
60+
max_pos = 256
61+
k_dim = 576
62+
63+
# Create cos_sin_cache
64+
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
65+
cos = freqs.cos() * 0.7
66+
sin = freqs.sin() * 0.7
67+
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
68+
positions = torch.randint(0, max_pos, (seq_len,))
69+
70+
for dtype in [torch.bfloat16]:
71+
enable_autocast = True
72+
73+
with torch.no_grad(), torch.cpu.amp.autocast(enabled=enable_autocast):
74+
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
75+
q_clone = q.clone()
76+
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
77+
k_clone = k.clone()
78+
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
79+
_, q_pe_clone = q_clone.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
80+
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
81+
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]
82+
83+
# ref kernel
84+
q_pe, k_pe = forward_ref(positions, q_pe, k_pe, cos_sin_cache)
85+
86+
# fused rope kernel
87+
q_pe_clone, k_pe_clone = sgl_kernel.cpu.rotary_position_embedding(
88+
positions, q_pe_clone, k_pe_clone, cos_sin_cache
89+
)
90+
91+
assert torch.allclose(q_pe, q_pe_clone)
92+
assert torch.allclose(k_pe, k_pe_clone)
93+
94+
95+
if __name__ == "__main__":
96+
unittest.main()

0 commit comments

Comments
 (0)