3
3
import torch
4
4
import torch .nn .functional as F
5
5
from tqdm import tqdm
6
- from vllm .model_executor .layers .fused_moe import fused_moe as fused_moe_vllm
7
6
8
7
from sglang .srt .layers .activation import SiluAndMul
9
8
from sglang .srt .layers .moe .fused_moe_triton .fused_moe import fused_moe
@@ -45,20 +44,49 @@ def get_tolerance(self, dtype):
45
44
else :
46
45
return 1e-2 , 1e-2 # Default values for other types
47
46
48
- def torch_naive_moe (self , a , w1 , w2 , score , topk ):
47
+ def torch_naive_moe (
48
+ self ,
49
+ a ,
50
+ w1 ,
51
+ w2 ,
52
+ score ,
53
+ topk ,
54
+ w1_scale = None ,
55
+ w2_scale = None ,
56
+ a1_scale = None ,
57
+ a2_scale = None ,
58
+ ):
49
59
B , D = a .shape
50
60
a = a .view (B , - 1 , D ).repeat (1 , topk , 1 ).reshape (- 1 , D )
51
61
out = torch .zeros (B * topk , w2 .shape [1 ], dtype = a .dtype , device = a .device )
52
62
score = torch .softmax (score , dim = - 1 , dtype = torch .float32 )
53
63
topk_weight , topk_ids = torch .topk (score , topk )
54
64
topk_weight = topk_weight .view (- 1 )
55
65
topk_ids = topk_ids .view (- 1 )
56
- for i in range (w1 .shape [0 ]):
66
+
67
+ if w1 .dtype == torch .float8_e4m3fn :
68
+ w1_compute = w1 .to (a .dtype )
69
+ w2_compute = w2 .to (a .dtype )
70
+
71
+ if w1_scale is not None :
72
+ w1_compute = (w1_compute * w1_scale .view (- 1 , 1 , 1 )).to (a .dtype )
73
+ if w2_scale is not None :
74
+ w2_compute = (w2_compute * w2_scale .view (- 1 , 1 , 1 )).to (a .dtype )
75
+ if a1_scale is not None :
76
+ a = (a * a1_scale ).to (a .dtype )
77
+ if a2_scale is not None :
78
+ a = (a * a2_scale ).to (a .dtype )
79
+ else :
80
+ w1_compute = w1
81
+ w2_compute = w2
82
+
83
+ for i in range (w1_compute .shape [0 ]):
57
84
mask = topk_ids == i
58
85
if mask .sum ():
59
- out [mask ] = SiluAndMul ()(a [mask ] @ w1 [i ].transpose (0 , 1 )) @ w2 [
60
- i
61
- ].transpose (0 , 1 )
86
+ out [mask ] = SiluAndMul ()(
87
+ a [mask ] @ w1_compute [i ].transpose (0 , 1 )
88
+ ) @ w2_compute [i ].transpose (0 , 1 )
89
+
62
90
return (
63
91
out .view (B , - 1 , w2 .shape [1 ]) * topk_weight .view (B , - 1 , 1 ).to (out .dtype )
64
92
).sum (dim = 1 )
@@ -98,21 +126,12 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
98
126
a2_scale = a2_scale ,
99
127
)
100
128
101
- vllm_output = fused_moe_vllm (
102
- a ,
103
- w1 ,
104
- w2 ,
105
- score ,
106
- topk ,
107
- renormalize = False ,
108
- use_fp8_w8a8 = True ,
109
- w1_scale = w1_scale ,
110
- w2_scale = w2_scale ,
111
- a1_scale = a1_scale ,
112
- a2_scale = a2_scale ,
129
+ torch_output = self .torch_naive_moe (
130
+ a , w1 , w2 , score , topk , w1_scale , w2_scale , a1_scale , a2_scale
131
+ )
132
+ torch .testing .assert_close (
133
+ sglang_output , torch_output , rtol = rtol , atol = atol
113
134
)
114
-
115
- torch .testing .assert_close (sglang_output , vllm_output , rtol = rtol , atol = atol )
116
135
117
136
else :
118
137
a = self .create_random_cuda_tensor ((m , k ), dtype )
@@ -127,8 +146,8 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
127
146
)
128
147
129
148
def test_various_configurations (self ):
130
- m_values = [1 , 33 , 64 , 222 , 1024 * 128 ]
131
- n_values = [128 , 1024 , 2048 ]
149
+ m_values = [1 , 33 , 64 , 222 ]
150
+ n_values = [128 , 1024 ]
132
151
k_values = [128 , 511 , 1024 ]
133
152
dtypes = [torch .float16 , torch .bfloat16 ]
134
153
fp8_modes = [False , True ]
@@ -171,6 +190,7 @@ def test_various_configurations(self):
171
190
dtype ,
172
191
use_fp8_w8a8 = use_fp8_w8a8 ,
173
192
)
193
+ torch .cuda .empty_cache ()
174
194
pbar .update (1 )
175
195
176
196
0 commit comments