@@ -345,7 +345,7 @@ static DEVICE_INLINE void ld_flag_acquire(int32_t& flag, int32_t* flag_addr) {
345
345
#endif
346
346
}
347
347
348
- template <int32_t kWorldSize , bool has_acc>
348
+ template <int32_t kWorldSize , bool has_acc, bool reduce_scatter >
349
349
#if defined(USE_ROCM)
350
350
__launch_bounds__ (512 ) __global__ void two_shot_all_reduce(
351
351
#else
@@ -425,13 +425,18 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
425
425
}
426
426
427
427
// Store to the local buffer.
428
- *reinterpret_cast <uint4 *>(&src_d[0 ][i + N_start]) =
429
- *reinterpret_cast <const uint4 *>(&sums);
428
+ if constexpr (reduce_scatter) {
429
+ *reinterpret_cast <uint4 *>(&output[i]) =
430
+ *reinterpret_cast <const uint4 *>(&sums);
431
+ } else {
432
+ *reinterpret_cast <uint4 *>(&src_d[0 ][i + N_start]) =
433
+ *reinterpret_cast <const uint4 *>(&sums);
434
+ }
430
435
}
431
436
432
437
__syncthreads ();
433
438
434
- // barreris among the blocks with the same idx (release-acuqire semantics)
439
+ // barriers among the blocks with the same idx (release-acuqire semantics)
435
440
if (threadIdx .x < kWorldSize ) {
436
441
// The all blocks notifies the other ranks.
437
442
int32_t flag_block_offset = kWorldSize + blockIdx .x * kWorldSize ;
@@ -445,6 +450,11 @@ __launch_bounds__(1024) __global__ void two_shot_all_reduce(
445
450
} while (rank_barrier != flag);
446
451
}
447
452
453
+ if constexpr (reduce_scatter) {
454
+ // reduce scatter we can stop here and skip the allgather below
455
+ return ;
456
+ }
457
+
448
458
__syncthreads ();
449
459
450
460
// Gather all needed elts from other intra-node ranks
@@ -628,7 +638,7 @@ void two_shot_car_allreduce(
628
638
#define X (kWorldSize ) \
629
639
if (state->world_size_ == kWorldSize ) { \
630
640
if (z) { \
631
- two_shot_all_reduce<kWorldSize , true > \
641
+ two_shot_all_reduce<kWorldSize , true , false > \
632
642
<<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
633
643
state->rank_ , \
634
644
state->world_size_ , \
@@ -641,7 +651,7 @@ void two_shot_car_allreduce(
641
651
C10_CUDA_KERNEL_LAUNCH_CHECK (); \
642
652
return ; \
643
653
} else { \
644
- two_shot_all_reduce<kWorldSize , false > \
654
+ two_shot_all_reduce<kWorldSize , false , false > \
645
655
<<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
646
656
state->rank_ , \
647
657
state->world_size_ , \
@@ -667,4 +677,100 @@ void two_shot_car_allreduce(
667
677
return ;
668
678
}
669
679
680
+ void car_reducescatter (
681
+ at::Tensor dst,
682
+ at::Tensor src,
683
+ std::optional<at::Tensor> bias,
684
+ int64_t comm_idx) { // match the API with nccl_allreduce in
685
+ // https://fburl.com/code/v538vig9
686
+ auto state = get_car_state ();
687
+ c10::cuda::CUDAGuard gg (dst.device ());
688
+ TORCH_CHECK (dst.is_contiguous ());
689
+ TORCH_CHECK (src.is_contiguous ());
690
+ TORCH_CHECK ((state->world_size_ * dst.numel ()) == src.numel ());
691
+ TORCH_CHECK (src.numel () % 8 == 0 );
692
+ TORCH_CHECK (src.numel () < kMaxCAR );
693
+ TORCH_CHECK (
694
+ state->world_size_ == 2 || state->world_size_ == 4 ||
695
+ state->world_size_ == 8 );
696
+
697
+ const auto N = src.numel ();
698
+ if (bias) {
699
+ TORCH_CHECK (bias->numel () == src.numel ());
700
+ }
701
+ ++state->flag_ ;
702
+
703
+ std::array<at::BFloat16*, 8 > inputs;
704
+ for (auto ii = 0 ; ii < state->world_size_ ; ++ii) {
705
+ inputs[ii] = state->buffers_ [ii].data_ptr <at::BFloat16>();
706
+ }
707
+
708
+ AT_CUDA_CHECK (cudaMemcpyAsync (
709
+ inputs[state->rank_ ],
710
+ src.data_ptr <at::BFloat16>(),
711
+ src.numel () * src.element_size (),
712
+ cudaMemcpyDeviceToDevice,
713
+ at::cuda::getCurrentCUDAStream ()));
714
+
715
+ std::array<int32_t *, 8 > barriers;
716
+ for (auto ii = 0 ; ii < state->world_size_ ; ++ii) {
717
+ barriers[ii] = state->barriers_ [ii].data_ptr <int32_t >();
718
+ }
719
+
720
+ constexpr int32_t N_per_thread = 8 ;
721
+ TORCH_CHECK (N % state->world_size_ == 0 );
722
+ const auto N_per_rank = N / state->world_size_ ;
723
+
724
+ TORCH_CHECK (N_per_rank % N_per_thread == 0 );
725
+ auto threads_per_rank = div_round_up (N_per_rank, N_per_thread);
726
+
727
+ #if defined(USE_ROCM)
728
+ constexpr int32_t kThreadsPerBlock = 512 ;
729
+ #else
730
+ constexpr int32_t kThreadsPerBlock = 1024 ;
731
+ #endif
732
+
733
+ constexpr int32_t kMaxBlocks = 24 ;
734
+
735
+ auto blocks = std::min<int32_t >(
736
+ cuda_calc_block_count (threads_per_rank, kThreadsPerBlock ), kMaxBlocks );
737
+
738
+ #define X (kWorldSize ) \
739
+ if (state->world_size_ == kWorldSize ) { \
740
+ if (bias) { \
741
+ two_shot_all_reduce<kWorldSize , true , true > \
742
+ <<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
743
+ state->rank_ , \
744
+ state->world_size_ , \
745
+ state->flag_ * state->world_size_ , \
746
+ barriers, \
747
+ inputs, \
748
+ bias->data_ptr <at::BFloat16>(), \
749
+ dst.data_ptr <at::BFloat16>(), \
750
+ N); \
751
+ C10_CUDA_KERNEL_LAUNCH_CHECK (); \
752
+ return ; \
753
+ } else { \
754
+ two_shot_all_reduce<kWorldSize , false , true > \
755
+ <<<blocks, kThreadsPerBlock , 0 , at::cuda::getCurrentCUDAStream()>>> ( \
756
+ state->rank_ , \
757
+ state->world_size_ , \
758
+ state->flag_ * state->world_size_ , \
759
+ barriers, \
760
+ inputs, \
761
+ nullptr , \
762
+ dst.data_ptr <at::BFloat16>(), \
763
+ N); \
764
+ C10_CUDA_KERNEL_LAUNCH_CHECK (); \
765
+ return ; \
766
+ } \
767
+ }
768
+ X (2 );
769
+ X (4 );
770
+ X (8 );
771
+
772
+ #undef X
773
+ return ;
774
+ }
775
+
670
776
} // namespace fbgemm_gpu
0 commit comments