Skip to content

[sgl-kernel] per token group quant support COLUMN MAJOR #4817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def sglang_per_token_group_quant_8bit(

def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
device = torch.device("cuda")
hidden_dim = group_size * 2
hidden_dim = 7168

x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
x = torch.randn(
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)

x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
x.clone(), group_size, dst_dtype
Expand Down Expand Up @@ -196,7 +198,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
device = torch.device("cuda")
hidden_dim = 7168

x = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16)
x = torch.randn(
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)

quantiles = [0.5, 0.2, 0.8]

Expand Down
68 changes: 49 additions & 19 deletions sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
return val;
}

template <typename T, typename DST_DTYPE>
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false>
__global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input,
void* __restrict__ output_q,
Expand All @@ -26,19 +26,30 @@ __global__ void per_token_group_quant_8bit_kernel(
const int groups_per_block,
const float eps,
const float min_8bit,
const float max_8bit) {
const float max_8bit,
const int scale_num_rows = 0,
const int scale_stride = 0) {
const int threads_per_group = 16;
const int local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;

const int block_group_id = blockIdx.x * groups_per_block;
const int block_group_offset = (block_group_id + local_group_id) * group_size;
const int global_group_id = block_group_id + local_group_id;
const int block_group_offset = global_group_id * group_size;

float local_absmax = eps;

const T* group_input = input + block_group_offset;
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
float* scale_output = output_s + (block_group_id + local_group_id);
float* scale_output;

if constexpr (IS_COLUMN_MAJOR) {
const int row_idx = global_group_id / scale_num_rows;
const int col_idx = global_group_id % scale_num_rows;
scale_output = output_s + (col_idx * scale_stride + row_idx);
} else {
scale_output = output_s + global_group_id;
}

constexpr uint32_t vec_size = 16 / sizeof(T);
using vec_t = flashinfer::vec_t<T, vec_size>;
Expand Down Expand Up @@ -88,11 +99,11 @@ void sgl_per_token_group_quant_8bit(
double max_8bit) {
CHECK_INPUT(input);
CHECK_INPUT(output_q);
CHECK_INPUT(output_s);

const int num_groups = input.numel() / group_size;

CHECK_EQ(input.numel() % group_size, 0);
CHECK_EQ(output_s.dim(), 2);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

Expand All @@ -114,20 +125,39 @@ void sgl_per_token_group_quant_8bit(
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;

#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
const int scale_num_rows = output_s.size(1);
const int scale_stride = output_s.stride(1);

#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
if (is_column_major) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
scale_num_rows, \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} while (0)

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
Expand Down
Loading
Loading