|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#ifndef USE_ROCM |
| 10 | + |
| 11 | +#include <ATen/ATen.h> |
| 12 | +#include <ATen/cuda/CUDAContext.h> |
| 13 | + |
| 14 | +#include "cute/atom/copy_atom.hpp" |
| 15 | +#include "cute/atom/copy_traits_sm90_tma.hpp" |
| 16 | +#include "cute/numeric/integral_constant.hpp" |
| 17 | +#include "cute/tensor.hpp" |
| 18 | +#include "cutlass/cluster_launch.hpp" |
| 19 | +#include "cutlass/device_kernel.h" |
| 20 | + |
| 21 | +namespace fbgemm_gpu { |
| 22 | + |
| 23 | +namespace { |
| 24 | + |
| 25 | +template <int kBlkN, class DataType, class SmemLayout> |
| 26 | +struct SharedStorage { |
| 27 | + static constexpr int kPipeMax = cute::size<0>(SmemLayout{}); |
| 28 | + static constexpr int kTmaAlignment = 128; |
| 29 | + static constexpr int kMbarAlignemnt = 8; |
| 30 | + |
| 31 | + cute::array_aligned<int32_t, kBlkN> index; |
| 32 | + cute::array_aligned<DataType, cute::cosize_v<SmemLayout>, kTmaAlignment> data; |
| 33 | + |
| 34 | + CUTE_ALIGNAS(kMbarAlignemnt) uint64_t tma_load_barrier[kPipeMax]; |
| 35 | +}; |
| 36 | + |
| 37 | +template < |
| 38 | + class ProblemShape, |
| 39 | + class TileShape, |
| 40 | + class DataType, |
| 41 | + class SmemLayout, |
| 42 | + class TmaLoad, |
| 43 | + class TmaStore> |
| 44 | +__global__ static void gather_along_first_dim_kernel( |
| 45 | + ProblemShape problem_shape, |
| 46 | + TileShape tile_shape, |
| 47 | + CUTLASS_GRID_CONSTANT TmaLoad const tma_load_input, |
| 48 | + const int32_t* index, |
| 49 | + CUTLASS_GRID_CONSTANT TmaStore const tma_store_output) { |
| 50 | + // Input shape: A [M, K] |
| 51 | + // Output shape: B [N, K] |
| 52 | + int M = cute::get<0>(problem_shape); |
| 53 | + int N = cute::get<1>(problem_shape); |
| 54 | + int K = cute::get<2>(problem_shape); |
| 55 | + |
| 56 | + static_assert(cute::is_static<TileShape>::value); |
| 57 | + constexpr int kBlkN = cute::size<0>(tile_shape); |
| 58 | + constexpr int kBlkK = cute::size<1>(tile_shape); |
| 59 | + |
| 60 | + using SmemT = SharedStorage<kBlkN, DataType, SmemLayout>; |
| 61 | + constexpr int kPipeMax = SmemT::kPipeMax; |
| 62 | + |
| 63 | + extern __shared__ char smem_raw[]; |
| 64 | + SmemT& smem = *reinterpret_cast<SmemT*>(smem_raw); |
| 65 | + |
| 66 | + const int n_offset = blockIdx.x * kBlkN; |
| 67 | + if (n_offset >= N) { |
| 68 | + return; |
| 69 | + } |
| 70 | + |
| 71 | + // Straight-forward direct global read of indices. |
| 72 | + if (threadIdx.x < kBlkN && n_offset + threadIdx.x < N) { |
| 73 | + smem.index[threadIdx.x] = index[n_offset + threadIdx.x]; |
| 74 | + } |
| 75 | + __syncthreads(); |
| 76 | + |
| 77 | + if (threadIdx.x == 0) { |
| 78 | + // Tensors on HBM. |
| 79 | + cute::Tensor gA = tma_load_input.get_tma_tensor(cute::make_shape(M, K)); |
| 80 | + cute::Tensor gB = tma_store_output.get_tma_tensor(cute::make_shape(N, K)); |
| 81 | + // Tensors on SMEM. |
| 82 | + cute::Tensor sA = cute::make_tensor( |
| 83 | + cute::make_smem_ptr(smem.data.data()), cute::group<0, 2>(SmemLayout{})); |
| 84 | + |
| 85 | + constexpr int kTmaTransactionBytes = kBlkK * sizeof(DataType); |
| 86 | + const int kNumKTiles = ((K + kBlkK - 1) / kBlkK); |
| 87 | + const int kNumNKTiles = kBlkN * kNumKTiles; |
| 88 | + const int kNumIterations = kNumNKTiles + kPipeMax - 1; |
| 89 | + |
| 90 | + for (int iteration = 0; iteration < kNumIterations; ++iteration) { |
| 91 | + // Load. |
| 92 | + if (iteration < kNumNKTiles) { |
| 93 | + int load_pipe = iteration % kPipeMax; |
| 94 | + |
| 95 | + int n = iteration / kNumKTiles; |
| 96 | + int k = iteration % kNumKTiles; |
| 97 | + int m = smem.index[n]; |
| 98 | + |
| 99 | + cute::tma_store_wait<kPipeMax - 1>(); |
| 100 | + |
| 101 | + cute::Tensor tAgA = cute::local_tile( |
| 102 | + gA, |
| 103 | + cute::Tile<cute::_1, cute::Int<kBlkK>>{}, |
| 104 | + cute::make_coord(m, k)); |
| 105 | + cute::Tensor tAsA = cute::local_tile( |
| 106 | + sA, |
| 107 | + cute::Tile<cute::_1, cute::Int<kBlkK>>{}, |
| 108 | + cute::make_coord(load_pipe, 0)); |
| 109 | + |
| 110 | + auto& tma_load_mbar = smem.tma_load_barrier[load_pipe]; |
| 111 | + cute::initialize_barrier(smem.tma_load_barrier[load_pipe], 1); |
| 112 | + cute::set_barrier_transaction_bytes( |
| 113 | + tma_load_mbar, kTmaTransactionBytes); |
| 114 | + |
| 115 | + auto tma_load_per_cta = tma_load_input.get_slice(0); |
| 116 | + cute::copy( |
| 117 | + tma_load_input.with(tma_load_mbar), |
| 118 | + tma_load_per_cta.partition_S(tAgA), |
| 119 | + tma_load_per_cta.partition_D(tAsA)); |
| 120 | + } |
| 121 | + |
| 122 | + // Store |
| 123 | + if (iteration >= kPipeMax - 1) { |
| 124 | + int processing_index = iteration - kPipeMax + 1; |
| 125 | + int store_pipe = processing_index % kPipeMax; |
| 126 | + |
| 127 | + int n = processing_index / kNumKTiles; |
| 128 | + int k = processing_index % kNumKTiles; |
| 129 | + |
| 130 | + cute::wait_barrier(smem.tma_load_barrier[store_pipe], 0); |
| 131 | + |
| 132 | + cute::Tensor tAgB = cute::local_tile( |
| 133 | + gB, |
| 134 | + cute::Tile<cute::_1, cute::Int<kBlkK>>{}, |
| 135 | + cute::make_coord(n + n_offset, k)); |
| 136 | + cute::Tensor tAsA = cute::local_tile( |
| 137 | + sA, |
| 138 | + cute::Tile<cute::_1, cute::Int<kBlkK>>{}, |
| 139 | + cute::make_coord(store_pipe, 0)); |
| 140 | + |
| 141 | + auto tma_store_per_cta = tma_store_output.get_slice(0); |
| 142 | + cute::copy( |
| 143 | + tma_store_output, |
| 144 | + tma_store_per_cta.partition_S(tAsA), |
| 145 | + tma_store_per_cta.partition_D(tAgB)); |
| 146 | + cute::tma_store_arrive(); |
| 147 | + } |
| 148 | + } |
| 149 | + } |
| 150 | + cute::tma_store_wait<0>(); |
| 151 | +} |
| 152 | + |
| 153 | +} // namespace |
| 154 | + |
| 155 | +// TODO(shikaili): Templatize it and make it supports more configurations. |
| 156 | +at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) { |
| 157 | + using DataType = cutlass::bfloat16_t; |
| 158 | + constexpr auto kDataTypeEnum = at::kBFloat16; |
| 159 | + using IndexType = int32_t; |
| 160 | + constexpr auto kIndexTypeEnum = at::kInt; |
| 161 | + constexpr int kTmaGmemAlignment = 16; |
| 162 | + |
| 163 | + bool compatible = (data.dtype() == kDataTypeEnum && data.is_contiguous() && |
| 164 | + data.dim() == 2) && |
| 165 | + (index.dtype() == kIndexTypeEnum && index.is_contiguous() && |
| 166 | + index.dim() == 1) && |
| 167 | + (data.size(1) * sizeof(DataType) % kTmaGmemAlignment == 0); |
| 168 | + |
| 169 | + if (!compatible) { |
| 170 | + return at::index_select(data, 0, index); |
| 171 | + } |
| 172 | + |
| 173 | + const int M = data.size(0); |
| 174 | + const int K = data.size(1); |
| 175 | + const int N = index.size(0); |
| 176 | + |
| 177 | + auto src_gmem_layout = |
| 178 | + cute::make_layout(cute::make_shape(M, K), cute::make_stride(K, 1)); |
| 179 | + auto src_gmem_tensor = cute::make_tensor( |
| 180 | + cute::make_gmem_ptr(reinterpret_cast<DataType*>(data.data_ptr())), |
| 181 | + src_gmem_layout); |
| 182 | + |
| 183 | + at::Tensor output = at::empty( |
| 184 | + {N, K}, at::TensorOptions().dtype(at::kBFloat16).device(data.device())); |
| 185 | + auto dst_gmem_layout = |
| 186 | + cute::make_layout(cute::make_shape(N, K), cute::make_stride(K, 1)); |
| 187 | + auto dst_gmem_tensor = cute::make_tensor( |
| 188 | + cute::make_gmem_ptr(reinterpret_cast<DataType*>(output.data_ptr())), |
| 189 | + dst_gmem_layout); |
| 190 | + |
| 191 | + constexpr int kBlkN = 1; |
| 192 | + constexpr int kBlkK = 256; |
| 193 | + constexpr int kPipeMax = 4; |
| 194 | + |
| 195 | + auto smem_layout = cute::make_layout( |
| 196 | + cute::make_shape(cute::Int<kPipeMax>{}, cute::_1{}, cute::Int<kBlkK>{}), |
| 197 | + cute::make_stride(cute::Int<kBlkK>{}, cute::Int<kBlkK>{}, cute::_1{})); |
| 198 | + auto tma_load = cute::make_tma_copy( |
| 199 | + cute::SM90_TMA_LOAD{}, src_gmem_tensor, smem_layout(0, cute::_, cute::_)); |
| 200 | + auto tma_store = cute::make_tma_copy( |
| 201 | + cute::SM90_TMA_STORE{}, |
| 202 | + dst_gmem_tensor, |
| 203 | + smem_layout(0, cute::_, cute::_)); |
| 204 | + |
| 205 | + auto problem_shape = cute::make_shape(M, N, K); |
| 206 | + auto tile_shape = cute::make_shape(cute::Int<kBlkN>{}, cute::Int<kBlkK>{}); |
| 207 | + |
| 208 | + using SmemT = SharedStorage<kBlkN, DataType, decltype(smem_layout)>; |
| 209 | + |
| 210 | + int num_ctas = (N + kBlkN - 1) / kBlkN; |
| 211 | + dim3 grid_dims(num_ctas, 1, 1); |
| 212 | + dim3 block_dims(32, 1, 1); |
| 213 | + dim3 cluster_dims(1, 1, 1); |
| 214 | + int smem_size = sizeof(SmemT); |
| 215 | + auto stream = c10::cuda::getCurrentCUDAStream(); |
| 216 | + |
| 217 | + cutlass::ClusterLaunchParams launch_params{ |
| 218 | + grid_dims, block_dims, cluster_dims, smem_size, stream}; |
| 219 | + void* kernel = (void*)gather_along_first_dim_kernel< |
| 220 | + decltype(problem_shape), |
| 221 | + decltype(tile_shape), |
| 222 | + DataType, |
| 223 | + decltype(smem_layout), |
| 224 | + decltype(tma_load), |
| 225 | + decltype(tma_store)>; |
| 226 | + |
| 227 | + CUTE_CHECK_ERROR(cudaFuncSetAttribute( |
| 228 | + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); |
| 229 | + |
| 230 | + // Kernel Launch |
| 231 | + cutlass::Status status = cutlass::launch_kernel_on_cluster( |
| 232 | + launch_params, |
| 233 | + kernel, |
| 234 | + problem_shape, |
| 235 | + tile_shape, |
| 236 | + tma_load, |
| 237 | + reinterpret_cast<IndexType*>(index.data_ptr()), |
| 238 | + tma_store); |
| 239 | + |
| 240 | + if (status != cutlass::Status::kSuccess) { |
| 241 | + cudaError_t error = cudaGetLastError(); |
| 242 | + CUTE_ERROR_EXIT(error); |
| 243 | + } |
| 244 | + |
| 245 | + return output; |
| 246 | +} |
| 247 | + |
| 248 | +} // namespace fbgemm_gpu |
| 249 | + |
| 250 | +#endif |
0 commit comments