Skip to content

Commit 5febc5a

Browse files
mxz297facebook-github-bot
authored andcommitted
reduce overhead for f8f8bf16_rowwise_grouped_dynamic on amd (pytorch#823)
Summary: X-link: pytorch#3742 Pull Request resolved: facebookresearch/FBGEMM#823 When there is no need to zeroing output tensor, the argument setup kernel currently will launch many wasted thread blocks, and that can cause significant overhead. So we separate argument setup kernels into two kernels based on whether we need zeroing or not. Reviewed By: zjing14, jwfromm Differential Revision: D70327636 fbshipit-source-id: c68bc094972929ccf9773e31f9b8a362dc5037d3
1 parent f1ecae6 commit 5febc5a

File tree

1 file changed

+63
-17
lines changed

1 file changed

+63
-17
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ void set_static_kernel_args(
195195
}
196196
}
197197

198-
__global__ void set_kernel_args_fixed_nk_kernel(
198+
__global__ void set_kernel_args_fixed_nk_kernel_only(
199199
KernelArguments* kernel_args,
200200
ADataType* XQ,
201201
BDataType* WQ,
@@ -206,8 +206,41 @@ __global__ void set_kernel_args_fixed_nk_kernel(
206206
int M,
207207
int N,
208208
int K,
209-
int group_count,
210-
bool zeroing_output_tensor) {
209+
int group_count) {
210+
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
211+
// Each thread is responsible for setting up the arguments for one group.
212+
if (thread_idx < group_count) {
213+
// Compute offsets for this group.
214+
int group_M = prepad_M[thread_idx];
215+
KernelArguments kernel_group_args = {
216+
XQ + (thread_idx * M * K),
217+
WQ + (thread_idx * N * K),
218+
{w_scale + (thread_idx * N), x_scale + (thread_idx * M)},
219+
output + (thread_idx * M * N),
220+
group_M,
221+
N,
222+
K,
223+
K,
224+
K,
225+
{0, 0},
226+
N};
227+
// Write kernel args to memory.
228+
kernel_args[thread_idx] = kernel_group_args;
229+
}
230+
}
231+
232+
__global__ void set_kernel_args_fixed_nk_kernel_zeroing(
233+
KernelArguments* kernel_args,
234+
ADataType* XQ,
235+
BDataType* WQ,
236+
D0DataType* w_scale,
237+
D1DataType* x_scale,
238+
EDataType* output,
239+
int64_t* prepad_M,
240+
int M,
241+
int N,
242+
int K,
243+
int group_count) {
211244
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
212245
// Each thread is responsible for setting up the arguments for one group.
213246
if (thread_idx < group_count) {
@@ -228,7 +261,6 @@ __global__ void set_kernel_args_fixed_nk_kernel(
228261
// Write kernel args to memory.
229262
kernel_args[thread_idx] = kernel_group_args;
230263
}
231-
if (!zeroing_output_tensor) return;
232264

233265
// Figure out where in memory we are.
234266
// Each thread sets one float 4 which corresponds to 8 bf16 values.
@@ -284,19 +316,33 @@ void set_dynamic_kernel_args(
284316
int block_factor = std::max(group_count, (group_count * M * N) / BLOCK_SIZE);
285317
int blockSize = std::min(512, block_factor);
286318
int numBlocks = (block_factor + blockSize - 1) / blockSize;
287-
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
288-
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
289-
reinterpret_cast<ADataType*>(XQ.data_ptr()),
290-
reinterpret_cast<BDataType*>(WQ.data_ptr()),
291-
reinterpret_cast<D0DataType*>(w_scale.data_ptr()),
292-
reinterpret_cast<D1DataType*>(x_scale.data_ptr()),
293-
reinterpret_cast<EDataType*>(output.data_ptr()),
294-
reinterpret_cast<int64_t*>(zero_start_index_M.data_ptr()),
295-
M,
296-
N,
297-
K,
298-
group_count,
299-
zeroing_output_tensor);
319+
if (zeroing_output_tensor) {
320+
set_kernel_args_fixed_nk_kernel_zeroing<<<numBlocks, blockSize, 0, stream>>>(
321+
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
322+
reinterpret_cast<ADataType*>(XQ.data_ptr()),
323+
reinterpret_cast<BDataType*>(WQ.data_ptr()),
324+
reinterpret_cast<D0DataType*>(w_scale.data_ptr()),
325+
reinterpret_cast<D1DataType*>(x_scale.data_ptr()),
326+
reinterpret_cast<EDataType*>(output.data_ptr()),
327+
reinterpret_cast<int64_t*>(zero_start_index_M.data_ptr()),
328+
M,
329+
N,
330+
K,
331+
group_count);
332+
} else {
333+
set_kernel_args_fixed_nk_kernel_only<<<1, group_count, 0, stream>>>(
334+
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
335+
reinterpret_cast<ADataType*>(XQ.data_ptr()),
336+
reinterpret_cast<BDataType*>(WQ.data_ptr()),
337+
reinterpret_cast<D0DataType*>(w_scale.data_ptr()),
338+
reinterpret_cast<D1DataType*>(x_scale.data_ptr()),
339+
reinterpret_cast<EDataType*>(output.data_ptr()),
340+
reinterpret_cast<int64_t*>(zero_start_index_M.data_ptr()),
341+
M,
342+
N,
343+
K,
344+
group_count);
345+
}
300346
}
301347

302348
template <typename OutputType>

0 commit comments

Comments
 (0)