@@ -2071,36 +2071,135 @@ Tensor jagged_jagged_bmm_forward(
2071
2071
return output;
2072
2072
}
2073
2073
2074
- template <typename index_t , typename scalar_t >
2074
+ template <
2075
+ const int BLOCK_TILE_M, // tile height of C that each thread block
2076
+ // calculates
2077
+ const int BLOCK_TILE_N, // tile width of C that each thread block
2078
+ // calculates
2079
+ const int BLOCK_TILE_K, // tile width of A that each thread block calculates
2080
+ const int THREAD_TILE_M, // tile height of C that each thread
2081
+ // calculates
2082
+ const int THREAD_TILE_N, // tile width of C that each thread calcualtes
2083
+ typename index_t ,
2084
+ typename scalar_t >
2075
2085
__global__ __launch_bounds__ (kMaxThreads ) void jagged_dense_bmm_kernel(
2076
- const at::PackedTensorAccessor32<scalar_t , 2 > x_values,
2077
- const at::PackedTensorAccessor32<index_t , 1 > x_offsets,
2078
- const at::PackedTensorAccessor32<scalar_t , 3 > y,
2079
- at::PackedTensorAccessor32<scalar_t , 2 > output,
2086
+ const at::PackedTensorAccessor32<scalar_t , 2 > __restrict__ x_values,
2087
+ const at::PackedTensorAccessor32<index_t , 1 > __restrict__ x_offsets,
2088
+ const at::PackedTensorAccessor32<scalar_t , 3 > __restrict__ y,
2089
+ at::PackedTensorAccessor32<scalar_t , 2 > __restrict__ output,
2080
2090
const int max_L) {
2081
2091
const int B = x_offsets.size (0 ) - 1 ;
2082
2092
const int K = x_values.size (1 );
2083
2093
const int N = y.size (2 );
2084
2094
2085
- const int b_l_begin = blockIdx .x * blockDim .y + threadIdx .y ;
2086
- const int b_l_step = gridDim .x * blockDim .y ;
2087
- for (int b_l = b_l_begin; b_l < B * max_L; b_l += b_l_step) {
2088
- const int b = b_l / max_L;
2089
- const int l = b_l % max_L;
2095
+ const auto block_row = blockIdx .y ;
2096
+ const auto block_col = blockIdx .x ;
2097
+
2098
+ const int THREADS_X_PER_BLOCK = BLOCK_TILE_N / THREAD_TILE_N;
2099
+ const int THREADS_Y_PER_BLOCK = BLOCK_TILE_M / THREAD_TILE_M;
2100
+ const int THREADS_PER_BLOCK = THREADS_X_PER_BLOCK * THREADS_Y_PER_BLOCK;
2101
+ const auto thread_row = threadIdx .x / THREADS_X_PER_BLOCK;
2102
+ const auto thread_col = threadIdx .x % THREADS_X_PER_BLOCK;
2103
+ const auto NUM_K_BLOCKS = (K + BLOCK_TILE_K - 1 ) / BLOCK_TILE_K;
2104
+
2105
+ __shared__ scalar_t As[BLOCK_TILE_M][BLOCK_TILE_K];
2106
+ __shared__ scalar_t Bs[BLOCK_TILE_K][BLOCK_TILE_N];
2107
+
2108
+ for (auto b = blockIdx .z ; b < B; b += gridDim .z ) {
2109
+ const index_t row_start = x_offsets[b];
2110
+ const index_t row_end = x_offsets[b + 1 ];
2111
+ const auto length = min (row_end - row_start, (index_t )max_L);
2112
+
2113
+ // the indices that this current will load into shared mem
2114
+ const auto inner_row_a = threadIdx .x / BLOCK_TILE_K;
2115
+ const auto inner_col_a = threadIdx .x % BLOCK_TILE_K;
2116
+ // the number of rows of As that will be loaded per step by a thread block
2117
+ const auto A_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_K;
2118
+
2119
+ const auto inner_row_b = threadIdx .x / BLOCK_TILE_N;
2120
+ const auto inner_col_b = threadIdx .x % BLOCK_TILE_N;
2121
+ const auto B_TILE_ROW_STRIDE = THREADS_PER_BLOCK / BLOCK_TILE_N;
2122
+
2123
+ // registers for C
2124
+ scalar_t accum[THREAD_TILE_M][THREAD_TILE_N] = {0 };
2125
+
2126
+ // registers for As and Bs
2127
+ scalar_t fragment_a[THREAD_TILE_M] = {0 };
2128
+ scalar_t fragment_b[THREAD_TILE_N] = {0 };
2129
+
2130
+ // loop for block tiles in K dimension
2131
+ for (auto block = 0 ; block < NUM_K_BLOCKS; block++) {
2132
+ // load a block of x_values from global memory to shared memory
2133
+ // apply tiling for threads in a block
2134
+ #pragma unroll
2135
+ for (auto offset = 0 ; offset < BLOCK_TILE_M;
2136
+ offset += A_TILE_ROW_STRIDE) {
2137
+ auto x_row_offset = block_row * BLOCK_TILE_M + inner_row_a + offset;
2138
+ auto x_col_offset = block * BLOCK_TILE_K + inner_col_a;
2139
+ if ((x_row_offset < length) && (x_col_offset < K)) {
2140
+ As[inner_row_a + offset][inner_col_a] =
2141
+ x_values[row_start + x_row_offset][x_col_offset];
2142
+ } else {
2143
+ As[inner_row_a + offset][inner_col_a] = 0 ;
2144
+ }
2145
+ }
2090
2146
2091
- const int row_start = x_offsets[b];
2092
- const int row_end = x_offsets[b + 1 ];
2093
- const int length = min (row_end - row_start, max_L);
2094
- if (length == 0 || l >= length) {
2095
- return ;
2096
- } else {
2097
- // TODO: use shared memory and better reduction
2098
- for (int n = threadIdx .x ; n < N; n += blockDim .x ) {
2099
- at::acc_type<scalar_t , true > acc = 0 ;
2100
- for (int k = 0 ; k < K; ++k) {
2101
- acc += x_values[row_start + l][k] * y[b][k][n];
2147
+ // load a block of y from global memory to shared memory
2148
+ // apply tiling for threads in a block
2149
+ #pragma unroll
2150
+ for (auto offset = 0 ; offset < BLOCK_TILE_K;
2151
+ offset += B_TILE_ROW_STRIDE) {
2152
+ auto y_row_offset = block * BLOCK_TILE_K + inner_row_b + offset;
2153
+ auto y_col_offset = block_col * BLOCK_TILE_N + inner_col_b;
2154
+ if ((y_row_offset < K) && (y_col_offset < N)) {
2155
+ Bs[inner_row_b + offset][inner_col_b] =
2156
+ y[b][y_row_offset][y_col_offset];
2157
+ } else {
2158
+ Bs[inner_row_b + offset][inner_col_b] = 0 ;
2159
+ }
2160
+ }
2161
+
2162
+ __syncthreads ();
2163
+
2164
+ // calculate the results per thread
2165
+ #pragma unroll
2166
+ for (auto k = 0 ; k < BLOCK_TILE_K; k++) {
2167
+ // load values from shared memory to registers for x_values
2168
+ for (auto row = 0 ; row < THREAD_TILE_M; row++) {
2169
+ fragment_a[row] = As[thread_row * THREAD_TILE_M + row][k];
2170
+ }
2171
+
2172
+ // load values from shared memory to registers for y
2173
+ #pragma unroll
2174
+ for (auto col = 0 ; col < THREAD_TILE_N; col++) {
2175
+ fragment_b[col] = Bs[k][thread_col * THREAD_TILE_N + col];
2176
+ }
2177
+
2178
+ // each thread calcualtes THREAD_TILE_M * THREAD_TILE_N elements
2179
+ #pragma unroll
2180
+ for (auto row = 0 ; row < THREAD_TILE_M; row++) {
2181
+ #pragma unroll
2182
+ for (auto col = 0 ; col < THREAD_TILE_N; col++) {
2183
+ accum[row][col] += fragment_a[row] * fragment_b[col];
2184
+ }
2185
+ }
2186
+ }
2187
+
2188
+ __syncthreads ();
2189
+ }
2190
+
2191
+ // write the result to the output
2192
+ #pragma unroll
2193
+ for (auto row = 0 ; row < THREAD_TILE_M; row++) {
2194
+ #pragma unroll
2195
+ for (auto col = 0 ; col < THREAD_TILE_N; col++) {
2196
+ auto out_row_offset =
2197
+ block_row * BLOCK_TILE_M + thread_row * THREAD_TILE_M + row;
2198
+ auto out_col_offset =
2199
+ block_col * BLOCK_TILE_N + thread_col * THREAD_TILE_N + col;
2200
+ if ((out_row_offset < length) && (out_col_offset < N)) {
2201
+ output[row_start + out_row_offset][out_col_offset] = accum[row][col];
2102
2202
}
2103
- output[row_start + l][n] = acc;
2104
2203
}
2105
2204
}
2106
2205
}
@@ -2124,9 +2223,29 @@ Tensor jagged_dense_bmm_forward(
2124
2223
const int total_L = x_values.size (0 );
2125
2224
auto output = at::zeros ({total_L, N}, x_values.options ());
2126
2225
if (B > 0 && M > 0 && N > 0 ) {
2127
- const int block_dim_x =
2128
- std::min (div_round_up (N, kWarpSize ) * kWarpSize , kMaxThreads );
2129
- const int block_dim_y = kMaxThreads / block_dim_x;
2226
+ // The shared memory size is (BLOCK_TILE_M + BLOCK_TILE_N) * BLOCK_TILE_K
2227
+ // BLOCK_TILE_M needs to be multiple of THREAD_TILE_M, and
2228
+ // BLOCK_TILE_N needs to be multiple of THREAD_TILE_N
2229
+ // The setting of these parameters needs to balance the hardware's shared
2230
+ // memory size limit and occupancy
2231
+ // TODO: autotune these parameters based on max_L and input and output
2232
+ // tensor sizes
2233
+ constexpr int BLOCK_TILE_M = 64 ;
2234
+ constexpr int BLOCK_TILE_N = 8 ;
2235
+ constexpr int BLOCK_TILE_K = 8 ;
2236
+ constexpr int THREAD_TILE_M = 4 ;
2237
+ constexpr int THREAD_TILE_N = 4 ;
2238
+
2239
+ const dim3 block (
2240
+ (BLOCK_TILE_M * BLOCK_TILE_N) / (THREAD_TILE_M * THREAD_TILE_N));
2241
+ const auto grid_dim_x = div_round_up (N, BLOCK_TILE_N);
2242
+ const auto grid_dim_y = div_round_up (max_L, BLOCK_TILE_M);
2243
+ TORCH_CHECK (
2244
+ grid_dim_y <= kMaxBlockYDim ,
2245
+ " max_L cannot be larger than" ,
2246
+ grid_dim_y * BLOCK_TILE_M + 1 - BLOCK_TILE_M);
2247
+ const auto grid_dim_z = std::min (B, kMaxBlockZDim );
2248
+ const dim3 grid (grid_dim_x, grid_dim_y, grid_dim_z);
2130
2249
2131
2250
AT_DISPATCH_INDEX_TYPES (
2132
2251
x_offsets.scalar_type (), " jagged_dense_bmm_kernel_1" , [&] {
@@ -2136,11 +2255,15 @@ Tensor jagged_dense_bmm_forward(
2136
2255
x_values.scalar_type (),
2137
2256
" jagged_dense_bmm_kernel_2" ,
2138
2257
[&] {
2139
- jagged_dense_bmm_kernel<index_t , scalar_t >
2140
- <<<div_round_up(B * max_L, block_dim_y),
2141
- dim3 (block_dim_x, block_dim_y),
2142
- 0,
2143
- at::cuda::getCurrentCUDAStream()>>>(
2258
+ jagged_dense_bmm_kernel<
2259
+ BLOCK_TILE_M,
2260
+ BLOCK_TILE_N,
2261
+ BLOCK_TILE_K,
2262
+ THREAD_TILE_M,
2263
+ THREAD_TILE_N,
2264
+ index_t ,
2265
+ scalar_t >
2266
+ <<<grid, block, 0 , at::cuda::getCurrentCUDAStream()>>> (
2144
2267
x_values.packed_accessor32 <scalar_t , 2 >(),
2145
2268
x_offsets.packed_accessor32 <index_t , 1 >(),
2146
2269
y.packed_accessor32 <scalar_t , 3 >(),
0 commit comments