Skip to content

Commit ba33284

Browse files
xw285cornellfacebook-github-bot
authored andcommitted
custom reduce scatter (pytorch#763)
Summary: X-link: pytorch#3686 Pull Request resolved: facebookresearch/FBGEMM#763 Piggyback on the twoshot allreduce for the reducescatter - pretty much the first half of twoshot allreduce. Reviewed By: jasonjk-park Differential Revision: D69364062 fbshipit-source-id: fb9d8332e30325e15c009abed1eb4c1ed229f3db
1 parent b5ee7cd commit ba33284

File tree

3 files changed

+246
-69
lines changed

3 files changed

+246
-69
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,11 @@ void nccl_alltoall(
177177
torch::cuda::nccl::all2all(dsts, srcs, *get_nccl_comm(comm_idx), stream);
178178
}
179179

180-
void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
180+
void nccl_reducescatter(
181+
at::Tensor dst,
182+
at::Tensor src,
183+
std::optional<at::Tensor> bias,
184+
int64_t comm_idx) {
181185
using namespace c10d;
182186
TORCH_CHECK(src.is_contiguous());
183187
TORCH_CHECK(dst.is_contiguous());
@@ -194,6 +198,10 @@ void nccl_reducescatter(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
194198
*get_nccl_comm(comm_idx),
195199
at::cuda::getCurrentCUDAStream()),
196200
"ncclReduceScatter");
201+
202+
if (bias) {
203+
dst.add_(*bias);
204+
}
197205
}
198206

199207
void nccl_allreduce(
@@ -259,6 +267,11 @@ void two_shot_car_allreduce(
259267
at::Tensor src,
260268
std::optional<at::Tensor> bias,
261269
int64_t comm_idx);
270+
void car_reducescatter(
271+
at::Tensor dst,
272+
at::Tensor src,
273+
std::optional<at::Tensor> bias,
274+
int64_t comm_idx);
262275

263276
at::Tensor car_tensor();
264277

@@ -282,7 +295,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
282295
"nccl_alltoall_single(Tensor(a!) dst, Tensor src, int world_size, int comm_idx=0) -> ()");
283296
m.def("nccl_alltoall(Tensor(a!)[] dst, Tensor[] src, int comm_idx=0) -> ()");
284297

285-
m.def("nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()");
298+
m.def(
299+
"nccl_reducescatter(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
286300

287301
m.def(
288302
"nccl_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
@@ -302,6 +316,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
302316

303317
m.def(
304318
"two_shot_car_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
319+
320+
m.def(
321+
"car_reducescatter(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
305322
}
306323

307324
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
@@ -312,6 +329,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
312329
m.impl("nccl_reducescatter", nccl_reducescatter);
313330
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
314331
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
332+
m.impl("car_reducescatter", car_reducescatter);
315333
}
316334

317335
// Though it shouldnt be used, it is useful to define these functions for CPU to
@@ -324,6 +342,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
324342
m.impl("nccl_reducescatter", nccl_reducescatter);
325343
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
326344
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
345+
m.impl("car_reducescatter", car_reducescatter);
327346
}
328347

329348
// Shape registration functions for car operators.
@@ -360,6 +379,7 @@ void nccl_alltoall_meta(
360379
void nccl_reducescatter_meta(
361380
at::Tensor /* dst */,
362381
at::Tensor /* src */,
382+
std::optional<at::Tensor> /* bias */,
363383
int64_t /* comm_idx */) {
364384
return;
365385
}
@@ -380,6 +400,14 @@ void two_shot_car_allreduce_meta(
380400
return;
381401
}
382402

403+
void car_reducescatter_meta(
404+
at::Tensor /* dst */,
405+
at::Tensor /* src */,
406+
std::optional<at::Tensor> /* bias */,
407+
int64_t /* comm_idx */) {
408+
return;
409+
}
410+
383411
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
384412
m.impl("nccl_allreduce", nccl_allreduce_meta);
385413
m.impl("nccl_allgather", nccl_allgather_meta);
@@ -388,6 +416,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
388416
m.impl("nccl_reducescatter", nccl_reducescatter_meta);
389417
m.impl("one_shot_car_allreduce", one_shot_car_allreduce_meta);
390418
m.impl("two_shot_car_allreduce", two_shot_car_allreduce_meta);
419+
m.impl("car_reducescatter", car_reducescatter_meta);
391420
}
392421

393422
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/comm/car.cu

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
345345
#endif
346346
}
347347

348-
template <int32_t kWorldSize, bool has_acc>
348+
template <int32_t kWorldSize, bool has_acc, bool reduce_scatter>
349349
#if defined(USE_ROCM)
350350
__launch_bounds__(512) __global__ void two_shot_all_reduce(
351351
#else
@@ -425,13 +425,18 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
425425
}
426426

427427
// Store to the local buffer.
428-
*reinterpret_cast<uint4*>(&src_d[0][i + N_start]) =
429-
*reinterpret_cast<const uint4*>(&sums);
428+
if constexpr (reduce_scatter) {
429+
*reinterpret_cast<uint4*>(&output[i]) =
430+
*reinterpret_cast<const uint4*>(&sums);
431+
} else {
432+
*reinterpret_cast<uint4*>(&src_d[0][i + N_start]) =
433+
*reinterpret_cast<const uint4*>(&sums);
434+
}
430435
}
431436

