Skip to content

Commit 0e0d2ea

Browse files
sarunyafacebook-github-bot
authored andcommitted
Fix stream sync for scratch pad eviction
Differential Revision: D59716516
1 parent bc78e2e commit 0e0d2ea

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,13 @@ def __init__(
270270
# pyre-fixme[20]: Argument `self` expected.
271271
(low_priority, high_priority) = torch.cuda.Stream.priority_range()
272272
self.ssd_stream = torch.cuda.Stream(priority=low_priority)
273-
self.ssd_set_start = torch.cuda.Event()
274-
self.ssd_set_end = torch.cuda.Event()
273+
274+
# SSD events
275+
self.ssd_event_get = torch.cuda.Event()
276+
self.ssd_event_evict = torch.cuda.Event()
277+
self.ssd_event_backward = torch.cuda.Event()
278+
self.ssd_event_evict_sp = torch.cuda.Event()
279+
275280
self.timesteps_prefetched: List[int] = []
276281
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
277282
# TODO: add type annotation
@@ -391,31 +396,49 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:
391396
return t_cpu
392397

393398
def evict(
394-
self, evicted_rows: Tensor, evicted_indices: Tensor, actions_count_cpu: Tensor
399+
self,
400+
evicted_rows: Tensor,
401+
evicted_indices: Tensor,
402+
actions_count_cpu: Tensor,
403+
eviction_stream: torch.cuda.Stream,
404+
pre_event: torch.cuda.Event,
405+
post_event: torch.cuda.Event,
395406
) -> None:
396407
"""
397408
Evict data from the given input tensors to SSD via RocksDB
398409
"""
399-
with torch.cuda.stream(self.ssd_stream):
400-
self.ssd_stream.wait_event(self.ssd_set_start)
410+
with torch.cuda.stream(eviction_stream):
411+
eviction_stream.wait_event(pre_event)
412+
401413
evicted_rows_cpu = self.to_pinned_cpu(evicted_rows)
402414
evicted_indices_cpu = self.to_pinned_cpu(evicted_indices)
403-
evicted_rows.record_stream(self.ssd_stream)
404-
evicted_indices.record_stream(self.ssd_stream)
415+
416+
evicted_rows.record_stream(eviction_stream)
417+
evicted_indices.record_stream(eviction_stream)
418+
405419
self.ssd_db.set_cuda(
406420
evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep
407421
)
422+
408423
# TODO: is this needed?
409424
# Need a way to synchronize
410425
# actions_count_cpu.record_stream(self.ssd_stream)
411-
self.ssd_stream.record_event(self.ssd_set_end)
426+
eviction_stream.record_event(post_event)
412427

413428
def _evict_from_scratch_pad(self, grad: Tensor) -> None:
414429
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
415430
(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) = (
416431
self.ssd_scratch_pads.pop(0)
417432
)
418-
self.evict(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu)
433+
torch.cuda.current_stream().record_event(self.ssd_event_backward)
434+
self.evict(
435+
inserted_rows_gpu,
436+
post_bwd_evicted_indices,
437+
actions_count_cpu,
438+
self.ssd_stream,
439+
self.ssd_event_backward,
440+
self.ssd_event_evict_sp,
441+
)
419442

420443
def _compute_cache_ptrs(
421444
self,
@@ -508,11 +531,13 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
508531
current_stream = torch.cuda.current_stream()
509532

510533
# Ensure the previous iterations l3_db.set(..) has completed.
511-
current_stream.wait_event(self.ssd_set_end)
534+
current_stream.wait_event(self.ssd_event_evict)
535+
current_stream.wait_event(self.ssd_event_evict_sp)
536+
512537
self.ssd_db.get_cuda(
513538
self.to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu
514539
)
515-
current_stream.record_event(self.ssd_set_start)
540+
current_stream.record_event(self.ssd_event_get)
516541
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
517542
# Should we allocate on HBM?
518543
inserted_rows_gpu = inserted_rows.cuda(non_blocking=True)
@@ -525,7 +550,14 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
525550
)
526551

527552
# Evict rows from cache to SSD
528-
self.evict(evicted_rows, evicted_indices, actions_count_cpu)
553+
self.evict(
554+
evicted_rows,
555+
evicted_indices,
556+
actions_count_cpu,
557+
self.ssd_stream,
558+
self.ssd_event_get,
559+
self.ssd_event_evict,
560+
)
529561

530562
# TODO: keep only necessary tensors
531563
self.ssd_prefetch_data.append(

0 commit comments

Comments
 (0)