Skip to content

Commit 6e313c1

Browse files
authored
Revert "Revert "fix: import vllm_rotary_embedding error when head_size not in 64, 128, 256, 512"" (#5777)
1 parent a45a4b2 commit 6e313c1

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

python/sglang/srt/layers/rotary_embedding.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
if _is_cuda:
1616
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
17-
else:
18-
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
1917

2018

2119
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -84,6 +82,12 @@ def __init__(
8482
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
8583
if not _is_cuda:
8684
cache = cache.to(dtype)
85+
86+
if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
87+
from vllm._custom_ops import rotary_embedding
88+
89+
self.vllm_rotary_embedding = rotary_embedding
90+
8791
self.cos_sin_cache: torch.Tensor
8892
self.register_buffer("cos_sin_cache", cache, persistent=False)
8993

@@ -160,7 +164,7 @@ def forward_cuda(
160164
)
161165
else:
162166
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
163-
vllm_rotary_embedding(
167+
self.vllm_rotary_embedding(
164168
positions,
165169
query,
166170
key,

0 commit comments

Comments
 (0)