Skip to content

Commit 6f0eede

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Skip empty groups in FP8 Stacked Gemm (pytorch#952)
Summary: X-link: pytorch#3862 Pull Request resolved: facebookresearch/FBGEMM#952 Applies the same optimization used in D71510967 to cutlass fp8 grouped gemm. This should help performance for cases where G > M. Reviewed By: jiawenliu64 Differential Revision: D71582782 fbshipit-source-id: 05a86398164b1a4bd6af46e9af2ec7f5faabdeb0
1 parent 5120295 commit 6f0eede

File tree

4 files changed

+184
-67
lines changed

4 files changed

+184
-67
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def generate_group_tensor(G, M):
5151
# Finally, we multiply this tensor by M and round to the nearest integer
5252
output_tensor = torch.round(normalized_tensor * M).to(torch.int64)
5353
# Adjust the last element to ensure the sum is exactly M
54-
output_tensor[-1] += M - output_tensor.sum()
54+
output_tensor[-1] += max(0, M - output_tensor.sum())
5555
return output_tensor.tolist()
5656

5757

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -210,27 +210,55 @@ __global__ void set_kernel_args_m_sizes_kernel(
210210
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
211211
// Each thread is responsible for setting up the arguments for one group.
212212
if (thread_idx < group_count) {
213+
// In cases where M < G, we want to only set M groups since the rest are empty.
214+
// To do this, we use a counter into the group argument tensor.
215+
__shared__ int non_zero_counter;
216+
// Initialize the counter in the first thread.
217+
if (thread_idx == 0) {
218+
non_zero_counter = 0;
219+
}
220+
// We need to set a default argument for all M groups.
221+
KernelArguments default_group_args = {
222+
XQ,
223+
WQ,
224+
{w_scale, x_scale},
225+
output,
226+
0,
227+
0,
228+
0,
229+
0,
230+
0,
231+
{0, 0},
232+
0};
233+
kernel_args[thread_idx] = default_group_args;
234+
// Sync threads to get consistent state.
235+
__syncthreads();
213236
// Get M information for this group.
214237
int kernel_M = M_sizes[thread_idx];
215-
int offset_M = 0;
216-
// Offset is computed by finding the sum of previous group Ms.
217-
for (int i = 0; i < thread_idx; i++) {
218-
offset_M += M_sizes[i];
238+
// Only write actual group information if this group is nonzero.
239+
if (kernel_M > 0) {
240+
// Get index automatically for this group.
241+
int non_zero_idx = atomicAdd(&non_zero_counter, 1);
242+
int offset_M = 0;
243+
// Offset is computed by finding the sum of previous group Ms.
244+
for (int i = 0; i < thread_idx; i++) {
245+
offset_M += M_sizes[i];
246+
}
247+
KernelArguments kernel_group_args = {
248+
XQ + (offset_M * K),
249+
WQ + (thread_idx * N * K),
250+
{w_scale + (thread_idx * N), x_scale + offset_M},
251+
output + (offset_M * N),
252+
kernel_M,
253+
N,
254+
K,
255+
K,
256+
K,
257+
{0, 0},
258+
N};
259+
// Write kernel args to memory.
260+
kernel_args[non_zero_idx] = kernel_group_args;
219261
}
220-
KernelArguments kernel_group_args = {
221-
XQ + (offset_M * K),
222-
WQ + (thread_idx * N * K),
223-
{w_scale + (thread_idx * N), x_scale + offset_M},
224-
output + (offset_M * N),
225-
kernel_M,
226-
N,
227-
K,
228-
K,
229-
K,
230-
{0, 0},
231-
N};
232-
// Write kernel args to memory.
233-
kernel_args[thread_idx] = kernel_group_args;
234262
}
235263
}
236264

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,15 @@ OutputType f8f8bf16_rowwise_grouped_impl(
134134
// Get input information.
135135
int group_count;
136136
if constexpr (std::is_same_v<InputType, at::Tensor>) {
137-
group_count = WQ.size(0);
137+
// Two different modes when inputs are tensors.
138+
// If XQ is 3D then its shape is [G, M, K].
139+
// If its 2D then its shape is [total_M, K].
140+
if (XQ.dim() == 2) {
141+
// group count is the min of total_M and G.
142+
group_count = std::min(XQ.size(0), WQ.size(0));
143+
} else {
144+
group_count = WQ.size(0);
145+
}
138146
} else {
139147
group_count = XQ.size();
140148
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 128 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -107,30 +107,14 @@ __global__ void set_dynamic_kernel_args_kernel(
107107
StrideA* stride_a_ptr,
108108
StrideB* stride_b_ptr,
109109
StrideC* stride_c_ptr,
110-
std::optional<int64_t*> zero_start_index_M,
111-
std::optional<int64_t*> M_sizes) {
110+
int64_t* zero_start_index_M) {
112111
uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x;
113112
// If this thread corresponds to a valid group, write kernel args to device
114113
// memory.
115114
if (group_index < G) {
116115
// Compute shape for this group.
117-
int offset_M;
118-
int kernel_M;
119-
if (zero_start_index_M.has_value()) {
120-
// For inputs with padding, M is fixed and the number of rows
121-
// to operate on is available in zero_start_index_M.
122-
kernel_M = zero_start_index_M.value()[group_index];
123-
offset_M = group_index * M;
124-
} else {
125-
// M for this group is pulled directly from M_sizes.
126-
kernel_M = M_sizes.value()[group_index];
127-
// We compute the offset by getting the cumulative sum over
128-
// prior groups.
129-
offset_M = 0;
130-
for (int i = 0; i < group_index; i++) {
131-
offset_M += M_sizes.value()[i];
132-
}
133-
}
116+
int kernel_M = zero_start_index_M[group_index];
117+
int offset_M = group_index * M;
134118
// Set the problem shape for this group.
135119
problem_shape_ptr[group_index] = ProblemShape(N, kernel_M, K);
136120
// Set input pointers.
@@ -148,6 +132,82 @@ __global__ void set_dynamic_kernel_args_kernel(
148132
}
149133
}
150134

135+
template <
136+
typename ProblemShape,
137+
typename ElementA,
138+
typename ElementB,
139+
typename ElementC,
140+
typename ElementComputeEpilogue,
141+
typename StrideA,
142+
typename StrideB,
143+
typename StrideC>
144+
__global__ void set_stacked_kernel_args_kernel(
145+
int G,
146+
int N,
147+
int K,
148+
ProblemShape* problem_shape_ptr,
149+
ElementA* xq,
150+
const ElementA** xq_ptr,
151+
ElementB* wq,
152+
const ElementB** wq_ptr,
153+
ElementComputeEpilogue* x_scale,
154+
const ElementComputeEpilogue** x_scale_ptr,
155+
ElementComputeEpilogue* w_scale,
156+
const ElementComputeEpilogue** w_scale_ptr,
157+
ElementC* output,
158+
ElementC** output_ptr,
159+
StrideA* stride_a_ptr,
160+
StrideB* stride_b_ptr,
161+
StrideC* stride_c_ptr,
162+
int64_t* M_sizes) {
163+
uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x;
164+
// If this thread corresponds to a valid group, write kernel args to device
165+
// memory.
166+
if (group_index < G) {
167+
// Its possible that we're only writing a subset of the groups to
168+
// kernel args. To do this, we need to set all groups initially to empty.
169+
// and keep a problem counter for the number of non-empty groups.
170+
__shared__ int non_zero_counter;
171+
// Initialize counter in first group.
172+
if (group_index == 0) {
173+
non_zero_counter = 0;
174+
}
175+
// Set problem shapes to empty by default.
176+
problem_shape_ptr[group_index] = ProblemShape(0, 0, 0);
177+
// Sync threads to get consistent state in the block.
178+
__syncthreads();
179+
180+
// Compute shape for this group.
181+
// M for this group is pulled directly from M_sizes.
182+
int M = M_sizes[group_index];
183+
// Only proceed to writing kernel args if this group is non-empty.
184+
if (M > 0) {
185+
// Get the index for this group atomically.
186+
int non_zero_idx = atomicAdd(&non_zero_counter, 1);
187+
// We compute the offset by getting the cumulative sum over
188+
// prior groups.
189+
int offset_M = 0;
190+
for (int i = 0; i < group_index; i++) {
191+
offset_M += M_sizes[i];
192+
}
193+
// Set the problem shape for this group.
194+
problem_shape_ptr[non_zero_idx] = ProblemShape(N, M, K);
195+
// Set input pointers.
196+
xq_ptr[non_zero_idx] = xq + (offset_M * K);
197+
wq_ptr[non_zero_idx] = wq + (group_index * N * K);
198+
x_scale_ptr[non_zero_idx] = x_scale + offset_M;
199+
w_scale_ptr[non_zero_idx] = w_scale + (group_index * N);
200+
output_ptr[non_zero_idx] = output + (offset_M * N);
201+
stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
202+
StrideA{}, cute::make_shape(M, K, 1));
203+
stride_b_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
204+
StrideB{}, cute::make_shape(N, K, 1));
205+
stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride(
206+
StrideC{}, cute::make_shape(N, M, 1));
207+
}
208+
}
209+
}
210+
151211
template <
152212
typename InputType,
153213
int TB_M,
@@ -178,6 +238,8 @@ at::Tensor f8f8bf16_rowwise_grouped_impl(
178238
G = WQ.size(0);
179239
options = XQ.options();
180240
}
241+
// The number of groups the kernel uses may vary.
242+
int kernel_groups = G;
181243
// Return early if there are no elements in the output.
182244
if (output.numel() == 0) {
183245
return output;
@@ -421,41 +483,60 @@ at::Tensor f8f8bf16_rowwise_grouped_impl(
421483
int M = XQ.size(XQ.dim() - 2);
422484
int N = WQ.size(1);
423485
int K = WQ.size(2);
424-
std::optional<int64_t*> zero_start_index_M_ptr = std::nullopt;
425-
std::optional<int64_t*> M_sizes_ptr = std::nullopt;
426486
if (zero_start_index_M.has_value()) {
427-
zero_start_index_M_ptr =
487+
int64_t* zero_start_index_M_ptr =
428488
reinterpret_cast<int64_t*>(zero_start_index_M.value().data_ptr());
489+
set_dynamic_kernel_args_kernel<<<1, G, 0, stream>>>(
490+
G,
491+
M,
492+
N,
493+
K,
494+
problem_shape_ptr,
495+
reinterpret_cast<ElementA*>(XQ.data_ptr()),
496+
xq_ptr,
497+
reinterpret_cast<ElementB*>(WQ.data_ptr()),
498+
wq_ptr,
499+
reinterpret_cast<ElementComputeEpilogue*>(x_scale.data_ptr()),
500+
x_scale_ptr,
501+
reinterpret_cast<ElementComputeEpilogue*>(w_scale.data_ptr()),
502+
w_scale_ptr,
503+
reinterpret_cast<ElementC*>(output.data_ptr()),
504+
output_ptr,
505+
stride_a_ptr,
506+
stride_b_ptr,
507+
stride_c_ptr,
508+
zero_start_index_M_ptr);
509+
} else {
510+
int64_t* M_sizes_ptr =
511+
reinterpret_cast<int64_t*>(M_sizes.value().data_ptr());
512+
set_stacked_kernel_args_kernel<<<1, G, 0, stream>>>(
513+
G,
514+
N,
515+
K,
516+
problem_shape_ptr,
517+
reinterpret_cast<ElementA*>(XQ.data_ptr()),
518+
xq_ptr,
519+
reinterpret_cast<ElementB*>(WQ.data_ptr()),
520+
wq_ptr,
521+
reinterpret_cast<ElementComputeEpilogue*>(x_scale.data_ptr()),
522+
x_scale_ptr,
523+
reinterpret_cast<ElementComputeEpilogue*>(w_scale.data_ptr()),
524+
w_scale_ptr,
525+
reinterpret_cast<ElementC*>(output.data_ptr()),
526+
output_ptr,
527+
stride_a_ptr,
528+
stride_b_ptr,
529+
stride_c_ptr,
530+
M_sizes_ptr);
531+
// Set the number of groups to the kernel to be at most the number of
532+
// non-zero rows.
533+
kernel_groups = std::min(M, G);
429534
}
430-
if (M_sizes.has_value()) {
431-
M_sizes_ptr = reinterpret_cast<int64_t*>(M_sizes.value().data_ptr());
432-
}
433-
set_dynamic_kernel_args_kernel<<<1, G, 0, stream>>>(
434-
G,
435-
M,
436-
N,
437-
K,
438-
problem_shape_ptr,
439-
reinterpret_cast<ElementA*>(XQ.data_ptr()),
440-
xq_ptr,
441-
reinterpret_cast<ElementB*>(WQ.data_ptr()),
442-
wq_ptr,
443-
reinterpret_cast<ElementComputeEpilogue*>(x_scale.data_ptr()),
444-
x_scale_ptr,
445-
reinterpret_cast<ElementComputeEpilogue*>(w_scale.data_ptr()),
446-
w_scale_ptr,
447-
reinterpret_cast<ElementC*>(output.data_ptr()),
448-
output_ptr,
449-
stride_a_ptr,
450-
stride_b_ptr,
451-
stride_c_ptr,
452-
zero_start_index_M_ptr,
453-
M_sizes_ptr);
454535
}
455536

456537
typename Gemm::Arguments arguments{
457538
cutlass::gemm::GemmUniversalMode::kGrouped,
458-
{G, problem_shape_ptr, nullptr},
539+
{kernel_groups, problem_shape_ptr, nullptr},
459540
{wq_ptr, stride_b_ptr, xq_ptr, stride_a_ptr},
460541
{{}, nullptr, stride_c_ptr, output_ptr, stride_c_ptr}};
461542

0 commit comments

Comments
 (0)