432437
__syncthreads();
433438

434-
// barreris among the blocks with the same idx (release-acuqire semantics)
439+
// barriers among the blocks with the same idx (release-acuqire semantics)
435440
if (threadIdx.x < kWorldSize) {
436441
// The all blocks notifies the other ranks.
437442
int32_t flag_block_offset = kWorldSize + blockIdx.x * kWorldSize;
@@ -445,6 +450,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
445450
} while (rank_barrier != flag);
446451
}
447452

453+
if constexpr (reduce_scatter) {
454+
// reduce scatter we can stop here and skip the allgather below
455+
return;
456+
}
457+
448458
__syncthreads();
449459

450460
// Gather all needed elts from other intra-node ranks
@@ -628,7 +638,7 @@ void two_shot_car_allreduce(
628638
#define X(kWorldSize) \
629639
if (state->world_size_ == kWorldSize) { \
630640
if (z) { \
631-
two_shot_all_reduce<kWorldSize, true> \
641+
two_shot_all_reduce<kWorldSize, true, false> \
632642
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
633643
state->rank_, \
634644
state->world_size_, \
@@ -641,7 +651,7 @@ void two_shot_car_allreduce(
641651
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
642652
return; \
643653
} else { \
644-
two_shot_all_reduce<kWorldSize, false> \
654+
two_shot_all_reduce<kWorldSize, false, false> \
645655
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
646656
state->rank_, \
647657
state->world_size_, \
@@ -667,4 +677,100 @@ void two_shot_car_allreduce(
667677
return;
668678
}
669679

680+
void car_reducescatter(
681+
at::Tensor dst,
682+
at::Tensor src,
683+
std::optional<at::Tensor> bias,
684+
int64_t comm_idx) { // match the API with nccl_allreduce in
685+
// https://fburl.com/code/v538vig9
686+
auto state = get_car_state();
687+
c10::cuda::CUDAGuard gg(dst.device());
688+
TORCH_CHECK(dst.is_contiguous());
689+
TORCH_CHECK(src.is_contiguous());
690+
TORCH_CHECK((state->world_size_ * dst.numel()) == src.numel());
691+
TORCH_CHECK(src.numel() % 8 == 0);
692+
TORCH_CHECK(src.numel() < kMaxCAR);
693+
TORCH_CHECK(
694+
state->world_size_ == 2 || state->world_size_ == 4 ||
695+
state->world_size_ == 8);
696+
697+
const auto N = src.numel();
698+
if (bias) {
699+
TORCH_CHECK(bias->numel() == src.numel());
700+
}
701+
++state->flag_;
702+
703+
std::array<at::BFloat16*, 8> inputs;
704+
for (auto ii = 0; ii < state->world_size_; ++ii) {
705+
inputs[ii] = state->buffers_[ii].data_ptr<at::BFloat16>();
706+
}
707+
708+
AT_CUDA_CHECK(cudaMemcpyAsync(
709+
inputs[state->rank_],
710+
src.data_ptr<at::BFloat16>(),
711+
src.numel() * src.element_size(),
712+
cudaMemcpyDeviceToDevice,
713+
at::cuda::getCurrentCUDAStream()));
714+
715+
std::array<int32_t*, 8> barriers;
716+
for (auto ii = 0; ii < state->world_size_; ++ii) {
717+
barriers[ii] = state->barriers_[ii].data_ptr<int32_t>();
718+
}
719+
720+
constexpr int32_t N_per_thread = 8;
721+
TORCH_CHECK(N % state->world_size_ == 0);
722+
const auto N_per_rank = N / state->world_size_;
723+
724+
TORCH_CHECK(N_per_rank % N_per_thread == 0);
725+
auto threads_per_rank = div_round_up(N_per_rank, N_per_thread);
726+
727+
#if defined(USE_ROCM)
728+
constexpr int32_t kThreadsPerBlock = 512;
729+
#else
730+
constexpr int32_t kThreadsPerBlock = 1024;
731+
#endif
732+
733+
constexpr int32_t kMaxBlocks = 24;
734+
735+
auto blocks = std::min<int32_t>(
736+
cuda_calc_block_count(threads_per_rank, kThreadsPerBlock), kMaxBlocks);
737+
738+
#define X(kWorldSize) \
739+
if (state->world_size_ == kWorldSize) { \
740+
if (bias) { \
741+
two_shot_all_reduce<kWorldSize, true, true> \
742+
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
743+
state->rank_, \
744+
state->world_size_, \
745+
state->flag_ * state->world_size_, \
746+
barriers, \
747+
inputs, \
748+
bias->data_ptr<at::BFloat16>(), \
749+
dst.data_ptr<at::BFloat16>(), \
750+
N); \
751+
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
752+
return; \
753+
} else { \
754+
two_shot_all_reduce<kWorldSize, false, true> \
755+
<<<blocks, kThreadsPerBlock, 0, at::cuda::getCurrentCUDAStream()>>>( \
756+
state->rank_, \
757+
state->world_size_, \
758+
state->flag_ * state->world_size_, \
759+
barriers, \
760+
inputs, \
761+
nullptr, \
762+
dst.data_ptr<at::BFloat16>(), \
763+
N); \
764+
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
765+
return; \
766+
} \
767+
}
768+
X(2);
769+
X(4);
770+
X(8);
771+
772+
#undef X
773+
return;
774+
}
775+
670776
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)