Skip to content

Commit 924ca7c

Browse files
authored
Add DeepSeek V3/R1 shared experts fusion (#4918)
1 parent 6ff9c6a commit 924ca7c

File tree

14 files changed

+536
-36
lines changed

14 files changed

+536
-36
lines changed

benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,12 @@ def main(args: argparse.Namespace):
399399
intermediate_size = config.moe_intermediate_size
400400
shard_intermediate_size = 2 * intermediate_size // args.tp_size
401401
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
402-
E = config.n_routed_experts
402+
n_share_fusion_experts = args.n_share_experts_fusion
403+
E = (
404+
config.n_routed_experts + n_share_fusion_experts
405+
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
406+
else config.n_routed_experts
407+
)
403408
topk = config.num_experts_per_tok
404409
intermediate_size = config.moe_intermediate_size
405410
shard_intermediate_size = 2 * intermediate_size // args.tp_size
@@ -559,6 +564,12 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
559564
parser.add_argument("--seed", type=int, default=0)
560565
parser.add_argument("--batch-size", type=int, required=False)
561566
parser.add_argument("--tune", action="store_true")
567+
parser.add_argument(
568+
"--n-share-experts-fusion",
569+
type=int,
570+
default=0,
571+
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1",
572+
)
562573
args = parser.parse_args()
563574

564575
main(args)

python/sglang/bench_serving.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,16 @@ async def limited_request_func(request_func_input, pbar):
993993
return await request_func(request_func_input=request_func_input, pbar=pbar)
994994

995995
# Warmup
996-
print("Starting initial single prompt test run...")
996+
print(f"Starting warmup with {args.warmup_requests} sequences...")
997+
998+
# Use the first request for all warmup iterations
997999
test_prompt, test_prompt_len, test_output_len = input_requests[0]
9981000
if lora_names != None and len(lora_names) != 0:
9991001
lora_name = lora_names[0]
10001002
else:
10011003
lora_name = None
10021004

1005+
# Create the test input once
10031006
test_input = RequestFuncInput(
10041007
model=model_id,
10051008
prompt=test_prompt,
@@ -1009,14 +1012,26 @@ async def limited_request_func(request_func_input, pbar):
10091012
lora_name=lora_name,
10101013
extra_request_body=extra_request_body,
10111014
)
1012-
test_output = await request_func(request_func_input=test_input)
1013-
if not test_output.success:
1015+
1016+
# Run warmup requests
1017+
warmup_tasks = []
1018+
for _ in range(args.warmup_requests):
1019+
warmup_tasks.append(
1020+
asyncio.create_task(request_func(request_func_input=test_input))
1021+
)
1022+
1023+
warmup_outputs = await asyncio.gather(*warmup_tasks)
1024+
1025+
# Check if at least one warmup request succeeded
1026+
if not any(output.success for output in warmup_outputs):
10141027
raise ValueError(
1015-
"Initial test run failed - Please make sure benchmark arguments "
1016-
f"are correctly specified. Error: {test_output.error}"
1028+
"Warmup failed - Please make sure benchmark arguments "
1029+
f"are correctly specified. Error: {warmup_outputs[0].error}"
10171030
)
10181031
else:
1019-
print("Initial test run completed. Starting main benchmark run...")
1032+
print(
1033+
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
1034+
)
10201035

10211036
# Flush cache
10221037
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
@@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace):
12531268
if not hasattr(args, "max_concurrency"):
12541269
args.max_concurrency = None
12551270

1271+
# Set default value for warmup_requests if not present
1272+
if not hasattr(args, "warmup_requests"):
1273+
args.warmup_requests = 1
1274+
12561275
print(f"benchmark_args={args}")
12571276

12581277
# Set global environments
@@ -1560,6 +1579,12 @@ def __call__(self, parser, namespace, values, option_string=None):
15601579
action="store_true",
15611580
help="Flush the cache before running the benchmark",
15621581
)
1582+
parser.add_argument(
1583+
"--warmup-requests",
1584+
type=int,
1585+
default=1,
1586+
help="Number of warmup requests to run before the benchmark",
1587+
)
15631588

