Skip to content

Commit 0bfcc77

Browse files
levendleefacebook-github-bot
authored andcommitted
torch.ops.fbgemm.gather_along_first_dim. (pytorch#800)
Summary: X-link: pytorch#3719 Pull Request resolved: facebookresearch/FBGEMM#800 TMA based gather operation optimized for large shapes. Hyperparameters could be finetuned for better performance. However, the expected headroom is small. Reviewed By: jianyuh Differential Revision: D69907204 fbshipit-source-id: 3e8b48d40b478a05359b55bc629e1f867fc82de8
1 parent 50de711 commit 0bfcc77

File tree

3 files changed

+355
-0
lines changed

3 files changed

+355
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <torch/library.h>
11+
12+
namespace fbgemm_gpu {
13+
14+
#ifndef USE_ROCM
15+
16+
at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index);
17+
18+
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
19+
m.set_python_module("fbgemm_gpu.experimental.gen_ai.gather");
20+
m.def("gather_along_first_dim(Tensor Data, Tensor Index) -> Tensor");
21+
}
22+
23+
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
24+
m.impl("gather_along_first_dim", gather_along_first_dim);
25+
}
26+
27+
#endif
28+
29+
} // namespace fbgemm_gpu
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#ifndef USE_ROCM
10+
11+
#include <ATen/ATen.h>
12+
#include <ATen/cuda/CUDAContext.h>
13+
14+
#include "cute/atom/copy_atom.hpp"
15+
#include "cute/atom/copy_traits_sm90_tma.hpp"
16+
#include "cute/numeric/integral_constant.hpp"
17+
#include "cute/tensor.hpp"
18+
#include "cutlass/cluster_launch.hpp"
19+
#include "cutlass/device_kernel.h"
20+
21+
namespace fbgemm_gpu {
22+
23+
namespace {
24+
25+
template <int kBlkN, class DataType, class SmemLayout>
26+
struct SharedStorage {
27+
static constexpr int kPipeMax = cute::size<0>(SmemLayout{});
28+
static constexpr int kTmaAlignment = 128;
29+
static constexpr int kMbarAlignemnt = 8;
30+
31+
cute::array_aligned<int32_t, kBlkN> index;
32+
cute::array_aligned<DataType, cute::cosize_v<SmemLayout>, kTmaAlignment> data;
33+
34+
CUTE_ALIGNAS(kMbarAlignemnt) uint64_t tma_load_barrier[kPipeMax];
35+
};
36+
37+
template <
38+
class ProblemShape,
39+
class TileShape,
40+
class DataType,
41+
class SmemLayout,
42+
class TmaLoad,
43+
class TmaStore>
44+
__global__ static void gather_along_first_dim_kernel(
45+
ProblemShape problem_shape,
46+
TileShape tile_shape,
47+
CUTLASS_GRID_CONSTANT TmaLoad const tma_load_input,
48+
const int32_t* index,
49+
CUTLASS_GRID_CONSTANT TmaStore const tma_store_output) {
50+
// Input shape: A [M, K]
51+
// Output shape: B [N, K]
52+
int M = cute::get<0>(problem_shape);
53+
int N = cute::get<1>(problem_shape);
54+
int K = cute::get<2>(problem_shape);
55+
56+
static_assert(cute::is_static<TileShape>::value);
57+
constexpr int kBlkN = cute::size<0>(tile_shape);
58+
constexpr int kBlkK = cute::size<1>(tile_shape);
59+
60+
using SmemT = SharedStorage<kBlkN, DataType, SmemLayout>;
61+
constexpr int kPipeMax = SmemT::kPipeMax;
62+
63+
extern __shared__ char smem_raw[];
64+
SmemT& smem = *reinterpret_cast<SmemT*>(smem_raw);
65+
66+
const int n_offset = blockIdx.x * kBlkN;
67+
if (n_offset >= N) {
68+
return;
69+
}
70+
71+
// Straight-forward direct global read of indices.
72+
if (threadIdx.x < kBlkN && n_offset + threadIdx.x < N) {
73+
smem.index[threadIdx.x] = index[n_offset + threadIdx.x];
74+
}
75+
__syncthreads();
76+
77+
if (threadIdx.x == 0) {
78+
// Tensors on HBM.
79+
cute::Tensor gA = tma_load_input.get_tma_tensor(cute::make_shape(M, K));
80+
cute::Tensor gB = tma_store_output.get_tma_tensor(cute::make_shape(N, K));
81+
// Tensors on SMEM.
82+
cute::Tensor sA = cute::make_tensor(
83+
cute::make_smem_ptr(smem.data.data()), cute::group<0, 2>(SmemLayout{}));
84+
85+
constexpr int kTmaTransactionBytes = kBlkK * sizeof(DataType);
86+
const int kNumKTiles = ((K + kBlkK - 1) / kBlkK);
87+
const int kNumNKTiles = kBlkN * kNumKTiles;
88+
const int kNumIterations = kNumNKTiles + kPipeMax - 1;
89+
90+
for (int iteration = 0; iteration < kNumIterations; ++iteration) {
91+
// Load.
92+
if (iteration < kNumNKTiles) {
93+
int load_pipe = iteration % kPipeMax;
94+
95+
int n = iteration / kNumKTiles;
96+
int k = iteration % kNumKTiles;
97+
int m = smem.index[n];
98+
99+
cute::tma_store_wait<kPipeMax - 1>();
100+
101+
cute::Tensor tAgA = cute::local_tile(
102+
gA,
103+
cute::Tile<cute::_1, cute::Int<kBlkK>>{},
104+
cute::make_coord(m, k));
105+
cute::Tensor tAsA = cute::local_tile(
106+
sA,
107+
cute::Tile<cute::_1, cute::Int<kBlkK>>{},
108+
cute::make_coord(load_pipe, 0));
109+
110+
auto& tma_load_mbar = smem.tma_load_barrier[load_pipe];
111+
cute::initialize_barrier(smem.tma_load_barrier[load_pipe], 1);
112+
cute::set_barrier_transaction_bytes(
113+
tma_load_mbar, kTmaTransactionBytes);
114+
115+
auto tma_load_per_cta = tma_load_input.get_slice(0);
116+
cute::copy(
117+
tma_load_input.with(tma_load_mbar),
118+
tma_load_per_cta.partition_S(tAgA),
119+
tma_load_per_cta.partition_D(tAsA));
120+
}
121+
122+
// Store
123+
if (iteration >= kPipeMax - 1) {
124+
int processing_index = iteration - kPipeMax + 1;
125+
int store_pipe = processing_index % kPipeMax;
126+
127+
int n = processing_index / kNumKTiles;
128+
int k = processing_index % kNumKTiles;
129+
130+
cute::wait_barrier(smem.tma_load_barrier[store_pipe], 0);
131+
132+
cute::Tensor tAgB = cute::local_tile(
133+
gB,
134+
cute::Tile<cute::_1, cute::Int<kBlkK>>{},
135+
cute::make_coord(n + n_offset, k));
136+
cute::Tensor tAsA = cute::local_tile(
137+
sA,
138+
cute::Tile<cute::_1, cute::Int<kBlkK>>{},
139+
cute::make_coord(store_pipe, 0));
140+
141+
auto tma_store_per_cta = tma_store_output.get_slice(0);
142+
cute::copy(
143+
tma_store_output,
144+
tma_store_per_cta.partition_S(tAsA),
145+
tma_store_per_cta.partition_D(tAgB));
146+
cute::tma_store_arrive();
147+
}
148+
}
149+
}
150+
cute::tma_store_wait<0>();
151+
}
152+
153+
} // namespace
154+
155+
// TODO(shikaili): Templatize it and make it supports more configurations.
156+
at::Tensor gather_along_first_dim(at::Tensor data, at::Tensor index) {
157+
using DataType = cutlass::bfloat16_t;
158+
constexpr auto kDataTypeEnum = at::kBFloat16;
159+
using IndexType = int32_t;
160+
constexpr auto kIndexTypeEnum = at::kInt;
161+
constexpr int kTmaGmemAlignment = 16;
162+
163+
bool compatible = (data.dtype() == kDataTypeEnum && data.is_contiguous() &&
164+
data.dim() == 2) &&
165+
(index.dtype() == kIndexTypeEnum && index.is_contiguous() &&
166+
index.dim() == 1) &&
167+
(data.size(1) * sizeof(DataType) % kTmaGmemAlignment == 0);
168+
169+
if (!compatible) {
170+
return at::index_select(data, 0, index);
171+
}
172+
173+
const int M = data.size(0);
174+
const int K = data.size(1);
175+
const int N = index.size(0);
176+
177+
auto src_gmem_layout =
178+
cute::make_layout(cute::make_shape(M, K), cute::make_stride(K, 1));
179+
auto src_gmem_tensor = cute::make_tensor(
180+
cute::make_gmem_ptr(reinterpret_cast<DataType*>(data.data_ptr())),
181+
src_gmem_layout);
182+
183+
at::Tensor output = at::empty(
184+
{N, K}, at::TensorOptions().dtype(at::kBFloat16).device(data.device()));
185+
auto dst_gmem_layout =
186+
cute::make_layout(cute::make_shape(N, K), cute::make_stride(K, 1));
187+
auto dst_gmem_tensor = cute::make_tensor(
188+
cute::make_gmem_ptr(reinterpret_cast<DataType*>(output.data_ptr())),
189+
dst_gmem_layout);
190+
191+
constexpr int kBlkN = 1;
192+
constexpr int kBlkK = 256;
193+
constexpr int kPipeMax = 4;
194+
195+
auto smem_layout = cute::make_layout(
196+
cute::make_shape(cute::Int<kPipeMax>{}, cute::_1{}, cute::Int<kBlkK>{}),
197+
cute::make_stride(cute::Int<kBlkK>{}, cute::Int<kBlkK>{}, cute::_1{}));
198+
auto tma_load = cute::make_tma_copy(
199+
cute::SM90_TMA_LOAD{}, src_gmem_tensor, smem_layout(0, cute::_, cute::_));
200+
auto tma_store = cute::make_tma_copy(
201+
cute::SM90_TMA_STORE{},
202+
dst_gmem_tensor,
203+
smem_layout(0, cute::_, cute::_));
204+
205+
auto problem_shape = cute::make_shape(M, N, K);
206+
auto tile_shape = cute::make_shape(cute::Int<kBlkN>{}, cute::Int<kBlkK>{});
207+
208+
using SmemT = SharedStorage<kBlkN, DataType, decltype(smem_layout)>;
209+
210+
int num_ctas = (N + kBlkN - 1) / kBlkN;
211+
dim3 grid_dims(num_ctas, 1, 1);
212+
dim3 block_dims(32, 1, 1);
213+
dim3 cluster_dims(1, 1, 1);
214+
int smem_size = sizeof(SmemT);
215+
auto stream = c10::cuda::getCurrentCUDAStream();
216+
217+
cutlass::ClusterLaunchParams launch_params{
218+
grid_dims, block_dims, cluster_dims, smem_size, stream};
219+
void* kernel = (void*)gather_along_first_dim_kernel<
220+
decltype(problem_shape),
221+
decltype(tile_shape),
222+
DataType,
223+
decltype(smem_layout),
224+
decltype(tma_load),
225+
decltype(tma_store)>;
226+
227+
CUTE_CHECK_ERROR(cudaFuncSetAttribute(
228+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
229+
230+
// Kernel Launch
231+
cutlass::Status status = cutlass::launch_kernel_on_cluster(
232+
launch_params,
233+
kernel,
234+
problem_shape,
235+
tile_shape,
236+
tma_load,
237+
reinterpret_cast<IndexType*>(index.data_ptr()),
238+
tma_store);
239+
240+
if (status != cutlass::Status::kSuccess) {
241+
cudaError_t error = cudaGetLastError();
242+
CUTE_ERROR_EXIT(error);
243+
}
244+
245+
return output;
246+
}
247+
248+
} // namespace fbgemm_gpu
249+
250+
#endif
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
# pyre-ignore-all-errors[56]
9+
10+
import logging
11+
import unittest
12+
13+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
14+
15+
import torch
16+
import triton # noqa: F401
17+
18+
logger: logging.Logger = logging.getLogger()
19+
logger.setLevel(logging.INFO)
20+
21+
22+
@unittest.skipIf(
23+
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
24+
"Skip when no Hopper GPU is available. This test is only for Hopper GPU.",
25+
)
26+
class GatherTests(unittest.TestCase):
27+
"""Test Gathers."""
28+
29+
def test_gather_along_first_dim(self) -> None:
30+
def _test_gather_along_first_dim(M: int, N: int, K: int) -> None:
31+
logger.info(f"Running test_gather_along_first_dim: {M=}, {N=}, {K=}")
32+
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
33+
if M == N:
34+
indices = torch.randperm(N, device="cuda", dtype=torch.int32)
35+
else:
36+
indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int32)
37+
38+
def fn():
39+
return torch.ops.fbgemm.gather_along_first_dim(src, indices)
40+
41+
def ref_fn():
42+
return torch.index_select(src, 0, indices)
43+
44+
logger.info("Running FBGMM")
45+
dst = fn()
46+
logger.info("Running PyTorch")
47+
ref_dst = ref_fn()
48+
49+
self.assertTrue((dst == ref_dst).all().item())
50+
51+
data_size_in_terabytes = N * K * 2 * 2 / 1e12
52+
53+
time_in_us = triton.testing.do_bench(fn) * 1e3
54+
time_in_second = time_in_us / 1e6
55+
terabytes_per_second = data_size_in_terabytes / time_in_second
56+
57+
ref_time_in_us = triton.testing.do_bench(ref_fn) * 1e3
58+
ref_time_in_second = ref_time_in_us / 1e6
59+
ref_terabytes_per_second = data_size_in_terabytes / ref_time_in_second
60+
61+
logger.info(
62+
f"FBGEMM time: {time_in_us:.2f} us. Bandwidth: {terabytes_per_second:.2f} TB/s"
63+
)
64+
logger.info(
65+
f"PyTorch time: {ref_time_in_us:.2f} us. Bandwidth: {ref_terabytes_per_second:.2f} TB/s"
66+
)
67+
68+
_test_gather_along_first_dim(127, 257, 1023)
69+
_test_gather_along_first_dim(127, 257, 1024)
70+
_test_gather_along_first_dim(255, 129, 2049)
71+
_test_gather_along_first_dim(255, 129, 2048)
72+
_test_gather_along_first_dim(1024, 1024, 1024)
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)