20
20
import torch .nn as nn
21
21
22
22
from sglang .srt .custom_op import CustomOp
23
- from sglang .srt .utils import is_cuda
23
+ from sglang .srt .utils import is_cuda , is_hip
24
24
25
25
_is_cuda = is_cuda ()
26
+ _is_hip = is_hip ()
26
27
27
28
if _is_cuda :
28
29
from sgl_kernel import (
32
33
rmsnorm ,
33
34
)
34
35
36
+ if _is_hip :
37
+ from vllm ._custom_ops import fused_add_rms_norm , rms_norm
35
38
36
39
logger = logging .getLogger (__name__ )
37
40
@@ -46,23 +49,49 @@ def __init__(
46
49
self .weight = nn .Parameter (torch .ones (hidden_size ))
47
50
self .variance_epsilon = eps
48
51
52
+ def forward (self , * args , ** kwargs ):
53
+ if torch .compiler .is_compiling ():
54
+ return self .forward_native (* args , ** kwargs )
55
+ if _is_cuda :
56
+ return self .forward_cuda (* args , ** kwargs )
57
+ elif _is_hip :
58
+ return self .forward_hip (* args , ** kwargs )
59
+ else :
60
+ return self .forward_native (* args , ** kwargs )
61
+
49
62
def forward_cuda (
50
63
self ,
51
64
x : torch .Tensor ,
52
65
residual : Optional [torch .Tensor ] = None ,
53
66
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
54
-
55
67
if residual is not None :
56
68
fused_add_rmsnorm (x , residual , self .weight .data , self .variance_epsilon )
57
69
return x , residual
58
70
out = rmsnorm (x , self .weight .data , self .variance_epsilon )
59
71
return out
60
72
73
+ def forward_hip (
74
+ self ,
75
+ x : torch .Tensor ,
76
+ residual : Optional [torch .Tensor ] = None ,
77
+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
78
+ if not x .is_contiguous ():
79
+ # NOTE: Romove this if aiter kernel supports discontinuous input
80
+ x = x .contiguous ()
81
+ if residual is not None :
82
+ fused_add_rms_norm (x , residual , self .weight .data , self .variance_epsilon )
83
+ return x , residual
84
+ out = torch .empty_like (x )
85
+ rms_norm (out , x , self .weight .data , self .variance_epsilon )
86
+ return out
87
+
61
88
def forward_native (
62
89
self ,
63
90
x : torch .Tensor ,
64
91
residual : Optional [torch .Tensor ] = None ,
65
92
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
93
+ if not x .is_contiguous ():
94
+ x = x .contiguous ()
66
95
orig_dtype = x .dtype
67
96
x = x .to (torch .float32 )
68
97
if residual is not None :
@@ -88,6 +117,14 @@ def __init__(
88
117
self .weight = nn .Parameter (torch .zeros (hidden_size ))
89
118
self .variance_epsilon = eps
90
119
120
+ def forward (self , * args , ** kwargs ):
121
+ if torch .compiler .is_compiling ():
122
+ return self .forward_native (* args , ** kwargs )
123
+ if _is_cuda :
124
+ return self .forward_cuda (* args , ** kwargs )
125
+ else :
126
+ return self .forward_native (* args , ** kwargs )
127
+
91
128
def forward_native (
92
129
self ,
93
130
x : torch .Tensor ,
@@ -139,8 +176,8 @@ def extra_repr(self):
139
176
return f"{ tuple (self .weight .shape )} , eps={ self .eps } "
140
177
141
178
142
- if not _is_cuda :
179
+ if not ( _is_cuda or _is_hip ) :
143
180
logger .info (
144
- "sgl-kernel is not available on Non-NV platforms . Fallback to other kernel libraries."
181
+ "sgl-kernel layernorm implementation is not available on current platform . Fallback to other kernel libraries."
145
182
)
146
183
from vllm .model_executor .layers .layernorm import GemmaRMSNorm , RMSNorm
0 commit comments