Skip to content

NCCL AllReduce 2400x slower via Ray actors vs torchrun on H200 (NVSwitch) multi-node #61073

@dmvevents

Description

@dmvevents

Environment

  • Nodes: 2x AWS P5en.48xlarge (8x NVIDIA H200 143GB per node, NVSwitch intra-node)
  • Networking: 16 EFA v2 devices per node, EKS Kubernetes with hostNetwork: true
  • Container: NCCL 2.27.5, aws-ofi-nccl 1.18.0, Libfabric 2.4, PyTorch 2.x
  • OS: Ubuntu 22.04 (EKS AMI)
  • GPU persistent mode: enabled (nvidia-smi -pm 1)

Problem

When using Ray actors with num_gpus=1 and calling torch.distributed.init_process_group('nccl') followed by all_reduce, NCCL performance is catastrophically degraded on H200 (P5en) multi-node compared to torchrun on the exact same hardware.

Benchmark comparison (same nodes, same GPUs, same network)

Launcher Operation Latency Notes
torchrun AllReduce 4MB 1.5 ms Expected performance
Ray actors AllReduce 4KB 3,660 ms 2400x slower than torchrun
Ray actors AllReduce 933MB HANGS forever Never completes, timeout required

Critical detail: works on P5, fails on P5en

  • The exact same Ray code on P5.48xlarge (H100, no NVSwitch) works correctly with expected NCCL performance.
  • The issue is specific to H200 with NVSwitch topology (P5en.48xlarge).

Environment variables tried (all failed to resolve)

NCCL_P2P_DISABLE=1
NCCL_CUMEM_ENABLE=0
NCCL_NVLS_ENABLE=0
NCCL_NET_GDR_LEVEL=SYS
NCCL_SHM_DISABLE=1

None of these changed the behavior. The AllReduce remains 2400x slower or hangs.

Likely root cause

Ray sets CUDA_VISIBLE_DEVICES to remap GPU indices for each actor (e.g., actor on physical GPU 5 sees CUDA_VISIBLE_DEVICES=5 and uses logical device 0). On H200 nodes with NVSwitch, this remapping appears to cause NCCL to misidentify the hardware topology, leading it to select incorrect transport paths (likely falling back to slow socket transport instead of NVSwitch/NVLink or EFA/RDMA).

Evidence:

  1. torchrun (which does not remap CUDA_VISIBLE_DEVICES the same way) works perfectly on the same nodes.
  2. P5.48xlarge (H100, PCIe-based without NVSwitch) works fine with Ray — suggesting NVSwitch topology discovery is what breaks.
  3. NCCL debug logs show the communicator initializes but data transfer is orders of magnitude slower than hardware capability.

Steps to reproduce

  1. Launch 2x P5en.48xlarge nodes on EKS with EFA and hostNetwork: true
  2. Start a Ray cluster across both nodes
  3. Create 16 Ray actors (8 per node) with num_gpus=1
  4. In each actor, call torch.distributed.init_process_group('nccl', ...) using TCPStore
  5. Run torch.distributed.all_reduce(tensor) with a 4KB+ tensor
  6. Observe multi-second latency or hang

For comparison, run the same all_reduce via torchrun --nproc_per_node=8 --nnodes=2 on the same nodes — it completes in ~1.5ms for 4MB.

Related issues

Expected behavior

Ray actors should achieve NCCL AllReduce performance comparable to torchrun on H200 NVSwitch nodes (~1.5ms for 4MB, not 3660ms for 4KB).

Versions

  • Ray: 2.44.1
  • PyTorch: 2.x
  • NCCL: 2.27.5
  • CUDA: 12.x
  • aws-ofi-nccl: 1.18.0
  • Libfabric: 2.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    community-backlogcoreIssues that should be addressed in Ray Core

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions