Skip to content

Fuse shared experts in Llama 4 #5101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
93ab6b9
Add Llama4 support
CatherineSue Apr 3, 2025
73c5d6d
complete pipeline
ch-wan Apr 6, 2025
ca9870e
fix
ch-wan Apr 6, 2025
fdb0dd6
add locall_attn
CatherineSue Apr 6, 2025
2cd80c2
load weight
ch-wan Apr 6, 2025
5a56108
Merge branch 'main-upstream' into llama4
fzyzcjy Apr 6, 2025
ac4cca3
rm mllama4
fzyzcjy Apr 6, 2025
6cfb3a7
load experts
ch-wan Apr 6, 2025
9fd5188
load weight
ch-wan Apr 6, 2025
6afdfdf
Revert "rm mllama4"
fzyzcjy Apr 6, 2025
a8d4bff
Merge commit '9fd5188965867d0335d8dde357ec81b1a6880982' into pr/Cathe…
ch-wan Apr 6, 2025
b0703ec
polish code
ch-wan Apr 6, 2025
6b21ef5
cleanup
ispobock Apr 6, 2025
114a366
format
ispobock Apr 6, 2025
c097826
more
fzyzcjy Apr 6, 2025
470bd94
more
fzyzcjy Apr 6, 2025
63455cb
more
fzyzcjy Apr 6, 2025
a4bfbe8
more
fzyzcjy Apr 6, 2025
fb95032
more
fzyzcjy Apr 6, 2025
9cdbe53
more
fzyzcjy Apr 6, 2025
1a86c6a
more
fzyzcjy Apr 6, 2025
c5c5c9e
more
fzyzcjy Apr 6, 2025
27b9775
more
fzyzcjy Apr 6, 2025
2d72913
more
fzyzcjy Apr 6, 2025
762b0f1
more
fzyzcjy Apr 6, 2025
26c8ef5
more
fzyzcjy Apr 6, 2025
c2fec4b
more
fzyzcjy Apr 6, 2025
da62632
more
fzyzcjy Apr 6, 2025
8b5b0df
more
fzyzcjy Apr 6, 2025
13e7c10
more
fzyzcjy Apr 6, 2025
ab5347e
more
fzyzcjy Apr 6, 2025
91ef55d
fmt
fzyzcjy Apr 6, 2025
ea98be2
Merge branch 'llama4' into feat/llama_tom_fused_shared_expert
fzyzcjy Apr 6, 2025
f78f65f
more
fzyzcjy Apr 6, 2025
ac1af5b
more
fzyzcjy Apr 6, 2025
3899283
more
fzyzcjy Apr 6, 2025
d505e37
more
fzyzcjy Apr 6, 2025
79efd9a
fmt
fzyzcjy Apr 6, 2025
8f10fc1
more
fzyzcjy Apr 6, 2025
1f7e0f2
more
fzyzcjy Apr 6, 2025
90fb6b6
more
fzyzcjy Apr 6, 2025
786f22c
fmt
fzyzcjy Apr 6, 2025
6b2aeb8
more
fzyzcjy Apr 6, 2025
040b69d
more
fzyzcjy Apr 6, 2025
a8f2c73
more
fzyzcjy Apr 6, 2025
3321217
more
fzyzcjy Apr 6, 2025
a601d6f
fmt
fzyzcjy Apr 6, 2025
8dd3313
more
fzyzcjy Apr 6, 2025
1378fe0
fix norm
ch-wan Apr 6, 2025
30e160f
more
fzyzcjy Apr 6, 2025
8de5eb8
more
fzyzcjy Apr 6, 2025
c31344b
more
fzyzcjy Apr 6, 2025
5e19521
more
fzyzcjy Apr 6, 2025
c9b8b2c
Merge branch 'llama4' into feat/llama_tom_fused_shared_expert
fzyzcjy Apr 6, 2025
b1663c4
more
fzyzcjy Apr 6, 2025
407ec61
more
fzyzcjy Apr 6, 2025
1f18b0c
add conversation template
ispobock Apr 6, 2025
6a79fb7
more
fzyzcjy Apr 6, 2025
f123df7
fmt
fzyzcjy Apr 6, 2025
862cee9
more
fzyzcjy Apr 6, 2025
5c434d7
apply_router_weight_on_input
ch-wan Apr 6, 2025
3dc59e1
add chat template
ispobock Apr 6, 2025
8c36252
more
fzyzcjy Apr 6, 2025
ae98a1e
more
fzyzcjy Apr 6, 2025
9266d96
format
ispobock Apr 6, 2025
ec99690
Merge branch 'llama4' into feat/llama_tom_fused_shared_expert
fzyzcjy Apr 6, 2025
7b534fa
more
fzyzcjy Apr 6, 2025
a204c21
fix load
ispobock Apr 6, 2025
55dfe89
suggestion by chwan
fzyzcjy Apr 6, 2025
01473c2
Merge branch 'llama4' into feat/llama_tom_fused_shared_expert
fzyzcjy Apr 6, 2025
cedb65c
Merge branch 'main' into llama4
zhyncs Apr 6, 2025
6258644
Merge branch 'llama4' into feat/llama_tom_fused_shared_expert
fzyzcjy Apr 6, 2025
9433d8f
more
fzyzcjy Apr 6, 2025
cc7e862
support k > 1
ch-wan Apr 6, 2025
0d53105
Merge branch 'llama4' into feat/llama_tom_fused_shared_expert
fzyzcjy Apr 6, 2025
d227924
fmt
fzyzcjy Apr 6, 2025
ade4384
more
fzyzcjy Apr 6, 2025
5a07b9d
more
fzyzcjy Apr 6, 2025
773cdf6
more
fzyzcjy Apr 6, 2025
d8c4432
lint
ispobock Apr 6, 2025
f5d4cf7
fix
ch-wan Apr 6, 2025
49834d7
more
fzyzcjy Apr 6, 2025
a517a53
more
fzyzcjy Apr 6, 2025
95de87d
fix mlp
fzyzcjy Apr 6, 2025
7d45b7d
fix local_attn support
CatherineSue Apr 6, 2025
5cab0b5
Merge branch 'llama4' into feat/llama4_tuning
fzyzcjy Apr 6, 2025
fc81086
more
fzyzcjy Apr 6, 2025
82ee700
more
fzyzcjy Apr 6, 2025
f18a7de
Revert "more"
fzyzcjy Apr 6, 2025
f87f710
Revert "more"
fzyzcjy Apr 6, 2025
d789896
more
fzyzcjy Apr 6, 2025
663ae1e
Merge branch 'llama4' into feat/llama_tom_fused_shared_expert
fzyzcjy Apr 7, 2025
885d525
more
fzyzcjy Apr 7, 2025
cf9b9e7
more
fzyzcjy Apr 7, 2025
3bf8b1d
more
fzyzcjy Apr 7, 2025
5d4f66a
tuning
fzyzcjy Apr 7, 2025
42baa99
tuning
fzyzcjy Apr 7, 2025
add644b
Merge branch 'feat/llama4_tuning' into feat/llama_tom_fused_shared_ex…
fzyzcjy Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
Expand All @@ -19,6 +17,7 @@
get_moe_configs,
)
from sglang.srt.utils import is_hip
from transformers import AutoConfig

