Skip to content

Commit 8879944

Browse files
authored
ROCm/AITER CK_MoE: update 2-stage kernels & support both Activations (#5228)
1 parent a879811 commit 8879944

File tree

1 file changed

+25
-13
lines changed
  • python/sglang/srt/layers/quantization

1 file changed

+25
-13
lines changed

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def prepare_fp8_layer_for_marlin(*args, **kwargs):
7171
_is_hip = is_hip()
7272

7373
if _is_hip:
74-
from aiter.fused_moe_bf16_asm import asm_moe
74+
from aiter import ActivationType
75+
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
7576
from aiter.ops.shuffle import shuffle_weight
7677

7778
_is_cuda = is_cuda()
@@ -487,7 +488,7 @@ def create_weights(
487488

488489
if self.quant_config.is_checkpoint_fp8_serialized:
489490
params_dtype = (
490-
torch.int32
491+
torch.uint32
491492
if get_bool_env_var("USE_INT4_WEIGHT")
492493
else torch.float8_e4m3fn
493494
)
@@ -822,12 +823,14 @@ def process_weights_hip_int4(self, layer: Module):
822823
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
823824
# Weight Permutation
824825
layer.w13_weight = torch.nn.Parameter(
825-
permute_weight(layer.w13_weight.data),
826+
# permute_weight(layer.w13_weight.data),
827+
shuffle_weight(layer.w13_weight.data, (16, 16)),
826828
requires_grad=False,
827829
)
828830
torch.cuda.empty_cache()
829831
layer.w2_weight = torch.nn.Parameter(
830-
permute_weight(layer.w2_weight.data),
832+
# permute_weight(layer.w2_weight.data),
833+
shuffle_weight(layer.w2_weight.data, (16, 16)),
831834
requires_grad=False,
832835
)
833836
torch.cuda.empty_cache()
@@ -867,12 +870,14 @@ def process_weights_hip_scale_padding(self, layer: Module):
867870

868871
if get_bool_env_var("CK_MOE"):
869872
layer.w13_weight = torch.nn.Parameter(
870-
permute_weight(layer.w13_weight.data),
873+
# permute_weight(layer.w13_weight.data),
874+
shuffle_weight(layer.w13_weight.data, (16, 16)),
871875
requires_grad=False,
872876
)
873877
torch.cuda.empty_cache()
874878
layer.w2_weight = torch.nn.Parameter(
875-
permute_weight(layer.w2_weight.data),
879+
# permute_weight(layer.w2_weight.data),
880+
shuffle_weight(layer.w2_weight.data, (16, 16)),
876881
requires_grad=False,
877882
)
878883
torch.cuda.empty_cache()
@@ -928,23 +933,25 @@ def apply(
928933
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
929934
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
930935
assert not no_combine, f"{no_combine=} is not supported."
931-
return asm_moe(
936+
return ck_moe_2stages_win4(
932937
x,
933938
layer.w13_weight,
934939
layer.w2_weight,
935940
topk_weights,
936941
topk_ids,
937942
layer.w13_weight_scale1,
938943
layer.w2_weight_scale1,
939-
activation=activation,
944+
activation=(
945+
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
946+
),
940947
)
941948
if _is_hip and get_bool_env_var("CK_MOE"):
942-
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
943-
assert (
944-
activation == "silu"
945-
), f"CK_MOE: FP8 and/or FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
946949
assert not no_combine, f"{no_combine=} is not supported."
947950
if self.block_quant:
951+
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
952+
assert (
953+
activation == "silu"
954+
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
948955
return asm_moe(
949956
x,
950957
layer.w13_weight,
@@ -957,14 +964,19 @@ def apply(
957964
expert_mask=None,
958965
)
959966
else:
960-
return asm_moe(
967+
return ck_moe_2stages(
961968
x,
962969
layer.w13_weight,
963970
layer.w2_weight,
964971
topk_weights,
965972
topk_ids,
966973
layer.w13_weight_scale1,
967974
layer.w2_weight_scale1,
975+
activation=(
976+
ActivationType.Silu
977+
if activation == "silu"
978+
else ActivationType.Gelu
979+
),
968980
)
969981
else:
970982
# Expert fusion with FP8 quantization

0 commit comments

Comments
 (0)