Skip to content

Commit 125ce44

Browse files
renganxufacebook-github-bot
authored andcommitted
jagged_dense_bmm operator optimization (pytorch#1643)
Summary: Pull Request resolved: pytorch#1643 This diff optimizes the jagged_dense_bmm operator with the following optimizations: * tiling across thread blocks, and use GPU shared memory for thread block * tiling across threads within a thread block, and use registers for each thread Reviewed By: brad-mengchi Differential Revision: D43674845 fbshipit-source-id: 85f0abf89fa958f79636ef59c3070a1c569b73c2
1 parent 35bdd40 commit 125ce44

File tree

2 files changed

+158
-30
lines changed

2 files changed

+158
-30
lines changed

fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ static constexpr int32_t kWarpSize = 32;
6262
#endif
6363
// Max thread num in one thread block
6464
static constexpr int32_t kMaxThreads = 1024;
65+
// Max block size in Y dimension of a grid
66+
static constexpr int32_t kMaxBlockYDim = 65535;
67+
// Max block size in Z dimension of a grid
68+
static constexpr int32_t kMaxBlockZDim = 65535;
69+
6570
static constexpr float kQParamEps = 1e-8f;
6671

6772
/* For rowwise int8 quantization, two quantization parameters (qparams)

fbgemm_gpu/src/jagged_tensor_ops.cu

Lines changed: 153 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,36 +2071,135 @@ Tensor jagged_jagged_bmm_forward(
20712071
return output;
20722072
}
20732073
2074-
template <typename index_t, typename scalar_t>
2074+
template <
2075+
const int BLOCK_TILE_M, // tile height of C that each thread block
2076+
// calculates
2077+
const int BLOCK_TILE_N, // tile width of C that each thread block
2078+
// calculates
2079+
const int BLOCK_TILE_K, // tile width of A that each thread block calculates
2080+
const int THREAD_TILE_M, // tile height of C that each thread
2081+
// calculates
2082+
const int THREAD_TILE_N, // tile width of C that each thread calcualtes
2083+
typename index_t,
2084+
typename scalar_t>
20752085
__global__ __launch_bounds__(kMaxThreads) void jagged_dense_bmm_kernel(
2076-
const at::PackedTensorAccessor32<scalar_t, 2> x_values,
2077-
const at::PackedTensorAccessor32<index_t, 1> x_offsets,
2078-
const at::PackedTensorAccessor32<scalar_t, 3> y,
2079-
at::PackedTensorAccessor32<scalar_t, 2> output,
2086+
const at::PackedTensorAccessor32<scalar_t, 2> __restrict__ x_values,
2087+
const at::PackedTensorAccessor32<index_t, 1> __restrict__ x_offsets,
2088+
const at::PackedTensorAccessor32<scalar_t, 3> __restrict__ y,
2089+
at::PackedTensorAccessor32<scalar_t, 2> __restrict__ output,
20802090
const int max_L) {
20812091
const int B = x_offsets.size(0) - 1;
20822092
const int K = x_values.size(1);
20832093
const int N = y.size(2);
20842094
2085-
const int b_l_begin = blockIdx.x * blockDim.y + threadIdx.y;
2086-
const int b_l_step = gridDim.x * blockDim.y;
2087-
for (int b_l = b_l_begin; b_l < B * max_L; b_l += b_l_step) {
2088-
const int b = b_l / max_L;
2089-
const int l = b_l % max_L;
2095+
const auto block_row = blockIdx.y;
2096+
const auto block_col = blockIdx.x;
2097+
2098+
const int THREADS_X_PER_BLOCK = BLOCK_TILE_N / THREAD_TILE_N;
2099+
const int THREADS_Y_PER_BLOCK = BLOCK_TILE_M / THREAD_TILE_M;
2100+
const int THREADS_PER_BLOCK = THREADS_X_PER_BLOCK * THREADS_Y_PER_BLOCK;
2101+
const auto thread_row = threadIdx.x / THREADS_X_PER_BLOCK;
2102+
const auto thread_col = threadIdx.x % THREADS_X_PER_BLOCK;
2103+
const auto NUM_K_BLOCKS = (K + BLOCK_TILE_K - 1) / BLOCK_TILE_K;
2104+
2105+
__shared__ scalar_t As[BLOCK_TILE_M][BLOCK_TILE_K];
2106+
__shared__ scalar_t Bs[BLOCK_TILE_K][BLOCK_TILE_N];
2107+
2108+
for (auto b = blockIdx.z; b < B; b += gridDim.z) {
2109+
const index_t row_start = x_offsets[b];
2110+
const index_t row_end = x_offsets[b + 1];
2111+
const auto length = min(row_end - row_start, (index_t)max_L);
2112+
2113+
// the indices that this current will load into shared mem
2114+
const auto inner_row_a = threadIdx.x / BLOCK_TILE_K;
2115+
const auto inner_col_a = threadIdx.x % BLOCK_TILE_K;
2116+
// the number of rows of As that will be loaded per step by a thread block
2117+
const auto A_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_K;
2118+
2119+
const auto inner_row_b = threadIdx.x / BLOCK_TILE_N;
2120+
const auto inner_col_b = threadIdx.x % BLOCK_TILE_N;
2121+
const auto B_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_N;
2122+
2123+
// registers for C
2124+
scalar_t accum[THREAD_TILE_M][THREAD_TILE_N] = {0};
2125+
2126+
// registers for As and Bs
2127+
scalar_t fragment_a[THREAD_TILE_M] = {0};
2128+
scalar_t fragment_b[THREAD_TILE_N] = {0};
2129+
2130+
// loop for block tiles in K dimension
2131+
for (auto block = 0; block < NUM_K_BLOCKS; block++) {
2132+
// load a block of x_values from global memory to shared memory
2133+
// apply tiling for threads in a block
2134+
#pragma unroll
2135+
for (auto offset = 0; offset < BLOCK_TILE_M;
2136+
offset += A_TILE_ROW_STRIDE) {
2137+
auto x_row_offset = block_row * BLOCK_TILE_M + inner_row_a + offset;
2138+
auto x_col_offset = block * BLOCK_TILE_K + inner_col_a;
2139+
if ((x_row_offset < length) && (x_col_offset < K)) {
2140+
As[inner_row_a + offset][inner_col_a] =
2141+
x_values[row_start + x_row_offset][x_col_offset];
2142+
} else {
2143+
As[inner_row_a + offset][inner_col_a] = 0;
2144+
}
2145+
}
20902146
2091-
const int row_start = x_offsets[b];
2092-
const int row_end = x_offsets[b + 1];
2093-
const int length = min(row_end - row_start, max_L);
2094-
if (length == 0 || l >= length) {
2095-
return;
2096-
} else {
2097-
// TODO: use shared memory and better reduction
2098-
for (int n = threadIdx.x; n < N; n += blockDim.x) {
2099-
at::acc_type<scalar_t, true> acc = 0;
2100-
for (int k = 0; k < K; ++k) {
2101-
acc += x_values[row_start + l][k] * y[b][k][n];
2147+
// load a block of y from global memory to shared memory
2148+
// apply tiling for threads in a block
2149+
#pragma unroll
2150+
for (auto offset = 0; offset < BLOCK_TILE_K;
2151+
offset += B_TILE_ROW_STRIDE) {
2152+
auto y_row_offset = block * BLOCK_TILE_K + inner_row_b + offset;
2153+
auto y_col_offset = block_col * BLOCK_TILE_N + inner_col_b;
2154+
if ((y_row_offset < K) && (y_col_offset < N)) {
2155+
Bs[inner_row_b + offset][inner_col_b] =
2156+
y[b][y_row_offset][y_col_offset];
2157+
} else {
2158+
Bs[inner_row_b + offset][inner_col_b] = 0;
2159+
}
2160+
}
2161+
2162+
__syncthreads();
2163+
2164+
// calculate the results per thread
2165+
#pragma unroll
2166+
for (auto k = 0; k < BLOCK_TILE_K; k++) {
2167+
// load values from shared memory to registers for x_values
2168+
for (auto row = 0; row < THREAD_TILE_M; row++) {
2169+
fragment_a[row] = As[thread_row * THREAD_TILE_M + row][k];
2170+
}
2171+
2172+
// load values from shared memory to registers for y
2173+
#pragma unroll
2174+
for (auto col = 0; col < THREAD_TILE_N; col++) {
2175+
fragment_b[col] = Bs[k][thread_col * THREAD_TILE_N + col];
2176+
}
2177+
2178+
// each thread calcualtes THREAD_TILE_M * THREAD_TILE_N elements
2179+
#pragma unroll
2180+
for (auto row = 0; row < THREAD_TILE_M; row++) {
2181+
#pragma unroll
2182+
for (auto col = 0; col < THREAD_TILE_N; col++) {
2183+
accum[row][col] += fragment_a[row] * fragment_b[col];
2184+
}
2185+
}
2186+
}
2187+
2188+
__syncthreads();
2189+
}
2190+
2191+
// write the result to the output
2192+
#pragma unroll
2193+
for (auto row = 0; row < THREAD_TILE_M; row++) {
2194+
#pragma unroll
2195+
for (auto col = 0; col < THREAD_TILE_N; col++) {
2196+
auto out_row_offset =
2197+
block_row * BLOCK_TILE_M + thread_row * THREAD_TILE_M + row;
2198+
auto out_col_offset =
2199+
block_col * BLOCK_TILE_N + thread_col * THREAD_TILE_N + col;
2200+
if ((out_row_offset < length) && (out_col_offset < N)) {
2201+
output[row_start + out_row_offset][out_col_offset] = accum[row][col];
21022202
}
2103-
output[row_start + l][n] = acc;
21042203
}
21052204
}
21062205
}
@@ -2124,9 +2223,29 @@ Tensor jagged_dense_bmm_forward(
21242223
const int total_L = x_values.size(0);
21252224
auto output = at::zeros({total_L, N}, x_values.options());
21262225
if (B > 0 && M > 0 && N > 0) {
2127-
const int block_dim_x =
2128-
std::min(div_round_up(N, kWarpSize) * kWarpSize, kMaxThreads);
2129-
const int block_dim_y = kMaxThreads / block_dim_x;
2226+
// The shared memory size is (BLOCK_TILE_M + BLOCK_TILE_N) * BLOCK_TILE_K
2227+
// BLOCK_TILE_M needs to be multiple of THREAD_TILE_M, and
2228+
// BLOCK_TILE_N needs to be multiple of THREAD_TILE_N
2229+
// The setting of these parameters needs to balance the hardware's shared
2230+
// memory size limit and occupancy
2231+
// TODO: autotune these parameters based on max_L and input and output
2232+
// tensor sizes
2233+
constexpr int BLOCK_TILE_M = 64;
2234+
constexpr int BLOCK_TILE_N = 8;
2235+
constexpr int BLOCK_TILE_K = 8;
2236+
constexpr int THREAD_TILE_M = 4;
2237+
constexpr int THREAD_TILE_N = 4;
2238+
2239+
const dim3 block(
2240+
(BLOCK_TILE_M * BLOCK_TILE_N) / (THREAD_TILE_M * THREAD_TILE_N));
2241+
const auto grid_dim_x = div_round_up(N, BLOCK_TILE_N);
2242+
const auto grid_dim_y = div_round_up(max_L, BLOCK_TILE_M);
2243+
TORCH_CHECK(
2244+
grid_dim_y <= kMaxBlockYDim,
2245+
"max_L cannot be larger than",
2246+
grid_dim_y * BLOCK_TILE_M + 1 - BLOCK_TILE_M);
2247+
const auto grid_dim_z = std::min(B, kMaxBlockZDim);
2248+
const dim3 grid(grid_dim_x, grid_dim_y, grid_dim_z);
21302249
21312250
AT_DISPATCH_INDEX_TYPES(
21322251
x_offsets.scalar_type(), "jagged_dense_bmm_kernel_1", [&] {
@@ -2136,11 +2255,15 @@ Tensor jagged_dense_bmm_forward(
21362255
x_values.scalar_type(),
21372256
"jagged_dense_bmm_kernel_2",
21382257
[&] {
2139-
jagged_dense_bmm_kernel<index_t, scalar_t>
2140-
<<<div_round_up(B * max_L, block_dim_y),
2141-
dim3(block_dim_x, block_dim_y),
2142-
0,
2143-
at::cuda::getCurrentCUDAStream()>>>(
2258+
jagged_dense_bmm_kernel<
2259+
BLOCK_TILE_M,
2260+
BLOCK_TILE_N,
2261+
BLOCK_TILE_K,
2262+
THREAD_TILE_M,
2263+
THREAD_TILE_N,
2264+
index_t,
2265+
scalar_t>
2266+
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
21442267
x_values.packed_accessor32<scalar_t, 2>(),
21452268
x_offsets.packed_accessor32<index_t, 1>(),
21462269
y.packed_accessor32<scalar_t, 3>(),

0 commit comments

Comments
 (0)