_is_hip_ = is_hip()

Expand Down Expand Up @@ -326,7 +325,7 @@ def tune(
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens} {best_config=}")
assert best_config is not None
return best_config

Expand Down Expand Up @@ -373,7 +372,7 @@ def save_configs(
block_shape,
)

print(f"Writing best config to {filename}...")
print(f"Writing best config to {filename} with content {configs=}")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
Expand All @@ -388,16 +387,19 @@ def main(args: argparse.Namespace):
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
n_share_fusion_experts = args.n_share_experts_fusion
E = (
Expand All @@ -408,6 +410,14 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
n_share_fusion_experts = args.n_share_experts_fusion
E = config.text_config.num_local_experts + n_share_fusion_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.text_config.hidden_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
Expand All @@ -417,14 +427,15 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size

hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
Expand Down Expand Up @@ -485,7 +496,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
print(f"Start tuning over {len(search_space)} configurations... ({E=} {shard_intermediate_size=} {dtype=} {block_shape=})")

start = time.time()
configs = _distribute(
Expand Down
2 changes: 1 addition & 1 deletion docs/references/supported_models.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Supported Models

## Generative Models
- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3
- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3 / Llama 4
- Mistral / Mixtral / Mistral NeMo / Mistral Small 3
- Gemma / Gemma 2 / Gemma3
- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL / Qwen 2.5 VL / Olympic Coder
Expand Down
24 changes: 24 additions & 0 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,30 @@ def get_chat_template_by_model_path(model_path):
)
)

# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="llama-4",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|header_start|>system<|header_end|>\n\n",
"<|eot|>",
),
"user": (
"<|header_start|>user<|header_end|>\n\n",
"<|eot|>",
),
"assistant": (
"<|header_start|>assistant<|header_end|>\n\n",
"<|eot|>",
),
},
stop_str=("<|eot|>",),
image_token="<|image|>",
)
)

# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)

# Check model type
self.is_generation = is_generation_model(
Expand Down Expand Up @@ -467,6 +470,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
"LlavaLlamaForCausalLM",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
Expand Down
28 changes: 27 additions & 1 deletion python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
LLAMA3 = auto()
LLAMA4 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
Expand Down Expand Up @@ -156,6 +157,19 @@ def get_prompt(self) -> str:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA4:
ret = "<|begin_of_text|>"
if self.system_message:
ret += system_prompt
else:
ret += ""
for i, (role, message) in enumerate(self.messages):
if message:
ret += f"<|header_start|>{role}<|header_end|>\n\n"
ret += f"{message.strip()}<|eot|>"
else:
ret += f"<|header_start|>{role}<|header_end|>\n\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA3:
ret = "<|begin_of_text|>"
if self.system_message:
Expand All @@ -168,7 +182,6 @@ def get_prompt(self) -> str:
ret += f"{message.strip()}<|eot_id|>"
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
# print(ret)
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
Expand Down Expand Up @@ -561,6 +574,19 @@ def generate_chat_conv(
)
)

# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_conv_template(
Conversation(
name="llama-4",
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA4,
sep="",
stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
image_token="<|image|>",
)
)

register_conv_template(
Conversation(
name="chatml",
Expand Down
Loading
Loading