Skip to content

Commit cc7cd60

Browse files
zdevitofacebook-github-bot
authored andcommitted
Use PyTorch's p2p access enable function (pytorch#2000)
Summary: We split the diff after adding a needed lazy cuda init call in enable p2p access function. Diff 1: D48939723 [PyTorch] Add the lazy init call for p2p access function *Prior context* cudaEnablePeerAccess only enables cross device access for memory allocated with cudaMalloc. When using other cuda APIs such cuMemMap, peer access is managed differently. expandable_segments:True in PyTorch uses cuMemMap, so code that just calls cudaEnablePeerAccess is not sufficient to enable cross-device copies. This patch switching the p2p access enabling functions to use PyTorchs `get_p2p_access` which lets its allocator figure out how to correctly enable p2p access for that memory. In the normal case (expandable_segments:False), this code performs exactly the same cuda calls as before. Differential Revision: D49021817
1 parent 9d6ba13 commit cc7cd60

File tree

1 file changed

+2
-9
lines changed

1 file changed

+2
-9
lines changed

fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ATen/core/op_registration/op_registration.h>
1111
#include <ATen/cuda/CUDAContext.h>
1212
#include <ATen/cuda/CUDAEvent.h>
13+
#include <ATen/cuda/PeerToPeerAccess.h>
1314
#include <ATen/native/TensorAdvancedIndexing.h>
1415
#include <c10/core/Device.h>
1516
#include <c10/core/TensorOptions.h>
@@ -562,15 +563,7 @@ void init_p2p_access() {
562563
for (const auto i : c10::irange(at::cuda::getNumGPUs())) {
563564
for (const auto j : c10::irange(at::cuda::getNumGPUs())) {
564565
if (i != j) {
565-
at::cuda::CUDAGuard g(i);
566-
const auto err =
567-
C10_CUDA_ERROR_HANDLED(cudaDeviceEnablePeerAccess(j, 0));
568-
if (err == cudaErrorPeerAccessAlreadyEnabled) {
569-
// ignore and clear the error if access was already enabled
570-
C10_CUDA_CLEAR_ERROR();
571-
} else {
572-
AT_CUDA_CHECK(err);
573-
}
566+
AT_ASSERT(at::cuda::get_p2p_access(i, j));
574567
}
575568
}
576569
}

0 commit comments

Comments
 (0)