@@ -309,6 +309,7 @@ def __init__(
309
309
* self .lxu_cache_weights .element_size ()
310
310
), "The precomputed cache_size does not match the actual cache size"
311
311
312
+ # Buffers for cache eviction
312
313
# For storing weights to evict
313
314
# The max number of rows to be evicted is limited by the number of
314
315
# slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
@@ -325,6 +326,49 @@ def __init__(
325
326
is_host_mapped = self .uvm_host_mapped ,
326
327
),
327
328
)
329
+
330
+ # For storing embedding indices to evict to
331
+ self .register_buffer (
332
+ "lxu_cache_evicted_indices" ,
333
+ torch .ops .fbgemm .new_unified_tensor (
334
+ torch .zeros (
335
+ 1 ,
336
+ device = self .current_device ,
337
+ dtype = torch .long ,
338
+ ),
339
+ (self .lxu_cache_weights .shape [0 ],),
340
+ is_host_mapped = self .uvm_host_mapped ,
341
+ ),
342
+ )
343
+
344
+ # For storing cache slots to evict
345
+ self .register_buffer (
346
+ "lxu_cache_evicted_slots" ,
347
+ torch .ops .fbgemm .new_unified_tensor (
348
+ torch .zeros (
349
+ 1 ,
350
+ device = self .current_device ,
351
+ dtype = torch .int ,
352
+ ),
353
+ (self .lxu_cache_weights .shape [0 ],),
354
+ is_host_mapped = self .uvm_host_mapped ,
355
+ ),
356
+ )
357
+
358
+ # For storing the number of evicted rows
359
+ self .register_buffer (
360
+ "lxu_cache_evicted_count" ,
361
+ torch .ops .fbgemm .new_unified_tensor (
362
+ torch .zeros (
363
+ 1 ,
364
+ device = self .current_device ,
365
+ dtype = torch .int ,
366
+ ),
367
+ (1 ,),
368
+ is_host_mapped = self .uvm_host_mapped ,
369
+ ),
370
+ )
371
+
328
372
self .timestep = 0
329
373
330
374
# Dummy profile configuration for measuring the SSD get/set time
@@ -1081,34 +1125,29 @@ def prefetch( # noqa C901
1081
1125
self .local_ssd_cache_stats ,
1082
1126
)
1083
1127
1084
- # Allocate output tensors for compact_indices
1085
- compact_evicted_indices = torch .empty_like (evicted_indices )
1086
- compact_assigned_cache_slots = torch .empty_like (assigned_cache_slots )
1087
- compact_actions_count_gpu = torch .empty_like (actions_count_gpu )
1088
-
1089
1128
# Defrag indices based on evicted_indices (removing -1 and making
1090
1129
# the non -1 elements contiguous). We need to do this because the
1091
1130
# number of rows in `lxu_cache_evicted_weights` might be smaller
1092
1131
# than the number of elements in `evicted_indices`. Without this
1093
1132
# step, we can run into the index out of bound issue
1094
1133
torch .ops .fbgemm .compact_indices (
1095
- compact_indices = [compact_evicted_indices , compact_assigned_cache_slots ],
1096
- compact_count = compact_actions_count_gpu ,
1134
+ compact_indices = [
1135
+ self .lxu_cache_evicted_indices ,
1136
+ self .lxu_cache_evicted_slots ,
1137
+ ],
1138
+ compact_count = self .lxu_cache_evicted_count ,
1097
1139
indices = [evicted_indices , assigned_cache_slots ],
1098
1140
masks = torch .where (evicted_indices != - 1 , 1 , 0 ),
1099
1141
count = actions_count_gpu ,
1100
1142
)
1101
1143
1102
- evicted_indices = compact_evicted_indices
1103
-
1104
1144
with record_function ("## ssd_d2h_inserted_indices ##" ):
1105
1145
# Transfer actions_count and insert_indices right away to
1106
1146
# incrase an overlap opportunity
1107
- actions_count_cpu , compact_actions_count_cpu , inserted_indices_cpu = (
1147
+ actions_count_cpu , inserted_indices_cpu = (
1108
1148
self .to_pinned_cpu_on_stream_wait_on_another_stream (
1109
1149
tensors = [
1110
1150
actions_count_gpu ,
1111
- compact_actions_count_gpu ,
1112
1151
inserted_indices ,
1113
1152
],
1114
1153
stream = self .ssd_memcpy_stream ,
@@ -1117,26 +1156,14 @@ def prefetch( # noqa C901
1117
1156
)
1118
1157
)
1119
1158
1120
- with record_function ("## ssd_d2h_evicted_indices ##" ):
1121
- # Transfer evicted indices from GPU to CPU right away to increase a
1122
- # chance of overlapping with compute on the default stream
1123
- (evicted_indices_cpu ,) = (
1124
- self .to_pinned_cpu_on_stream_wait_on_another_stream (
1125
- tensors = [evicted_indices ],
1126
- stream = self .ssd_eviction_stream ,
1127
- stream_to_wait_on = current_stream ,
1128
- post_event = None ,
1129
- )
1130
- )
1131
-
1132
1159
# Copy rows to be evicted into a separate buffer (will be evicted
1133
1160
# later in the prefetch step)
1134
1161
with record_function ("## ssd_compute_evicted_rows ##" ):
1135
1162
torch .ops .fbgemm .masked_index_select (
1136
1163
self .lxu_cache_evicted_weights ,
1137
- compact_assigned_cache_slots ,
1164
+ self . lxu_cache_evicted_slots ,
1138
1165
self .lxu_cache_weights ,
1139
- compact_actions_count_gpu ,
1166
+ self . lxu_cache_evicted_count ,
1140
1167
)
1141
1168
1142
1169
# Allocation a scratch pad for the current iteration. The scratch
@@ -1290,8 +1317,8 @@ def prefetch( # noqa C901
1290
1317
# Evict rows from cache to SSD
1291
1318
self .evict (
1292
1319
rows = self .lxu_cache_evicted_weights ,
1293
- indices_cpu = evicted_indices_cpu ,
1294
- actions_count_cpu = compact_actions_count_cpu ,
1320
+ indices_cpu = self . lxu_cache_evicted_indices ,
1321
+ actions_count_cpu = self . lxu_cache_evicted_count ,
1295
1322
stream = self .ssd_eviction_stream ,
1296
1323
pre_event = self .ssd_event_get ,
1297
1324
# Record completion event after scratch pad eviction
0 commit comments