Skip to content

Commit 02cabff

Browse files
authored
[V1] [ROCm] Enable EP with AITER Fused MoE (vllm-project#20270)
Signed-off-by: tjtanaa <[email protected]>
1 parent 3d19d47 commit 02cabff

File tree

4 files changed

+15
-5
lines changed

4 files changed

+15
-5
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,13 @@ def forward_cuda(
646646
indices_type=self.topk_indices_dtype)
647647

648648
if self.rocm_aiter_moe_enabled:
649-
assert expert_map is None
650649
return self.rocm_aiter_fused_experts(
651650
hidden_states=x,
652651
w1=layer.w13_weight,
653652
w2=layer.w2_weight,
654653
topk_weights=topk_weights,
655654
topk_ids=topk_ids,
655+
expert_map=expert_map,
656656
activation=activation,
657657
apply_router_weight_on_input=apply_router_weight_on_input)
658658
else:

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,20 @@ def rocm_aiter_fused_experts(
315315
w2_scale: Optional[torch.Tensor] = None,
316316
a1_scale: Optional[torch.Tensor] = None,
317317
a2_scale: Optional[torch.Tensor] = None,
318-
block_shape: Optional[list[int]] = None) -> torch.Tensor:
318+
block_shape: Optional[list[int]] = None,
319+
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
319320

320321
activation_method = (ActivationMethod.SILU
321322
if activation == "silu" else ActivationMethod.GELU)
322323
# All AITER Fused MoE kernels are expecting the following datatypes
323324
topk_weights = topk_weights.to(torch.float32)
324325
topk_ids = topk_ids.to(torch.int32)
325326

327+
if expert_map is not None:
328+
expert_mask = (expert_map > -1).to(torch.int32)
329+
else:
330+
expert_mask = None
331+
326332
# w8a8 per-channel quantization
327333
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
328334
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
@@ -346,7 +352,7 @@ def rocm_aiter_fused_experts(
346352
fc2_smooth_scale=None,
347353
a16=False,
348354
per_tensor_quant_scale=None,
349-
expert_mask=None,
355+
expert_mask=expert_mask,
350356
activation_method=activation_method)
351357

352358
else:
@@ -378,6 +384,7 @@ def rocm_aiter_fused_experts(
378384
w2,
379385
topk_weights,
380386
topk_ids,
387+
expert_mask=expert_mask,
381388
quant_method=quant_method,
382389
activation_method=activation_method,
383390
w1_scale=w1_scale,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,8 @@ def apply(
633633
w1_scale=layer.w13_weight_scale,
634634
w2_scale=layer.w2_weight_scale,
635635
a1_scale=layer.w13_input_scale,
636-
a2_scale=layer.w2_input_scale)
636+
a2_scale=layer.w2_input_scale,
637+
expert_map=expert_map)
637638
if self.use_marlin:
638639
assert activation == "silu", (
639640
f"{activation} not supported for Marlin MoE.")

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
442442
"""
443443

444444
def __init__(self, quant_config: Fp8Config):
445+
445446
from vllm.model_executor.layers.fused_moe import fused_experts
446447
self.quant_config = quant_config
447448
self.block_quant = self.quant_config.weight_block_size is not None
@@ -879,7 +880,8 @@ def apply(
879880
if self.block_quant else layer.w2_weight_scale),
880881
a1_scale=layer.w13_input_scale,
881882
a2_scale=layer.w2_input_scale,
882-
block_shape=self.quant_config.weight_block_size)
883+
block_shape=self.quant_config.weight_block_size,
884+
expert_map=expert_map)
883885
elif self.use_marlin:
884886
assert activation == "silu", (
885887
f"{activation} not supported for Marlin MoE.")

0 commit comments

Comments
 (0)