Skip to content

Commit 93c6fb1

Browse files
Fix: deepseek forward absorb (#5723)
Co-authored-by: ispobock <[email protected]>
1 parent 11e27d0 commit 93c6fb1

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

python/sglang/srt/layers/layernorm.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import torch.nn as nn
2121

2222
from sglang.srt.custom_op import CustomOp
23-
from sglang.srt.utils import is_cuda
23+
from sglang.srt.utils import is_cuda, is_hip
2424

2525
_is_cuda = is_cuda()
26+
_is_hip = is_hip()
2627

2728
if _is_cuda:
2829
from sgl_kernel import (
@@ -32,6 +33,8 @@
3233
rmsnorm,
3334
)
3435

36+
if _is_hip:
37+
from vllm._custom_ops import fused_add_rms_norm, rms_norm
3538

3639
logger = logging.getLogger(__name__)
3740

@@ -46,23 +49,49 @@ def __init__(
4649
self.weight = nn.Parameter(torch.ones(hidden_size))
4750
self.variance_epsilon = eps
4851

52+
def forward(self, *args, **kwargs):
53+
if torch.compiler.is_compiling():
54+
return self.forward_native(*args, **kwargs)
55+
if _is_cuda:
56+
return self.forward_cuda(*args, **kwargs)
57+
elif _is_hip:
58+
return self.forward_hip(*args, **kwargs)
59+
else:
60+
return self.forward_native(*args, **kwargs)
61+
4962
def forward_cuda(
5063
self,
5164
x: torch.Tensor,
5265
residual: Optional[torch.Tensor] = None,
5366
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
54-
5567
if residual is not None:
5668
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
5769
return x, residual
5870
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
5971
return out
6072

73+
def forward_hip(
74+
self,
75+
x: torch.Tensor,
76+
residual: Optional[torch.Tensor] = None,
77+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
78+
if not x.is_contiguous():
79+
# NOTE: Romove this if aiter kernel supports discontinuous input
80+
x = x.contiguous()
81+
if residual is not None:
82+
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
83+
return x, residual
84+
out = torch.empty_like(x)
85+
rms_norm(out, x, self.weight.data, self.variance_epsilon)
86+
return out
87+
6188
def forward_native(
6289
self,
6390
x: torch.Tensor,
6491
residual: Optional[torch.Tensor] = None,
6592
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
93+
if not x.is_contiguous():
94+
x = x.contiguous()
6695
orig_dtype = x.dtype
6796
x = x.to(torch.float32)
6897
if residual is not None:
@@ -88,6 +117,14 @@ def __init__(
88117
self.weight = nn.Parameter(torch.zeros(hidden_size))
89118
self.variance_epsilon = eps
90119

120+
def forward(self, *args, **kwargs):
121+
if torch.compiler.is_compiling():
122+
return self.forward_native(*args, **kwargs)
123+
if _is_cuda:
124+
return self.forward_cuda(*args, **kwargs)
125+
else:
126+
return self.forward_native(*args, **kwargs)
127+
91128
def forward_native(
92129
self,
93130
x: torch.Tensor,
@@ -139,8 +176,8 @@ def extra_repr(self):
139176
return f"{tuple(self.weight.shape)}, eps={self.eps}"
140177

141178

142-
if not _is_cuda:
179+
if not (_is_cuda or _is_hip):
143180
logger.info(
144-
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
181+
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
145182
)
146183
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm

0 commit comments

Comments
 (0)