Skip to content

Commit 82bffd1

Browse files
sryapfacebook-github-bot
authored andcommitted
Make the scratch pad tensor UVA (pytorch#2844)
Summary: Pull Request resolved: pytorch#2844 Before this diff, the scratch pad in SSD TBE (see D55998215 for more detail) was a CPU tensor which was later transferred to GPU to allow the TBE kernels to access it. The scratch pad tranfer was highly inefficient since TBE over provisioned the scratch pad buffer allocation (as it did not know the exact number of cache missed rows) causing extra data transfer. Such the extra data transfer could be large since the number of cache missed rows was normally much smaller than value that TBE over provisioned. There are two ways to avoid the extra data transfer: (1) Let TBE have the exact number of cache missed rows on host which requires device-to-host data transfer which will cause a sync point between host and device (not desirable in most trainings). However, this will allow TBE to use `cudaMemcpy` which will utilize the DMA engine and will allow the memory copy to overlap efficiently with other compute kernels. (2) Make the scratch pad accessible by both CPU and GPU. In other words, make the scratch pad a UVA tensor. This does not require device and host synchornization. However, the memory copy has to be done through CUDA load/store which requires a kernel to run on SMs. Thus, the memory copy and compute kernel overlapping will require a careful SMs management. Based on the tradeoffs explained above, we chose to implement (2) to avoid the host and device sync point. Differential Revision: D58631974
1 parent 0e0d2ea commit 82bffd1

File tree

1 file changed

+87
-51
lines changed

1 file changed

+87
-51
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 87 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __init__(
278278
self.ssd_event_evict_sp = torch.cuda.Event()
279279

280280
self.timesteps_prefetched: List[int] = []
281-
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
281+
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
282282
# TODO: add type annotation
283283
# pyre-fixme[4]: Attribute must be annotated.
284284
self.ssd_prefetch_data = []
@@ -397,48 +397,71 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:
397397

398398
def evict(
399399
self,
400-
evicted_rows: Tensor,
401-
evicted_indices: Tensor,
400+
rows: Tensor,
401+
indices: Tensor,
402402
actions_count_cpu: Tensor,
403-
eviction_stream: torch.cuda.Stream,
403+
stream: torch.cuda.Stream,
404404
pre_event: torch.cuda.Event,
405405
post_event: torch.cuda.Event,
406+
is_rows_uvm: bool,
406407
) -> None:
407408
"""
408409
Evict data from the given input tensors to SSD via RocksDB
410+
411+
Args:
412+
rows (Tensor): The 2D tensor that contains rows to evict
413+
indices (Tensor): The 1D tensor that contains the row indices that
414+
the rows will be evicted to
415+
actions_count_cpu (Tensor): A scalar tensor that contains the
416+
number of rows that the evict function
417+
has to process
418+
stream (Stream): The CUDA stream that cudaStreamAddCallback will
419+
synchronize the host function with. Moreover, the
420+
asynchronous D->H memory copies will operate on
421+
this stream
422+
pre_event (Event): The CUDA event that the stream has to wait on
423+
post_event (Event): The CUDA event that the current will record on
424+
when the eviction is done
425+
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
426+
tensor (which is accessible on both host and
427+
device)
428+
Returns:
429+
None
409430
"""
410-
with torch.cuda.stream(eviction_stream):
411-
eviction_stream.wait_event(pre_event)
431+
with torch.cuda.stream(stream):
432+
stream.wait_event(pre_event)
412433

413-
evicted_rows_cpu = self.to_pinned_cpu(evicted_rows)
414-
evicted_indices_cpu = self.to_pinned_cpu(evicted_indices)
434+
rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
435+
indices_cpu = self.to_pinned_cpu(indices)
415436

416-
evicted_rows.record_stream(eviction_stream)
417-
evicted_indices.record_stream(eviction_stream)
437+
rows.record_stream(stream)
438+
indices.record_stream(stream)
418439

419440
self.ssd_db.set_cuda(
420-
evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep
441+
indices_cpu, rows_cpu, actions_count_cpu, self.timestep
421442
)
422443

423444
# TODO: is this needed?
424445
# Need a way to synchronize
425446
# actions_count_cpu.record_stream(self.ssd_stream)
426-
eviction_stream.record_event(post_event)
447+
stream.record_event(post_event)
427448

428449
def _evict_from_scratch_pad(self, grad: Tensor) -> None:
429450
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
430-
(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) = (
451+
(inserted_rows, post_bwd_evicted_indices, actions_count_cpu, do_evict) = (
431452
self.ssd_scratch_pads.pop(0)
432453
)
433-
torch.cuda.current_stream().record_event(self.ssd_event_backward)
434-
self.evict(
435-
inserted_rows_gpu,
436-
post_bwd_evicted_indices,
437-
actions_count_cpu,
438-
self.ssd_stream,
439-
self.ssd_event_backward,
440-
self.ssd_event_evict_sp,
441-
)
454+
if do_evict:
455+
torch.cuda.current_stream().record_event(self.ssd_event_backward)
456+
self.evict(
457+
inserted_rows,
458+
post_bwd_evicted_indices,
459+
actions_count_cpu,
460+
self.ssd_stream,
461+
self.ssd_event_backward,
462+
self.ssd_event_evict_sp,
463+
is_rows_uvm=True,
464+
)
442465

443466
def _compute_cache_ptrs(
444467
self,
@@ -447,7 +470,7 @@ def _compute_cache_ptrs(
447470
linear_index_inverse_indices: torch.Tensor,
448471
unique_indices_count_cumsum: torch.Tensor,
449472
cache_set_inverse_indices: torch.Tensor,
450-
inserted_rows_gpu: torch.Tensor,
473+
inserted_rows: torch.Tensor,
451474
unique_indices_length: torch.Tensor,
452475
inserted_indices: torch.Tensor,
453476
actions_count_cpu: torch.Tensor,
@@ -468,7 +491,7 @@ def _compute_cache_ptrs(
468491
unique_indices_count_cumsum,
469492
cache_set_inverse_indices,
470493
self.lxu_cache_weights,
471-
inserted_rows_gpu,
494+
inserted_rows,
472495
unique_indices_length,
473496
inserted_indices,
474497
)
@@ -477,14 +500,19 @@ def _compute_cache_ptrs(
477500
with record_function("## ssd_scratch_pads ##"):
478501
# Store scratch pad info for post backward eviction
479502
self.ssd_scratch_pads.append(
480-
(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu)
503+
(
504+
inserted_rows,
505+
post_bwd_evicted_indices,
506+
actions_count_cpu,
507+
linear_cache_indices.numel() > 0,
508+
)
481509
)
482510

483511
# pyre-fixme[7]: Expected `Tensor` but got `Tuple[typing.Any, Tensor,
484512
# typing.Any, Tensor]`.
485513
return (
486514
lxu_cache_ptrs,
487-
inserted_rows_gpu,
515+
inserted_rows,
488516
post_bwd_evicted_indices,
489517
actions_count_cpu,
490518
)
@@ -522,42 +550,50 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
522550
evicted_rows = self.lxu_cache_weights[
523551
assigned_cache_slots.clamp(min=0).long(), :
524552
]
525-
inserted_rows = torch.empty(
526-
evicted_rows.shape,
527-
dtype=self.lxu_cache_weights.dtype,
528-
pin_memory=True,
529-
)
553+
554+
if linear_cache_indices.numel() > 0:
555+
inserted_rows = torch.ops.fbgemm.new_managed_tensor(
556+
torch.zeros(
557+
1, device=self.current_device, dtype=self.lxu_cache_weights.dtype
558+
),
559+
evicted_rows.shape,
560+
)
561+
else:
562+
inserted_rows = torch.empty(
563+
evicted_rows.shape,
564+
dtype=self.lxu_cache_weights.dtype,
565+
device=self.current_device,
566+
)
530567

531568
current_stream = torch.cuda.current_stream()
532569

570+
inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)
571+
533572
# Ensure the previous iterations l3_db.set(..) has completed.
534573
current_stream.wait_event(self.ssd_event_evict)
535574
current_stream.wait_event(self.ssd_event_evict_sp)
536-
537-
self.ssd_db.get_cuda(
538-
self.to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu
539-
)
575+
if linear_cache_indices.numel() > 0:
576+
self.ssd_db.get_cuda(inserted_indices_cpu, inserted_rows, actions_count_cpu)
540577
current_stream.record_event(self.ssd_event_get)
541-
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
542-
# Should we allocate on HBM?
543-
inserted_rows_gpu = inserted_rows.cuda(non_blocking=True)
544578

545579
torch.ops.fbgemm.masked_index_put(
546580
self.lxu_cache_weights,
547581
assigned_cache_slots,
548-
inserted_rows_gpu,
582+
inserted_rows,
549583
actions_count_gpu,
550584
)
551585

552-
# Evict rows from cache to SSD
553-
self.evict(
554-
evicted_rows,
555-
evicted_indices,
556-
actions_count_cpu,
557-
self.ssd_stream,
558-
self.ssd_event_get,
559-
self.ssd_event_evict,
560-
)
586+
if linear_cache_indices.numel() > 0:
587+
# Evict rows from cache to SSD
588+
self.evict(
589+
evicted_rows,
590+
evicted_indices,
591+
actions_count_cpu,
592+
self.ssd_stream,
593+
self.ssd_event_get,
594+
self.ssd_event_evict,
595+
is_rows_uvm=False,
596+
)
561597

562598
# TODO: keep only necessary tensors
563599
self.ssd_prefetch_data.append(
@@ -567,7 +603,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
567603
linear_index_inverse_indices,
568604
unique_indices_count_cumsum,
569605
cache_set_inverse_indices,
570-
inserted_rows_gpu,
606+
inserted_rows,
571607
unique_indices_length,
572608
inserted_indices,
573609
actions_count_cpu,
@@ -593,7 +629,7 @@ def forward(
593629
prefetch_data = self.ssd_prefetch_data.pop(0)
594630
(
595631
lxu_cache_ptrs,
596-
inserted_rows_gpu,
632+
inserted_rows,
597633
post_bwd_evicted_indices,
598634
actions_count_cpu,
599635
) = self._compute_cache_ptrs(*prefetch_data)
@@ -635,7 +671,7 @@ def forward(
635671
# codegen/genscript/optimizer_args.py
636672
ssd_tensors={
637673
"row_addrs": lxu_cache_ptrs,
638-
"inserted_rows": inserted_rows_gpu,
674+
"inserted_rows": inserted_rows,
639675
"post_bwd_evicted_indices": post_bwd_evicted_indices,
640676
"actions_count": actions_count_cpu,
641677
},

0 commit comments

Comments
 (0)