Skip to content

Commit a550741

Browse files
BBufjianan-gu
authored andcommitted
[ci] fix ci test fused_moe op (sgl-project#5102)
1 parent 1a84cfb commit a550741

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class TestFile:
7676
TestFile("test_create_kvindices.py", 2),
7777
TestFile("test_hicache.py", 60),
7878
TestFile("test_hicache_mla.py", 90),
79+
TestFile("test_fused_moe.py", 30),
7980
TestFile("test_triton_moe_channel_fp8_kernel.py", 25),
8081
],
8182
"per-commit-2-gpu": [

test/srt/test_fused_moe.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
import torch.nn.functional as F
55
from tqdm import tqdm
6-
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
76

87
from sglang.srt.layers.activation import SiluAndMul
98
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
@@ -45,20 +44,49 @@ def get_tolerance(self, dtype):
4544
else:
4645
return 1e-2, 1e-2 # Default values for other types
4746

48-
def torch_naive_moe(self, a, w1, w2, score, topk):
47+
def torch_naive_moe(
48+
self,
49+
a,
50+
w1,
51+
w2,
52+
score,
53+
topk,
54+
w1_scale=None,
55+
w2_scale=None,
56+
a1_scale=None,
57+
a2_scale=None,
58+
):
4959
B, D = a.shape
5060
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
5161
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
5262
score = torch.softmax(score, dim=-1, dtype=torch.float32)
5363
topk_weight, topk_ids = torch.topk(score, topk)
5464
topk_weight = topk_weight.view(-1)
5565
topk_ids = topk_ids.view(-1)
56-
for i in range(w1.shape[0]):
66+
67+
if w1.dtype == torch.float8_e4m3fn:
68+
w1_compute = w1.to(a.dtype)
69+
w2_compute = w2.to(a.dtype)
70+
71+
if w1_scale is not None:
72+
w1_compute = (w1_compute * w1_scale.view(-1, 1, 1)).to(a.dtype)
73+
if w2_scale is not None:
74+
w2_compute = (w2_compute * w2_scale.view(-1, 1, 1)).to(a.dtype)
75+
if a1_scale is not None:
76+
a = (a * a1_scale).to(a.dtype)
77+
if a2_scale is not None:
78+
a = (a * a2_scale).to(a.dtype)
79+
else:
80+
w1_compute = w1
81+
w2_compute = w2
82+
83+
for i in range(w1_compute.shape[0]):
5784
mask = topk_ids == i
5885
if mask.sum():
59-
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[
60-
i
61-
].transpose(0, 1)
86+
out[mask] = SiluAndMul()(
87+
a[mask] @ w1_compute[i].transpose(0, 1)
88+
) @ w2_compute[i].transpose(0, 1)
89+
6290
return (
6391
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
6492
).sum(dim=1)
@@ -98,21 +126,12 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
98126
a2_scale=a2_scale,
99127
)
100128

101-
vllm_output = fused_moe_vllm(
102-
a,
103-
w1,
104-
w2,
105-
score,
106-
topk,
107-
renormalize=False,
108-
use_fp8_w8a8=True,
109-
w1_scale=w1_scale,
110-
w2_scale=w2_scale,
111-
a1_scale=a1_scale,
112-
a2_scale=a2_scale,
129+
torch_output = self.torch_naive_moe(
130+
a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale
131+
)
132+
torch.testing.assert_close(
133+
sglang_output, torch_output, rtol=rtol, atol=atol
113134
)
114-
115-
torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol)
116135

117136
else:
118137
a = self.create_random_cuda_tensor((m, k), dtype)
@@ -127,8 +146,8 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
127146
)
128147

129148
def test_various_configurations(self):
130-
m_values = [1, 33, 64, 222, 1024 * 128]
131-
n_values = [128, 1024, 2048]
149+
m_values = [1, 33, 64, 222]
150+
n_values = [128, 1024]
132151
k_values = [128, 511, 1024]
133152
dtypes = [torch.float16, torch.bfloat16]
134153
fp8_modes = [False, True]
@@ -171,6 +190,7 @@ def test_various_configurations(self):
171190
dtype,
172191
use_fp8_w8a8=use_fp8_w8a8,
173192
)
193+
torch.cuda.empty_cache()
174194
pbar.update(1)
175195

176196

0 commit comments

Comments
 (0)