Skip to content

ROCm/AITER CK_MoE: update 2-stage kernels & support both Activations #5228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def prepare_fp8_layer_for_marlin(*args, **kwargs):
_is_hip = is_hip()

if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
from aiter.ops.shuffle import shuffle_weight

_is_cuda = is_cuda()
Expand Down Expand Up @@ -487,7 +488,7 @@ def create_weights(

if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = (
torch.int32
torch.uint32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ why change this? It's also be used by NVIDIA GPU

Copy link
Collaborator Author

@HaiShaw HaiShaw Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a case for uint32 packed int4 in serialized checkpoint case (within the flag), else case is generic case of OCP/NV FP8 data read from checkpoint.

if get_bool_env_var("USE_INT4_WEIGHT")
else torch.float8_e4m3fn
)
Expand Down Expand Up @@ -822,12 +823,14 @@ def process_weights_hip_int4(self, layer: Module):
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -867,12 +870,14 @@ def process_weights_hip_scale_padding(self, layer: Module):

if get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
# permute_weight(layer.w13_weight.data),
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
# permute_weight(layer.w2_weight.data),
shuffle_weight(layer.w2_weight.data, (16, 16)),
requires_grad=False,
)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -928,23 +933,25 @@ def apply(
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
return ck_moe_2stages_win4(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=activation,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
),
)
if _is_hip and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
return asm_moe(
x,
layer.w13_weight,
Expand All @@ -957,14 +964,19 @@ def apply(
expert_mask=None,
)
else:
return asm_moe(
return ck_moe_2stages(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
layer.w13_weight_scale1,
layer.w2_weight_scale1,
activation=(
ActivationType.Silu
if activation == "silu"
else ActivationType.Gelu
),
)
else:
# Expert fusion with FP8 quantization
Expand Down
Loading