|
5 | 5 | import triton
|
6 | 6 | import triton.language as tl
|
7 | 7 |
|
8 |
| -from sglang.srt.distributed import get_tensor_model_parallel_rank |
9 | 8 | from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
10 | 9 | from sglang.srt.utils import is_cuda
|
11 | 10 |
|
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
12 | 13 | _is_cuda = is_cuda()
|
13 | 14 | if _is_cuda:
|
14 | 15 | from sglang.srt.layers.quantization.fp8_kernel import (
|
15 | 16 | sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
16 | 17 | )
|
17 |
| -logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + try: |
| 20 | + from deep_gemm import ceil_div |
| 21 | + except ImportError: |
| 22 | + logger.error(f"Failed to import ceil_div from deep_gemm.") |
| 23 | + |
| 24 | +import triton.language as tl |
18 | 25 |
|
19 | 26 |
|
20 | 27 | @triton.jit
|
@@ -704,3 +711,334 @@ def grouped_gemm_triton(
|
704 | 711 | **config,
|
705 | 712 | )
|
706 | 713 | return c
|
| 714 | + |
| 715 | + |
| 716 | +@triton.jit |
| 717 | +def _fwd_kernel_ep_scatter_1( |
| 718 | + num_recv_tokens_per_expert, |
| 719 | + expert_start_loc, |
| 720 | + m_indices, |
| 721 | + num_experts: tl.constexpr, |
| 722 | + BLOCK_E: tl.constexpr, |
| 723 | + BLOCK_EXPERT_NUM: tl.constexpr, |
| 724 | +): |
| 725 | + cur_expert = tl.program_id(0) |
| 726 | + |
| 727 | + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) |
| 728 | + tokens_per_expert = tl.load( |
| 729 | + num_recv_tokens_per_expert + offset_cumsum, |
| 730 | + mask=offset_cumsum < num_experts, |
| 731 | + other=0, |
| 732 | + ) |
| 733 | + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert |
| 734 | + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) |
| 735 | + |
| 736 | + cur_expert_start = tl.load(expert_start_loc + cur_expert) |
| 737 | + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) |
| 738 | + |
| 739 | + m_indices_start_ptr = m_indices + cur_expert_start |
| 740 | + off_expert = tl.arange(0, BLOCK_E) |
| 741 | + |
| 742 | + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): |
| 743 | + tl.store( |
| 744 | + m_indices_start_ptr + start_m + off_expert, |
| 745 | + cur_expert, |
| 746 | + ) |
| 747 | + |
| 748 | + |
| 749 | +@triton.jit |
| 750 | +def _fwd_kernel_ep_scatter_2( |
| 751 | + total_token_num, |
| 752 | + expert_start_loc, |
| 753 | + recv_x, |
| 754 | + recv_x_stride0, |
| 755 | + recv_x_stride1, |
| 756 | + recv_x_scale, |
| 757 | + recv_x_scale_stride0, |
| 758 | + recv_x_scale_stride1, |
| 759 | + recv_topk, |
| 760 | + recv_topk_stride0, |
| 761 | + recv_topk_stride1, |
| 762 | + output_tensor, |
| 763 | + output_tensor_stride0, |
| 764 | + output_tensor_stride1, |
| 765 | + output_tensor_scale, |
| 766 | + output_tensor_scale_stride0, |
| 767 | + output_tensor_scale_stride1, |
| 768 | + output_index, |
| 769 | + output_index_stride0, |
| 770 | + output_index_stride1, |
| 771 | + topk_num: tl.constexpr, |
| 772 | + HIDDEN_SIZE: tl.constexpr, |
| 773 | + HIDDEN_SIZE_PAD: tl.constexpr, |
| 774 | + SCALE_HIDDEN_SIZE: tl.constexpr, |
| 775 | + SCALE_HIDDEN_SIZE_PAD: tl.constexpr, |
| 776 | +): |
| 777 | + start_token_id = tl.program_id(0) |
| 778 | + grid_num = tl.num_programs(0) |
| 779 | + |
| 780 | + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) |
| 781 | + mask = offset_in < HIDDEN_SIZE |
| 782 | + |
| 783 | + offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) |
| 784 | + mask_s = offset_in_s < SCALE_HIDDEN_SIZE |
| 785 | + |
| 786 | + for token_id in range(start_token_id, total_token_num, grid_num): |
| 787 | + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) |
| 788 | + to_copy_s = tl.load( |
| 789 | + recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s |
| 790 | + ) |
| 791 | + |
| 792 | + for topk_index in tl.range(0, topk_num, 1, num_stages=4): |
| 793 | + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) |
| 794 | + if expert_id >= 0: |
| 795 | + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) |
| 796 | + tl.store( |
| 797 | + output_index + token_id * output_index_stride0 + topk_index, |
| 798 | + dest_token_index, |
| 799 | + ) |
| 800 | + output_tensor_ptr = ( |
| 801 | + output_tensor + dest_token_index * output_tensor_stride0 |
| 802 | + ) |
| 803 | + output_tensor_scale_ptr = ( |
| 804 | + output_tensor_scale + dest_token_index * output_tensor_scale_stride0 |
| 805 | + ) |
| 806 | + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) |
| 807 | + tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s) |
| 808 | + |
| 809 | + |
| 810 | +# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py |
| 811 | +@torch.no_grad() |
| 812 | +def ep_scatter( |
| 813 | + recv_x: torch.Tensor, |
| 814 | + recv_x_scale: torch.Tensor, |
| 815 | + recv_topk: torch.Tensor, |
| 816 | + num_recv_tokens_per_expert: torch.Tensor, |
| 817 | + expert_start_loc: torch.Tensor, |
| 818 | + output_tensor: torch.Tensor, |
| 819 | + output_tensor_scale: torch.Tensor, |
| 820 | + m_indices: torch.Tensor, |
| 821 | + output_index: torch.Tensor, |
| 822 | +): |
| 823 | + BLOCK_E = 128 # token num of per expert is aligned to 128 |
| 824 | + BLOCK_D = 128 # block size of quantization |
| 825 | + num_warps = 8 |
| 826 | + num_experts = num_recv_tokens_per_expert.shape[0] |
| 827 | + hidden_size = recv_x.shape[1] |
| 828 | + # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) |
| 829 | + grid = num_experts |
| 830 | + |
| 831 | + assert m_indices.shape[0] % BLOCK_E == 0 |
| 832 | + |
| 833 | + _fwd_kernel_ep_scatter_1[(grid,)]( |
| 834 | + num_recv_tokens_per_expert, |
| 835 | + expert_start_loc, |
| 836 | + m_indices, |
| 837 | + num_experts=num_experts, |
| 838 | + num_warps=num_warps, |
| 839 | + BLOCK_E=BLOCK_E, |
| 840 | + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), |
| 841 | + ) |
| 842 | + |
| 843 | + grid = min(recv_topk.shape[0], 1024 * 8) |
| 844 | + |
| 845 | + _fwd_kernel_ep_scatter_2[(grid,)]( |
| 846 | + recv_topk.shape[0], |
| 847 | + expert_start_loc, |
| 848 | + recv_x, |
| 849 | + recv_x.stride(0), |
| 850 | + recv_x.stride(1), |
| 851 | + recv_x_scale, |
| 852 | + recv_x_scale.stride(0), |
| 853 | + recv_x_scale.stride(1), |
| 854 | + recv_topk, |
| 855 | + recv_topk.stride(0), |
| 856 | + recv_topk.stride(1), |
| 857 | + output_tensor, |
| 858 | + output_tensor.stride(0), |
| 859 | + output_tensor.stride(1), |
| 860 | + output_tensor_scale, |
| 861 | + output_tensor_scale.stride(0), |
| 862 | + output_tensor_scale.stride(1), |
| 863 | + output_index, |
| 864 | + output_index.stride(0), |
| 865 | + output_index.stride(1), |
| 866 | + topk_num=recv_topk.shape[1], |
| 867 | + num_warps=num_warps, |
| 868 | + HIDDEN_SIZE=hidden_size, |
| 869 | + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), |
| 870 | + SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D, |
| 871 | + SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D), |
| 872 | + ) |
| 873 | + return |
| 874 | + |
| 875 | + |
| 876 | +@triton.jit |
| 877 | +def _fwd_kernel_ep_gather( |
| 878 | + total_token_num, |
| 879 | + input_tensor, |
| 880 | + input_tensor_stride0, |
| 881 | + input_tensor_stride1, |
| 882 | + recv_topk_ids, |
| 883 | + recv_topk_ids_stride0, |
| 884 | + recv_topk_ids_stride1, |
| 885 | + recv_topk_weight, |
| 886 | + recv_topk_weight_stride0, |
| 887 | + recv_topk_weight_stride1, |
| 888 | + input_index, |
| 889 | + input_index_stride0, |
| 890 | + input_index_stride1, |
| 891 | + output_tensor, |
| 892 | + output_tensor_stride0, |
| 893 | + output_tensor_stride1, |
| 894 | + topk_num: tl.constexpr, |
| 895 | + BLOCK_D: tl.constexpr, |
| 896 | +): |
| 897 | + cur_block = tl.program_id(0) |
| 898 | + start_cur_token = tl.program_id(1) |
| 899 | + grid_num = tl.num_programs(1) |
| 900 | + |
| 901 | + for cur_token in range(start_cur_token, total_token_num, grid_num): |
| 902 | + off_d = tl.arange(0, BLOCK_D) |
| 903 | + accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) |
| 904 | + for topk_index in range(0, topk_num): |
| 905 | + expert_id = tl.load( |
| 906 | + recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index |
| 907 | + ) |
| 908 | + if expert_id >= 0: |
| 909 | + source_token_index = tl.load( |
| 910 | + input_index + cur_token * input_index_stride0 + topk_index |
| 911 | + ) |
| 912 | + acc_weight = tl.load( |
| 913 | + recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index |
| 914 | + ) |
| 915 | + tmp = tl.load( |
| 916 | + input_tensor |
| 917 | + + source_token_index * input_tensor_stride0 |
| 918 | + + cur_block * BLOCK_D |
| 919 | + + off_d |
| 920 | + ) |
| 921 | + accumulator += tmp.to(tl.float32) * acc_weight |
| 922 | + |
| 923 | + tl.store( |
| 924 | + output_tensor |
| 925 | + + cur_token * output_tensor_stride0 |
| 926 | + + cur_block * BLOCK_D |
| 927 | + + off_d, |
| 928 | + accumulator.to(output_tensor.dtype.element_ty), |
| 929 | + ) |
| 930 | + |
| 931 | + |
| 932 | +@torch.no_grad() |
| 933 | +def ep_gather( |
| 934 | + input_tensor: torch.Tensor, |
| 935 | + recv_topk_ids: torch.Tensor, |
| 936 | + recv_topk_weight: torch.Tensor, |
| 937 | + input_index: torch.Tensor, |
| 938 | + output_tensor: torch.Tensor, |
| 939 | +): |
| 940 | + BLOCK_D = 1024 # block size of quantization |
| 941 | + num_warps = 2 |
| 942 | + num_tokens = output_tensor.shape[0] |
| 943 | + hidden_size = input_tensor.shape[1] |
| 944 | + assert hidden_size % BLOCK_D == 0 |
| 945 | + grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) |
| 946 | + _fwd_kernel_ep_gather[grid]( |
| 947 | + num_tokens, |
| 948 | + input_tensor, |
| 949 | + input_tensor.stride(0), |
| 950 | + input_tensor.stride(1), |
| 951 | + recv_topk_ids, |
| 952 | + recv_topk_ids.stride(0), |
| 953 | + recv_topk_ids.stride(1), |
| 954 | + recv_topk_weight, |
| 955 | + recv_topk_weight.stride(0), |
| 956 | + recv_topk_weight.stride(1), |
| 957 | + input_index, |
| 958 | + input_index.stride(0), |
| 959 | + input_index.stride(1), |
| 960 | + output_tensor, |
| 961 | + output_tensor.stride(0), |
| 962 | + output_tensor.stride(1), |
| 963 | + topk_num=recv_topk_ids.shape[1], |
| 964 | + num_warps=num_warps, |
| 965 | + BLOCK_D=BLOCK_D, |
| 966 | + ) |
| 967 | + return |
| 968 | + |
| 969 | + |
| 970 | +# copy from |
| 971 | +# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 |
| 972 | +def get_tma_aligned_size(x: int, element_size: int) -> int: |
| 973 | + """ |
| 974 | + Global memory address of TMA must be 16-byte aligned. |
| 975 | + Since we use column-major layout for the LHS scaling tensor, |
| 976 | + the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. |
| 977 | +
|
| 978 | + Arguments: |
| 979 | + x: original M-axis shape of the LHS scaling tensor. |
| 980 | + element_size: element size of the LHS scaling tensor. |
| 981 | +
|
| 982 | + Returns: |
| 983 | + M-axis shape of the LHS scaling tensor after padding. |
| 984 | + """ |
| 985 | + tma_alignment_bytes = 16 |
| 986 | + assert tma_alignment_bytes % element_size == 0 |
| 987 | + alignment = tma_alignment_bytes // element_size |
| 988 | + return ceil_div(x, alignment) * alignment |
| 989 | + |
| 990 | + |
| 991 | +@triton.jit |
| 992 | +def _tma_align_input_scale_kernel( |
| 993 | + input_scale_ptr, |
| 994 | + output_ptr, |
| 995 | + m, |
| 996 | + k_div_block_size, |
| 997 | + input_scale_stride_m, |
| 998 | + input_scale_stride_k, |
| 999 | + output_stride_m, |
| 1000 | + output_stride_k, |
| 1001 | + BLOCK_SIZE_K: tl.constexpr, |
| 1002 | +): |
| 1003 | + pid_m = tl.program_id(axis=0) |
| 1004 | + grid_m = tl.num_programs(0) |
| 1005 | + k_offsets = tl.arange(0, BLOCK_SIZE_K) |
| 1006 | + |
| 1007 | + for m_base in range(pid_m, m, grid_m): |
| 1008 | + input_offset = ( |
| 1009 | + input_scale_ptr |
| 1010 | + + m_base * input_scale_stride_m |
| 1011 | + + k_offsets * input_scale_stride_k |
| 1012 | + ) |
| 1013 | + input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size) |
| 1014 | + |
| 1015 | + output_offset = ( |
| 1016 | + output_ptr + k_offsets * output_stride_k + m_base * output_stride_m |
| 1017 | + ) |
| 1018 | + tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size) |
| 1019 | + |
| 1020 | + |
| 1021 | +# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py |
| 1022 | +def tma_align_input_scale(input_scale: torch.Tensor): |
| 1023 | + assert input_scale.dim() == 2 |
| 1024 | + m, k_div_block_size = input_scale.shape |
| 1025 | + padd_m = get_tma_aligned_size(m, input_scale.element_size()) |
| 1026 | + output = torch.empty( |
| 1027 | + (k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device |
| 1028 | + ) |
| 1029 | + |
| 1030 | + grid_m = min(m, 8192) |
| 1031 | + BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size) |
| 1032 | + |
| 1033 | + _tma_align_input_scale_kernel[(grid_m,)]( |
| 1034 | + input_scale_ptr=input_scale, |
| 1035 | + output_ptr=output, |
| 1036 | + m=m, |
| 1037 | + k_div_block_size=k_div_block_size, |
| 1038 | + input_scale_stride_m=input_scale.stride(0), |
| 1039 | + input_scale_stride_k=input_scale.stride(1), |
| 1040 | + output_stride_m=output.stride(1), # Note: these are swapped |
| 1041 | + output_stride_k=output.stride(0), # for column-major |
| 1042 | + BLOCK_SIZE_K=BLOCK_SIZE_K, |
| 1043 | + ) |
| 1044 | + return output.t()[:m] |
0 commit comments