15641589
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
15651590
group.add_argument(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 128,
5+
"BLOCK_SIZE_K": 128,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 4,
8+
"num_stages": 4
9+
},
10+
"2": {
11+
"BLOCK_SIZE_M": 16,
12+
"BLOCK_SIZE_N": 128,
13+
"BLOCK_SIZE_K": 128,
14+
"GROUP_SIZE_M": 16,
15+
"num_warps": 4,
16+
"num_stages": 4
17+
},
18+
"4": {
19+
"BLOCK_SIZE_M": 16,
20+
"BLOCK_SIZE_N": 128,
21+
"BLOCK_SIZE_K": 128,
22+
"GROUP_SIZE_M": 32,
23+
"num_warps": 4,
24+
"num_stages": 4
25+
},
26+
"8": {
27+
"BLOCK_SIZE_M": 16,
28+
"BLOCK_SIZE_N": 128,
29+
"BLOCK_SIZE_K": 128,
30+
"GROUP_SIZE_M": 64,
31+
"num_warps": 4,
32+
"num_stages": 4
33+
},
34+
"16": {
35+
"BLOCK_SIZE_M": 16,
36+
"BLOCK_SIZE_N": 128,
37+
"BLOCK_SIZE_K": 128,
38+
"GROUP_SIZE_M": 1,
39+
"num_warps": 4,
40+
"num_stages": 3
41+
},
42+
"24": {
43+
"BLOCK_SIZE_M": 16,
44+
"BLOCK_SIZE_N": 128,
45+
"BLOCK_SIZE_K": 128,
46+
"GROUP_SIZE_M": 1,
47+
"num_warps": 4,
48+
"num_stages": 4
49+
},
50+
"32": {
51+
"BLOCK_SIZE_M": 16,
52+
"BLOCK_SIZE_N": 128,
53+
"BLOCK_SIZE_K": 128,
54+
"GROUP_SIZE_M": 1,
55+
"num_warps": 4,
56+
"num_stages": 5
57+
},
58+
"48": {
59+
"BLOCK_SIZE_M": 16,
60+
"BLOCK_SIZE_N": 128,
61+
"BLOCK_SIZE_K": 128,
62+
"GROUP_SIZE_M": 32,
63+
"num_warps": 4,
64+
"num_stages": 4
65+
},
66+
"64": {
67+
"BLOCK_SIZE_M": 16,
68+
"BLOCK_SIZE_N": 128,
69+
"BLOCK_SIZE_K": 128,
70+
"GROUP_SIZE_M": 64,
71+
"num_warps": 4,
72+
"num_stages": 4
73+
},
74+
"96": {
75+
"BLOCK_SIZE_M": 16,
76+
"BLOCK_SIZE_N": 128,
77+
"BLOCK_SIZE_K": 128,
78+
"GROUP_SIZE_M": 16,
79+
"num_warps": 4,
80+
"num_stages": 3
81+
},
82+
"128": {
83+
"BLOCK_SIZE_M": 16,
84+
"BLOCK_SIZE_N": 128,
85+
"BLOCK_SIZE_K": 128,
86+
"GROUP_SIZE_M": 32,
87+
"num_warps": 4,
88+
"num_stages": 3
89+
},
90+
"256": {
91+
"BLOCK_SIZE_M": 16,
92+
"BLOCK_SIZE_N": 128,
93+
"BLOCK_SIZE_K": 128,
94+
"GROUP_SIZE_M": 16,
95+
"num_warps": 4,
96+
"num_stages": 3
97+
},
98+
"512": {
99+
"BLOCK_SIZE_M": 64,
100+
"BLOCK_SIZE_N": 128,
101+
"BLOCK_SIZE_K": 128,
102+
"GROUP_SIZE_M": 32,
103+
"num_warps": 4,
104+
"num_stages": 4
105+
},
106+
"1024": {
107+
"BLOCK_SIZE_M": 64,
108+
"BLOCK_SIZE_N": 128,
109+
"BLOCK_SIZE_K": 128,
110+
"GROUP_SIZE_M": 16,
111+
"num_warps": 4,
112+
"num_stages": 4
113+
},
114+
"1536": {
115+
"BLOCK_SIZE_M": 64,
116+
"BLOCK_SIZE_N": 128,
117+
"BLOCK_SIZE_K": 128,
118+
"GROUP_SIZE_M": 32,
119+
"num_warps": 4,
120+
"num_stages": 4
121+
},
122+
"2048": {
123+
"BLOCK_SIZE_M": 64,
124+
"BLOCK_SIZE_N": 128,
125+
"BLOCK_SIZE_K": 128,
126+
"GROUP_SIZE_M": 32,
127+
"num_warps": 4,
128+
"num_stages": 4
129+
},
130+
"3072": {
131+
"BLOCK_SIZE_M": 64,
132+
"BLOCK_SIZE_N": 128,
133+
"BLOCK_SIZE_K": 128,
134+
"GROUP_SIZE_M": 16,
135+
"num_warps": 4,
136+
"num_stages": 4
137+
},
138+
"4096": {
139+
"BLOCK_SIZE_M": 64,
140+
"BLOCK_SIZE_N": 128,
141+
"BLOCK_SIZE_K": 128,
142+
"GROUP_SIZE_M": 16,
143+
"num_warps": 4,
144+
"num_stages": 4
145+
}
146+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 64,
4+
"BLOCK_SIZE_N": 64,
5+
"BLOCK_SIZE_K": 128,
6+
"GROUP_SIZE_M": 16,
7+
"num_warps": 4,
8+
"num_stages": 4
9+
},
10+
"2": {
11+
"BLOCK_SIZE_M": 64,
12+
"BLOCK_SIZE_N": 32,
13+
"BLOCK_SIZE_K": 128,
14+
"GROUP_SIZE_M": 1,
15+
"num_warps": 4,
16+
"num_stages": 3
17+
},
18+
"4": {
19+
"BLOCK_SIZE_M": 64,
20+
"BLOCK_SIZE_N": 64,
21+
"BLOCK_SIZE_K": 128,
22+
"GROUP_SIZE_M": 1,
23+
"num_warps": 4,
24+
"num_stages": 4
25+
},
26+
"8": {
27+
"BLOCK_SIZE_M": 64,
28+
"BLOCK_SIZE_N": 128,
29+
"BLOCK_SIZE_K": 128,
30+
"GROUP_SIZE_M": 32,
31+
"num_warps": 4,
32+
"num_stages": 3
33+
},
34+
"16": {
35+
"BLOCK_SIZE_M": 64,
36+
"BLOCK_SIZE_N": 128,
37+
"BLOCK_SIZE_K": 128,
38+
"GROUP_SIZE_M": 16,
39+
"num_warps": 4,
40+
"num_stages": 3
41+
},
42+
"24": {
43+
"BLOCK_SIZE_M": 64,
44+
"BLOCK_SIZE_N": 128,
45+
"BLOCK_SIZE_K": 128,
46+
"GROUP_SIZE_M": 16,
47+
"num_warps": 4,
48+
"num_stages": 3
49+
},
50+
"32": {
51+
"BLOCK_SIZE_M": 64,
52+
"BLOCK_SIZE_N": 128,
53+
"BLOCK_SIZE_K": 128,
54+
"GROUP_SIZE_M": 32,
55+
"num_warps": 4,
56+
"num_stages": 3
57+
},
58+
"48": {
59+
"BLOCK_SIZE_M": 64,
60+
"BLOCK_SIZE_N": 128,
61+
"BLOCK_SIZE_K": 128,
62+
"GROUP_SIZE_M": 32,
63+
"num_warps": 4,
64+
"num_stages": 3
65+
},
66+
"64": {
67+
"BLOCK_SIZE_M": 64,
68+
"BLOCK_SIZE_N": 128,
69+
"BLOCK_SIZE_K": 128,
70+
"GROUP_SIZE_M": 64,
71+
"num_warps": 4,
72+
"num_stages": 3
73+
},
74+
"96": {
75+
"BLOCK_SIZE_M": 64,
76+
"BLOCK_SIZE_N": 128,
77+
"BLOCK_SIZE_K": 128,
78+
"GROUP_SIZE_M": 64,
79+
"num_warps": 4,
80+
"num_stages": 3
81+
},
82+
"128": {
83+
"BLOCK_SIZE_M": 64,
84+
"BLOCK_SIZE_N": 128,
85+
"BLOCK_SIZE_K": 128,
86+
"GROUP_SIZE_M": 16,
87+
"num_warps": 4,
88+
"num_stages": 3
89+
},
90+
"256": {
91+
"BLOCK_SIZE_M": 64,
92+
"BLOCK_SIZE_N": 128,
93+
"BLOCK_SIZE_K": 128,
94+
"GROUP_SIZE_M": 16,
95+
"num_warps": 4,
96+
"num_stages": 3
97+
},
98+
"512": {
99+
"BLOCK_SIZE_M": 64,
100+
"BLOCK_SIZE_N": 128,
101+
"BLOCK_SIZE_K": 128,
102+
"GROUP_SIZE_M": 16,
103+
"num_warps": 4,
104+
"num_stages": 3
105+
},
106+
"1024": {
107+
"BLOCK_SIZE_M": 64,
108+
"BLOCK_SIZE_N": 128,
109+
"BLOCK_SIZE_K": 128,
110+
"GROUP_SIZE_M": 32,
111+
"num_warps": 4,
112+
"num_stages": 3
113+
},
114+
"1536": {
115+
"BLOCK_SIZE_M": 64,
116+
"BLOCK_SIZE_N": 128,
117+
"BLOCK_SIZE_K": 128,
118+
"GROUP_SIZE_M": 32,
119+
"num_warps": 4,
120+
"num_stages": 3
121+
},
122+
"2048": {
123+
"BLOCK_SIZE_M": 64,
124+
"BLOCK_SIZE_N": 128,
125+
"BLOCK_SIZE_K": 128,
126+
"GROUP_SIZE_M": 16,
127+
"num_warps": 4,
128+
"num_stages": 3
129+
},
130+
"3072": {
131+
"BLOCK_SIZE_M": 128,
132+
"BLOCK_SIZE_N": 64,
133+
"BLOCK_SIZE_K": 128,
134+
"GROUP_SIZE_M": 32,
135+
"num_warps": 4,
136+
"num_stages": 3
137+
},
138+
"4096": {
139+
"BLOCK_SIZE_M": 64,
140+
"BLOCK_SIZE_N": 128,
141+
"BLOCK_SIZE_K": 128,
142+
"GROUP_SIZE_M": 64,
143+
"num_warps": 4,
144+
"num_stages": 3
145+
}
146+
}

0 commit comments

Comments
 (0)