@@ -278,7 +278,7 @@ def __init__(
278
278
self .ssd_event_evict_sp = torch .cuda .Event ()
279
279
280
280
self .timesteps_prefetched : List [int ] = []
281
- self .ssd_scratch_pads : List [Tuple [Tensor , Tensor , Tensor ]] = []
281
+ self .ssd_scratch_pads : List [Tuple [Tensor , Tensor , Tensor , bool ]] = []
282
282
# TODO: add type annotation
283
283
# pyre-fixme[4]: Attribute must be annotated.
284
284
self .ssd_prefetch_data = []
@@ -397,48 +397,71 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:
397
397
398
398
def evict (
399
399
self ,
400
- evicted_rows : Tensor ,
401
- evicted_indices : Tensor ,
400
+ rows : Tensor ,
401
+ indices : Tensor ,
402
402
actions_count_cpu : Tensor ,
403
- eviction_stream : torch .cuda .Stream ,
403
+ stream : torch .cuda .Stream ,
404
404
pre_event : torch .cuda .Event ,
405
405
post_event : torch .cuda .Event ,
406
+ is_rows_uvm : bool ,
406
407
) -> None :
407
408
"""
408
409
Evict data from the given input tensors to SSD via RocksDB
410
+
411
+ Args:
412
+ rows (Tensor): The 2D tensor that contains rows to evict
413
+ indices (Tensor): The 1D tensor that contains the row indices that
414
+ the rows will be evicted to
415
+ actions_count_cpu (Tensor): A scalar tensor that contains the
416
+ number of rows that the evict function
417
+ has to process
418
+ stream (Stream): The CUDA stream that cudaStreamAddCallback will
419
+ synchronize the host function with. Moreover, the
420
+ asynchronous D->H memory copies will operate on
421
+ this stream
422
+ pre_event (Event): The CUDA event that the stream has to wait on
423
+ post_event (Event): The CUDA event that the current will record on
424
+ when the eviction is done
425
+ is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
426
+ tensor (which is accessible on both host and
427
+ device)
428
+ Returns:
429
+ None
409
430
"""
410
- with torch .cuda .stream (eviction_stream ):
411
- eviction_stream .wait_event (pre_event )
431
+ with torch .cuda .stream (stream ):
432
+ stream .wait_event (pre_event )
412
433
413
- evicted_rows_cpu = self .to_pinned_cpu (evicted_rows )
414
- evicted_indices_cpu = self .to_pinned_cpu (evicted_indices )
434
+ rows_cpu = rows if is_rows_uvm else self .to_pinned_cpu (rows )
435
+ indices_cpu = self .to_pinned_cpu (indices )
415
436
416
- evicted_rows .record_stream (eviction_stream )
417
- evicted_indices .record_stream (eviction_stream )
437
+ rows .record_stream (stream )
438
+ indices .record_stream (stream )
418
439
419
440
self .ssd_db .set_cuda (
420
- evicted_indices_cpu , evicted_rows_cpu , actions_count_cpu , self .timestep
441
+ indices_cpu , rows_cpu , actions_count_cpu , self .timestep
421
442
)
422
443
423
444
# TODO: is this needed?
424
445
# Need a way to synchronize
425
446
# actions_count_cpu.record_stream(self.ssd_stream)
426
- eviction_stream .record_event (post_event )
447
+ stream .record_event (post_event )
427
448
428
449
def _evict_from_scratch_pad (self , grad : Tensor ) -> None :
429
450
assert len (self .ssd_scratch_pads ) > 0 , "There must be at least one scratch pad"
430
- (inserted_rows_gpu , post_bwd_evicted_indices , actions_count_cpu ) = (
451
+ (inserted_rows , post_bwd_evicted_indices , actions_count_cpu , do_evict ) = (
431
452
self .ssd_scratch_pads .pop (0 )
432
453
)
433
- torch .cuda .current_stream ().record_event (self .ssd_event_backward )
434
- self .evict (
435
- inserted_rows_gpu ,
436
- post_bwd_evicted_indices ,
437
- actions_count_cpu ,
438
- self .ssd_stream ,
439
- self .ssd_event_backward ,
440
- self .ssd_event_evict_sp ,
441
- )
454
+ if do_evict :
455
+ torch .cuda .current_stream ().record_event (self .ssd_event_backward )
456
+ self .evict (
457
+ inserted_rows ,
458
+ post_bwd_evicted_indices ,
459
+ actions_count_cpu ,
460
+ self .ssd_stream ,
461
+ self .ssd_event_backward ,
462
+ self .ssd_event_evict_sp ,
463
+ is_rows_uvm = True ,
464
+ )
442
465
443
466
def _compute_cache_ptrs (
444
467
self ,
@@ -447,7 +470,7 @@ def _compute_cache_ptrs(
447
470
linear_index_inverse_indices : torch .Tensor ,
448
471
unique_indices_count_cumsum : torch .Tensor ,
449
472
cache_set_inverse_indices : torch .Tensor ,
450
- inserted_rows_gpu : torch .Tensor ,
473
+ inserted_rows : torch .Tensor ,
451
474
unique_indices_length : torch .Tensor ,
452
475
inserted_indices : torch .Tensor ,
453
476
actions_count_cpu : torch .Tensor ,
@@ -468,7 +491,7 @@ def _compute_cache_ptrs(
468
491
unique_indices_count_cumsum ,
469
492
cache_set_inverse_indices ,
470
493
self .lxu_cache_weights ,
471
- inserted_rows_gpu ,
494
+ inserted_rows ,
472
495
unique_indices_length ,
473
496
inserted_indices ,
474
497
)
@@ -477,14 +500,19 @@ def _compute_cache_ptrs(
477
500
with record_function ("## ssd_scratch_pads ##" ):
478
501
# Store scratch pad info for post backward eviction
479
502
self .ssd_scratch_pads .append (
480
- (inserted_rows_gpu , post_bwd_evicted_indices , actions_count_cpu )
503
+ (
504
+ inserted_rows ,
505
+ post_bwd_evicted_indices ,
506
+ actions_count_cpu ,
507
+ linear_cache_indices .numel () > 0 ,
508
+ )
481
509
)
482
510
483
511
# pyre-fixme[7]: Expected `Tensor` but got `Tuple[typing.Any, Tensor,
484
512
# typing.Any, Tensor]`.
485
513
return (
486
514
lxu_cache_ptrs ,
487
- inserted_rows_gpu ,
515
+ inserted_rows ,
488
516
post_bwd_evicted_indices ,
489
517
actions_count_cpu ,
490
518
)
@@ -522,42 +550,50 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
522
550
evicted_rows = self .lxu_cache_weights [
523
551
assigned_cache_slots .clamp (min = 0 ).long (), :
524
552
]
525
- inserted_rows = torch .empty (
526
- evicted_rows .shape ,
527
- dtype = self .lxu_cache_weights .dtype ,
528
- pin_memory = True ,
529
- )
553
+
554
+ if linear_cache_indices .numel () > 0 :
555
+ inserted_rows = torch .ops .fbgemm .new_managed_tensor (
556
+ torch .zeros (
557
+ 1 , device = self .current_device , dtype = self .lxu_cache_weights .dtype
558
+ ),
559
+ evicted_rows .shape ,
560
+ )
561
+ else :
562
+ inserted_rows = torch .empty (
563
+ evicted_rows .shape ,
564
+ dtype = self .lxu_cache_weights .dtype ,
565
+ device = self .current_device ,
566
+ )
530
567
531
568
current_stream = torch .cuda .current_stream ()
532
569
570
+ inserted_indices_cpu = self .to_pinned_cpu (inserted_indices )
571
+
533
572
# Ensure the previous iterations l3_db.set(..) has completed.
534
573
current_stream .wait_event (self .ssd_event_evict )
535
574
current_stream .wait_event (self .ssd_event_evict_sp )
536
-
537
- self .ssd_db .get_cuda (
538
- self .to_pinned_cpu (inserted_indices ), inserted_rows , actions_count_cpu
539
- )
575
+ if linear_cache_indices .numel () > 0 :
576
+ self .ssd_db .get_cuda (inserted_indices_cpu , inserted_rows , actions_count_cpu )
540
577
current_stream .record_event (self .ssd_event_get )
541
- # TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
542
- # Should we allocate on HBM?
543
- inserted_rows_gpu = inserted_rows .cuda (non_blocking = True )
544
578
545
579
torch .ops .fbgemm .masked_index_put (
546
580
self .lxu_cache_weights ,
547
581
assigned_cache_slots ,
548
- inserted_rows_gpu ,
582
+ inserted_rows ,
549
583
actions_count_gpu ,
550
584
)
551
585
552
- # Evict rows from cache to SSD
553
- self .evict (
554
- evicted_rows ,
555
- evicted_indices ,
556
- actions_count_cpu ,
557
- self .ssd_stream ,
558
- self .ssd_event_get ,
559
- self .ssd_event_evict ,
560
- )
586
+ if linear_cache_indices .numel () > 0 :
587
+ # Evict rows from cache to SSD
588
+ self .evict (
589
+ evicted_rows ,
590
+ evicted_indices ,
591
+ actions_count_cpu ,
592
+ self .ssd_stream ,
593
+ self .ssd_event_get ,
594
+ self .ssd_event_evict ,
595
+ is_rows_uvm = False ,
596
+ )
561
597
562
598
# TODO: keep only necessary tensors
563
599
self .ssd_prefetch_data .append (
@@ -567,7 +603,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
567
603
linear_index_inverse_indices ,
568
604
unique_indices_count_cumsum ,
569
605
cache_set_inverse_indices ,
570
- inserted_rows_gpu ,
606
+ inserted_rows ,
571
607
unique_indices_length ,
572
608
inserted_indices ,
573
609
actions_count_cpu ,
@@ -593,7 +629,7 @@ def forward(
593
629
prefetch_data = self .ssd_prefetch_data .pop (0 )
594
630
(
595
631
lxu_cache_ptrs ,
596
- inserted_rows_gpu ,
632
+ inserted_rows ,
597
633
post_bwd_evicted_indices ,
598
634
actions_count_cpu ,
599
635
) = self ._compute_cache_ptrs (* prefetch_data )
@@ -635,7 +671,7 @@ def forward(
635
671
# codegen/genscript/optimizer_args.py
636
672
ssd_tensors = {
637
673
"row_addrs" : lxu_cache_ptrs ,
638
- "inserted_rows" : inserted_rows_gpu ,
674
+ "inserted_rows" : inserted_rows ,
639
675
"post_bwd_evicted_indices" : post_bwd_evicted_indices ,
640
676
"actions_count" : actions_count_cpu ,
641
677
},
0 commit comments