@@ -71,7 +71,8 @@ def prepare_fp8_layer_for_marlin(*args, **kwargs):
71
71
_is_hip = is_hip ()
72
72
73
73
if _is_hip :
74
- from aiter .fused_moe_bf16_asm import asm_moe
74
+ from aiter import ActivationType
75
+ from aiter .fused_moe_bf16_asm import asm_moe , ck_moe_2stages , ck_moe_2stages_win4
75
76
from aiter .ops .shuffle import shuffle_weight
76
77
77
78
_is_cuda = is_cuda ()
@@ -487,7 +488,7 @@ def create_weights(
487
488
488
489
if self .quant_config .is_checkpoint_fp8_serialized :
489
490
params_dtype = (
490
- torch .int32
491
+ torch .uint32
491
492
if get_bool_env_var ("USE_INT4_WEIGHT" )
492
493
else torch .float8_e4m3fn
493
494
)
@@ -822,12 +823,14 @@ def process_weights_hip_int4(self, layer: Module):
822
823
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
823
824
# Weight Permutation
824
825
layer .w13_weight = torch .nn .Parameter (
825
- permute_weight (layer .w13_weight .data ),
826
+ # permute_weight(layer.w13_weight.data),
827
+ shuffle_weight (layer .w13_weight .data , (16 , 16 )),
826
828
requires_grad = False ,
827
829
)
828
830
torch .cuda .empty_cache ()
829
831
layer .w2_weight = torch .nn .Parameter (
830
- permute_weight (layer .w2_weight .data ),
832
+ # permute_weight(layer.w2_weight.data),
833
+ shuffle_weight (layer .w2_weight .data , (16 , 16 )),
831
834
requires_grad = False ,
832
835
)
833
836
torch .cuda .empty_cache ()
@@ -867,12 +870,14 @@ def process_weights_hip_scale_padding(self, layer: Module):
867
870
868
871
if get_bool_env_var ("CK_MOE" ):
869
872
layer .w13_weight = torch .nn .Parameter (
870
- permute_weight (layer .w13_weight .data ),
873
+ # permute_weight(layer.w13_weight.data),
874
+ shuffle_weight (layer .w13_weight .data , (16 , 16 )),
871
875
requires_grad = False ,
872
876
)
873
877
torch .cuda .empty_cache ()
874
878
layer .w2_weight = torch .nn .Parameter (
875
- permute_weight (layer .w2_weight .data ),
879
+ # permute_weight(layer.w2_weight.data),
880
+ shuffle_weight (layer .w2_weight .data , (16 , 16 )),
876
881
requires_grad = False ,
877
882
)
878
883
torch .cuda .empty_cache ()
@@ -928,23 +933,25 @@ def apply(
928
933
if _is_hip and get_bool_env_var ("USE_INT4_WEIGHT" ):
929
934
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
930
935
assert not no_combine , f"{ no_combine = } is not supported."
931
- return asm_moe (
936
+ return ck_moe_2stages_win4 (
932
937
x ,
933
938
layer .w13_weight ,
934
939
layer .w2_weight ,
935
940
topk_weights ,
936
941
topk_ids ,
937
942
layer .w13_weight_scale1 ,
938
943
layer .w2_weight_scale1 ,
939
- activation = activation ,
944
+ activation = (
945
+ ActivationType .Silu if activation == "silu" else ActivationType .Gelu
946
+ ),
940
947
)
941
948
if _is_hip and get_bool_env_var ("CK_MOE" ):
942
- # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
943
- assert (
944
- activation == "silu"
945
- ), f"CK_MOE: FP8 and/or FP8 bloack_quant { activation = } will be supported later, unset CK_MOE"
946
949
assert not no_combine , f"{ no_combine = } is not supported."
947
950
if self .block_quant :
951
+ # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
952
+ assert (
953
+ activation == "silu"
954
+ ), f"CK_MOE: FP8 bloack_quant { activation = } will be supported later, unset CK_MOE"
948
955
return asm_moe (
949
956
x ,
950
957
layer .w13_weight ,
@@ -957,14 +964,19 @@ def apply(
957
964
expert_mask = None ,
958
965
)
959
966
else :
960
- return asm_moe (
967
+ return ck_moe_2stages (
961
968
x ,
962
969
layer .w13_weight ,
963
970
layer .w2_weight ,
964
971
topk_weights ,
965
972
topk_ids ,
966
973
layer .w13_weight_scale1 ,
967
974
layer .w2_weight_scale1 ,
975
+ activation = (
976
+ ActivationType .Silu
977
+ if activation == "silu"
978
+ else ActivationType .Gelu
979
+ ),
968
980
)
969
981
else :
970
982
# Expert fusion with FP8 quantization
0 commit comments