Skip to content

Support tuning moe for llama 4 model #5109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
93ab6b9
Add Llama4 support
CatherineSue Apr 3, 2025
73c5d6d
complete pipeline
ch-wan Apr 6, 2025
ca9870e
fix
ch-wan Apr 6, 2025
fdb0dd6
add locall_attn
CatherineSue Apr 6, 2025
2cd80c2
load weight
ch-wan Apr 6, 2025
5a56108
Merge branch 'main-upstream' into llama4
fzyzcjy Apr 6, 2025
ac4cca3
rm mllama4
fzyzcjy Apr 6, 2025
6cfb3a7
load experts
ch-wan Apr 6, 2025
9fd5188
load weight
ch-wan Apr 6, 2025
6afdfdf
Revert "rm mllama4"
fzyzcjy Apr 6, 2025
a8d4bff
Merge commit '9fd5188965867d0335d8dde357ec81b1a6880982' into pr/Cathe…
ch-wan Apr 6, 2025
b0703ec
polish code
ch-wan Apr 6, 2025
6b21ef5
cleanup
ispobock Apr 6, 2025
114a366
format
ispobock Apr 6, 2025
1378fe0
fix norm
ch-wan Apr 6, 2025
1f18b0c
add conversation template
ispobock Apr 6, 2025
5c434d7
apply_router_weight_on_input
ch-wan Apr 6, 2025
3dc59e1
add chat template
ispobock Apr 6, 2025
9266d96
format
ispobock Apr 6, 2025
a204c21
fix load
ispobock Apr 6, 2025
cedb65c
Merge branch 'main' into llama4
zhyncs Apr 6, 2025
cc7e862
support k > 1
ch-wan Apr 6, 2025
d8c4432
lint
ispobock Apr 6, 2025
f5d4cf7
fix
ch-wan Apr 6, 2025
49834d7
more
fzyzcjy Apr 6, 2025
a517a53
more
fzyzcjy Apr 6, 2025
95de87d
fix mlp
fzyzcjy Apr 6, 2025
7d45b7d
fix local_attn support
CatherineSue Apr 6, 2025
5cab0b5
Merge branch 'llama4' into feat/llama4_tuning
fzyzcjy Apr 6, 2025
fc81086
more
fzyzcjy Apr 6, 2025
82ee700
more
fzyzcjy Apr 6, 2025
f18a7de
Revert "more"
fzyzcjy Apr 6, 2025
f87f710
Revert "more"
fzyzcjy Apr 6, 2025
d789896
more
fzyzcjy Apr 6, 2025
885d525
more
fzyzcjy Apr 7, 2025
cf9b9e7
more
fzyzcjy Apr 7, 2025
3bf8b1d
more
fzyzcjy Apr 7, 2025
5d4f66a
tuning
fzyzcjy Apr 7, 2025
42baa99
tuning
fzyzcjy Apr 7, 2025
53de4ec
tuning
fzyzcjy Apr 7, 2025
21e9f04
tuning
fzyzcjy Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
Expand All @@ -19,6 +17,7 @@
get_moe_configs,
)
from sglang.srt.utils import is_hip
from transformers import AutoConfig

_is_hip_ = is_hip()

Expand Down Expand Up @@ -326,7 +325,7 @@ def tune(
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens} {best_config=}")
assert best_config is not None
return best_config

Expand Down Expand Up @@ -373,7 +372,7 @@ def save_configs(
block_shape,
)

print(f"Writing best config to {filename}...")
print(f"Writing best config to {filename} with content {configs=}")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
Expand All @@ -388,16 +387,19 @@ def main(args: argparse.Namespace):
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen2MoeForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
n_share_fusion_experts = args.n_share_experts_fusion
E = (
Expand All @@ -408,6 +410,14 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Llama4ForConditionalGeneration":
n_share_fusion_experts = args.n_share_experts_fusion
E = config.text_config.num_local_experts + n_share_fusion_experts
topk = config.text_config.num_experts_per_tok
intermediate_size = config.text_config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.text_config.hidden_size
elif config.architectures[0] in [
"Grok1ForCausalLM",
"Grok1ImgGen",
Expand All @@ -417,14 +427,15 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
else:
# Default: Mixtral
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size

hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
Expand Down Expand Up @@ -485,7 +496,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
for config in search_space
if block_k % config["BLOCK_SIZE_K"] == 0
]
print(f"Start tuning over {len(search_space)} configurations...")
print(f"Start tuning over {len(search_space)} configurations... ({E=} {shard_intermediate_size=} {dtype=} {block_shape=})")

start = time.time()
configs = _distribute(
Expand Down
2 changes: 1 addition & 1 deletion docs/references/supported_models.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Supported Models

## Generative Models
- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3
- Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 / Llama 3.3 / Llama 4
- Mistral / Mixtral / Mistral NeMo / Mistral Small 3
- Gemma / Gemma 2 / Gemma3
- Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL / Qwen 2.5 VL / Olympic Coder
Expand Down
24 changes: 24 additions & 0 deletions python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,30 @@ def get_chat_template_by_model_path(model_path):
)
)

# Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_chat_template(
ChatTemplate(
name="llama-4",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
"<|header_start|>system<|header_end|>\n\n",
"<|eot|>",
),
"user": (
"<|header_start|>user<|header_end|>\n\n",
"<|eot|>",
),
"assistant": (
"<|header_start|>assistant<|header_end|>\n\n",
"<|eot|>",
),
},
stop_str=("<|eot|>",),
image_token="<|image|>",
)
)

# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
)

# Check model type
self.is_generation = is_generation_model(
Expand Down Expand Up @@ -467,6 +470,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
"LlavaLlamaForCausalLM",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
Expand Down
28 changes: 27 additions & 1 deletion python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
LLAMA3 = auto()
LLAMA4 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
Expand Down Expand Up @@ -156,6 +157,19 @@ def get_prompt(self) -> str:
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA4:
ret = "<|begin_of_text|>"
if self.system_message:
ret += system_prompt
else:
ret += ""
for i, (role, message) in enumerate(self.messages):
if message:
ret += f"<|header_start|>{role}<|header_end|>\n\n"
ret += f"{message.strip()}<|eot|>"
else:
ret += f"<|header_start|>{role}<|header_end|>\n\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA3:
ret = "<|begin_of_text|>"
if self.system_message:
Expand All @@ -168,7 +182,6 @@ def get_prompt(self) -> str:
ret += f"{message.strip()}<|eot_id|>"
else:
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
# print(ret)
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
Expand Down Expand Up @@ -561,6 +574,19 @@ def generate_chat_conv(
)
)

# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
register_conv_template(
Conversation(
name="llama-4",
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA4,
sep="",
stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
image_token="<|image|>",
)
)

register_conv_template(
Conversation(
name="chatml",
Expand Down
Loading
Loading