@@ -270,8 +270,13 @@ def __init__(
270
270
# pyre-fixme[20]: Argument `self` expected.
271
271
(low_priority , high_priority ) = torch .cuda .Stream .priority_range ()
272
272
self .ssd_stream = torch .cuda .Stream (priority = low_priority )
273
- self .ssd_set_start = torch .cuda .Event ()
274
- self .ssd_set_end = torch .cuda .Event ()
273
+
274
+ # SSD events
275
+ self .ssd_event_get = torch .cuda .Event ()
276
+ self .ssd_event_evict = torch .cuda .Event ()
277
+ self .ssd_event_backward = torch .cuda .Event ()
278
+ self .ssd_event_evict_sp = torch .cuda .Event ()
279
+
275
280
self .timesteps_prefetched : List [int ] = []
276
281
self .ssd_scratch_pads : List [Tuple [Tensor , Tensor , Tensor ]] = []
277
282
# TODO: add type annotation
@@ -391,31 +396,49 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:
391
396
return t_cpu
392
397
393
398
def evict (
394
- self , evicted_rows : Tensor , evicted_indices : Tensor , actions_count_cpu : Tensor
399
+ self ,
400
+ evicted_rows : Tensor ,
401
+ evicted_indices : Tensor ,
402
+ actions_count_cpu : Tensor ,
403
+ eviction_stream : torch .cuda .Stream ,
404
+ pre_event : torch .cuda .Event ,
405
+ post_event : torch .cuda .Event ,
395
406
) -> None :
396
407
"""
397
408
Evict data from the given input tensors to SSD via RocksDB
398
409
"""
399
- with torch .cuda .stream (self .ssd_stream ):
400
- self .ssd_stream .wait_event (self .ssd_set_start )
410
+ with torch .cuda .stream (eviction_stream ):
411
+ eviction_stream .wait_event (pre_event )
412
+
401
413
evicted_rows_cpu = self .to_pinned_cpu (evicted_rows )
402
414
evicted_indices_cpu = self .to_pinned_cpu (evicted_indices )
403
- evicted_rows .record_stream (self .ssd_stream )
404
- evicted_indices .record_stream (self .ssd_stream )
415
+
416
+ evicted_rows .record_stream (eviction_stream )
417
+ evicted_indices .record_stream (eviction_stream )
418
+
405
419
self .ssd_db .set_cuda (
406
420
evicted_indices_cpu , evicted_rows_cpu , actions_count_cpu , self .timestep
407
421
)
422
+
408
423
# TODO: is this needed?
409
424
# Need a way to synchronize
410
425
# actions_count_cpu.record_stream(self.ssd_stream)
411
- self . ssd_stream . record_event (self . ssd_set_end )
426
+ eviction_stream . record_event (post_event )
412
427
413
428
def _evict_from_scratch_pad (self , grad : Tensor ) -> None :
414
429
assert len (self .ssd_scratch_pads ) > 0 , "There must be at least one scratch pad"
415
430
(inserted_rows_gpu , post_bwd_evicted_indices , actions_count_cpu ) = (
416
431
self .ssd_scratch_pads .pop (0 )
417
432
)
418
- self .evict (inserted_rows_gpu , post_bwd_evicted_indices , actions_count_cpu )
433
+ torch .cuda .current_stream ().record_event (self .ssd_event_backward )
434
+ self .evict (
435
+ inserted_rows_gpu ,
436
+ post_bwd_evicted_indices ,
437
+ actions_count_cpu ,
438
+ self .ssd_stream ,
439
+ self .ssd_event_backward ,
440
+ self .ssd_event_evict_sp ,
441
+ )
419
442
420
443
def _compute_cache_ptrs (
421
444
self ,
@@ -508,11 +531,13 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
508
531
current_stream = torch .cuda .current_stream ()
509
532
510
533
# Ensure the previous iterations l3_db.set(..) has completed.
511
- current_stream .wait_event (self .ssd_set_end )
534
+ current_stream .wait_event (self .ssd_event_evict )
535
+ current_stream .wait_event (self .ssd_event_evict_sp )
536
+
512
537
self .ssd_db .get_cuda (
513
538
self .to_pinned_cpu (inserted_indices ), inserted_rows , actions_count_cpu
514
539
)
515
- current_stream .record_event (self .ssd_set_start )
540
+ current_stream .record_event (self .ssd_event_get )
516
541
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
517
542
# Should we allocate on HBM?
518
543
inserted_rows_gpu = inserted_rows .cuda (non_blocking = True )
@@ -525,7 +550,14 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
525
550
)
526
551
527
552
# Evict rows from cache to SSD
528
- self .evict (evicted_rows , evicted_indices , actions_count_cpu )
553
+ self .evict (
554
+ evicted_rows ,
555
+ evicted_indices ,
556
+ actions_count_cpu ,
557
+ self .ssd_stream ,
558
+ self .ssd_event_get ,
559
+ self .ssd_event_evict ,
560
+ )
529
561
530
562
# TODO: keep only necessary tensors
531
563
self .ssd_prefetch_data .append (
0 commit comments