Skip to content

Commit 7862409

Browse files
sleepcooyyihuangch-wanxutizhouZhengHSI
authored andcommitted
DeepEP normal support deepgemm-contiguous (sgl-project#5626)
Co-authored-by: Yingyi Huang <[email protected]> Co-authored-by: Cheng Wan <[email protected]> Co-authored-by: Xuting Zhou <[email protected]> Co-authored-by: ZhengHSI <[email protected]>
1 parent 2ba27ea commit 7862409

File tree

6 files changed

+568
-59
lines changed

6 files changed

+568
-59
lines changed

python/sglang/srt/layers/moe/ep_moe/kernels.py

Lines changed: 340 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,23 @@
55
import triton
66
import triton.language as tl
77

8-
from sglang.srt.distributed import get_tensor_model_parallel_rank
98
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
109
from sglang.srt.utils import is_cuda
1110

11+
logger = logging.getLogger(__name__)
12+
1213
_is_cuda = is_cuda()
1314
if _is_cuda:
1415
from sglang.srt.layers.quantization.fp8_kernel import (
1516
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
1617
)
17-
logger = logging.getLogger(__name__)
18+
19+
try:
20+
from deep_gemm import ceil_div
21+
except ImportError:
22+
logger.error(f"Failed to import ceil_div from deep_gemm.")
23+
24+
import triton.language as tl
1825

1926

2027
@triton.jit
@@ -704,3 +711,334 @@ def grouped_gemm_triton(
704711
**config,
705712
)
706713
return c
714+
715+
716+
@triton.jit
717+
def _fwd_kernel_ep_scatter_1(
718+
num_recv_tokens_per_expert,
719+
expert_start_loc,
720+
m_indices,
721+
num_experts: tl.constexpr,
722+
BLOCK_E: tl.constexpr,
723+
BLOCK_EXPERT_NUM: tl.constexpr,
724+
):
725+
cur_expert = tl.program_id(0)
726+
727+
offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM)
728+
tokens_per_expert = tl.load(
729+
num_recv_tokens_per_expert + offset_cumsum,
730+
mask=offset_cumsum < num_experts,
731+
other=0,
732+
)
733+
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
734+
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
735+
736+
cur_expert_start = tl.load(expert_start_loc + cur_expert)
737+
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
738+
739+
m_indices_start_ptr = m_indices + cur_expert_start
740+
off_expert = tl.arange(0, BLOCK_E)
741+
742+
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
743+
tl.store(
744+
m_indices_start_ptr + start_m + off_expert,
745+
cur_expert,
746+
)
747+
748+
749+
@triton.jit
750+
def _fwd_kernel_ep_scatter_2(
751+
total_token_num,
752+
expert_start_loc,
753+
recv_x,
754+
recv_x_stride0,
755+
recv_x_stride1,
756+
recv_x_scale,
757+
recv_x_scale_stride0,
758+
recv_x_scale_stride1,
759+
recv_topk,
760+
recv_topk_stride0,
761+
recv_topk_stride1,
762+
output_tensor,
763+
output_tensor_stride0,
764+
output_tensor_stride1,
765+
output_tensor_scale,
766+
output_tensor_scale_stride0,
767+
output_tensor_scale_stride1,
768+
output_index,
769+
output_index_stride0,
770+
output_index_stride1,
771+
topk_num: tl.constexpr,
772+
HIDDEN_SIZE: tl.constexpr,
773+
HIDDEN_SIZE_PAD: tl.constexpr,
774+
SCALE_HIDDEN_SIZE: tl.constexpr,
775+
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
776+
):
777+
start_token_id = tl.program_id(0)
778+
grid_num = tl.num_programs(0)
779+
780+
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
781+
mask = offset_in < HIDDEN_SIZE
782+
783+
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
784+
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
785+
786+
for token_id in range(start_token_id, total_token_num, grid_num):
787+
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
788+
to_copy_s = tl.load(
789+
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
790+
)
791+
792+
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
793+
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
794+
if expert_id >= 0:
795+
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
796+
tl.store(
797+
output_index + token_id * output_index_stride0 + topk_index,
798+
dest_token_index,
799+
)
800+
output_tensor_ptr = (
801+
output_tensor + dest_token_index * output_tensor_stride0
802+
)
803+
output_tensor_scale_ptr = (
804+
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
805+
)
806+
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
807+
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
808+
809+
810+
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
811+
@torch.no_grad()
812+
def ep_scatter(
813+
recv_x: torch.Tensor,
814+
recv_x_scale: torch.Tensor,
815+
recv_topk: torch.Tensor,
816+
num_recv_tokens_per_expert: torch.Tensor,
817+
expert_start_loc: torch.Tensor,
818+
output_tensor: torch.Tensor,
819+
output_tensor_scale: torch.Tensor,
820+
m_indices: torch.Tensor,
821+
output_index: torch.Tensor,
822+
):
823+
BLOCK_E = 128 # token num of per expert is aligned to 128
824+
BLOCK_D = 128 # block size of quantization
825+
num_warps = 8
826+
num_experts = num_recv_tokens_per_expert.shape[0]
827+
hidden_size = recv_x.shape[1]
828+
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
829+
grid = num_experts
830+
831+
assert m_indices.shape[0] % BLOCK_E == 0
832+
833+
_fwd_kernel_ep_scatter_1[(grid,)](
834+
num_recv_tokens_per_expert,
835+
expert_start_loc,
836+
m_indices,
837+
num_experts=num_experts,
838+
num_warps=num_warps,
839+
BLOCK_E=BLOCK_E,
840+
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
841+
)
842+
843+
grid = min(recv_topk.shape[0], 1024 * 8)
844+
845+
_fwd_kernel_ep_scatter_2[(grid,)](
846+
recv_topk.shape[0],
847+
expert_start_loc,
848+
recv_x,
849+
recv_x.stride(0),
850+
recv_x.stride(1),
851+
recv_x_scale,
852+
recv_x_scale.stride(0),
853+
recv_x_scale.stride(1),
854+
recv_topk,
855+
recv_topk.stride(0),
856+
recv_topk.stride(1),
857+
output_tensor,
858+
output_tensor.stride(0),
859+
output_tensor.stride(1),
860+
output_tensor_scale,
861+
output_tensor_scale.stride(0),
862+
output_tensor_scale.stride(1),
863+
output_index,
864+
output_index.stride(0),
865+
output_index.stride(1),
866+
topk_num=recv_topk.shape[1],
867+
num_warps=num_warps,
868+
HIDDEN_SIZE=hidden_size,
869+
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
870+
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
871+
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
872+
)
873+
return
874+
875+
876+
@triton.jit
877+
def _fwd_kernel_ep_gather(
878+
total_token_num,
879+
input_tensor,
880+
input_tensor_stride0,
881+
input_tensor_stride1,
882+
recv_topk_ids,
883+
recv_topk_ids_stride0,
884+
recv_topk_ids_stride1,
885+
recv_topk_weight,
886+
recv_topk_weight_stride0,
887+
recv_topk_weight_stride1,
888+
input_index,
889+
input_index_stride0,
890+
input_index_stride1,
891+
output_tensor,
892+
output_tensor_stride0,
893+
output_tensor_stride1,
894+
topk_num: tl.constexpr,
895+
BLOCK_D: tl.constexpr,
896+
):
897+
cur_block = tl.program_id(0)
898+
start_cur_token = tl.program_id(1)
899+
grid_num = tl.num_programs(1)
900+
901+
for cur_token in range(start_cur_token, total_token_num, grid_num):
902+
off_d = tl.arange(0, BLOCK_D)
903+
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
904+
for topk_index in range(0, topk_num):
905+
expert_id = tl.load(
906+
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
907+
)
908+
if expert_id >= 0:
909+
source_token_index = tl.load(
910+
input_index + cur_token * input_index_stride0 + topk_index
911+
)
912+
acc_weight = tl.load(
913+
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
914+
)
915+
tmp = tl.load(
916+
input_tensor
917+
+ source_token_index * input_tensor_stride0
918+
+ cur_block * BLOCK_D
919+
+ off_d
920+
)
921+
accumulator += tmp.to(tl.float32) * acc_weight
922+
923+
tl.store(
924+
output_tensor
925+
+ cur_token * output_tensor_stride0
926+
+ cur_block * BLOCK_D
927+
+ off_d,
928+
accumulator.to(output_tensor.dtype.element_ty),
929+
)
930+
931+
932+
@torch.no_grad()
933+
def ep_gather(
934+
input_tensor: torch.Tensor,
935+
recv_topk_ids: torch.Tensor,
936+
recv_topk_weight: torch.Tensor,
937+
input_index: torch.Tensor,
938+
output_tensor: torch.Tensor,
939+
):
940+
BLOCK_D = 1024 # block size of quantization
941+
num_warps = 2
942+
num_tokens = output_tensor.shape[0]
943+
hidden_size = input_tensor.shape[1]
944+
assert hidden_size % BLOCK_D == 0
945+
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
946+
_fwd_kernel_ep_gather[grid](
947+
num_tokens,
948+
input_tensor,
949+
input_tensor.stride(0),
950+
input_tensor.stride(1),
951+
recv_topk_ids,
952+
recv_topk_ids.stride(0),
953+
recv_topk_ids.stride(1),
954+
recv_topk_weight,
955+
recv_topk_weight.stride(0),
956+
recv_topk_weight.stride(1),
957+
input_index,
958+
input_index.stride(0),
959+
input_index.stride(1),
960+
output_tensor,
961+
output_tensor.stride(0),
962+
output_tensor.stride(1),
963+
topk_num=recv_topk_ids.shape[1],
964+
num_warps=num_warps,
965+
BLOCK_D=BLOCK_D,
966+
)
967+
return
968+
969+
970+
# copy from
971+
# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58
972+
def get_tma_aligned_size(x: int, element_size: int) -> int:
973+
"""
974+
Global memory address of TMA must be 16-byte aligned.
975+
Since we use column-major layout for the LHS scaling tensor,
976+
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
977+
978+
Arguments:
979+
x: original M-axis shape of the LHS scaling tensor.
980+
element_size: element size of the LHS scaling tensor.
981+
982+
Returns:
983+
M-axis shape of the LHS scaling tensor after padding.
984+
"""
985+
tma_alignment_bytes = 16
986+
assert tma_alignment_bytes % element_size == 0
987+
alignment = tma_alignment_bytes // element_size
988+
return ceil_div(x, alignment) * alignment
989+
990+
991+
@triton.jit
992+
def _tma_align_input_scale_kernel(
993+
input_scale_ptr,
994+
output_ptr,
995+
m,
996+
k_div_block_size,
997+
input_scale_stride_m,
998+
input_scale_stride_k,
999+
output_stride_m,
1000+
output_stride_k,
1001+
BLOCK_SIZE_K: tl.constexpr,
1002+
):
1003+
pid_m = tl.program_id(axis=0)
1004+
grid_m = tl.num_programs(0)
1005+
k_offsets = tl.arange(0, BLOCK_SIZE_K)
1006+
1007+
for m_base in range(pid_m, m, grid_m):
1008+
input_offset = (
1009+
input_scale_ptr
1010+
+ m_base * input_scale_stride_m
1011+
+ k_offsets * input_scale_stride_k
1012+
)
1013+
input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size)
1014+
1015+
output_offset = (
1016+
output_ptr + k_offsets * output_stride_k + m_base * output_stride_m
1017+
)
1018+
tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size)
1019+
1020+
1021+
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py
1022+
def tma_align_input_scale(input_scale: torch.Tensor):
1023+
assert input_scale.dim() == 2
1024+
m, k_div_block_size = input_scale.shape
1025+
padd_m = get_tma_aligned_size(m, input_scale.element_size())
1026+
output = torch.empty(
1027+
(k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device
1028+
)
1029+
1030+
grid_m = min(m, 8192)
1031+
BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size)
1032+
1033+
_tma_align_input_scale_kernel[(grid_m,)](
1034+
input_scale_ptr=input_scale,
1035+
output_ptr=output,
1036+
m=m,
1037+
k_div_block_size=k_div_block_size,
1038+
input_scale_stride_m=input_scale.stride(0),
1039+
input_scale_stride_k=input_scale.stride(1),
1040+
output_stride_m=output.stride(1), # Note: these are swapped
1041+
output_stride_k=output.stride(0), # for column-major
1042+
BLOCK_SIZE_K=BLOCK_SIZE_K,
1043+
)
1044+
return output.t()[:m]

0 commit comments

Comments
 (0)