@@ -309,6 +309,7 @@ def __init__(
309
309
* self .lxu_cache_weights .element_size ()
310
310
), "The precomputed cache_size does not match the actual cache size"
311
311
312
+ # Buffers for cache eviction
312
313
# For storing weights to evict
313
314
# The max number of rows to be evicted is limited by the number of
314
315
# slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
@@ -325,6 +326,49 @@ def __init__(
325
326
is_host_mapped = self .uvm_host_mapped ,
326
327
),
327
328
)
329
+
330
+ # For storing embedding indices to evict to
331
+ self .register_buffer (
332
+ "lxu_cache_evicted_indices" ,
333
+ torch .ops .fbgemm .new_unified_tensor (
334
+ torch .zeros (
335
+ 1 ,
336
+ device = self .current_device ,
337
+ dtype = torch .long ,
338
+ ),
339
+ (self .lxu_cache_weights .shape [0 ],),
340
+ is_host_mapped = self .uvm_host_mapped ,
341
+ ),
342
+ )
343
+
344
+ # For storing cache slots to evict
345
+ self .register_buffer (
346
+ "lxu_cache_evicted_slots" ,
347
+ torch .ops .fbgemm .new_unified_tensor (
348
+ torch .zeros (
349
+ 1 ,
350
+ device = self .current_device ,
351
+ dtype = torch .int ,
352
+ ),
353
+ (self .lxu_cache_weights .shape [0 ],),
354
+ is_host_mapped = self .uvm_host_mapped ,
355
+ ),
356
+ )
357
+
358
+ # For storing the number of evicted rows
359
+ self .register_buffer (
360
+ "lxu_cache_evicted_count" ,
361
+ torch .ops .fbgemm .new_unified_tensor (
362
+ torch .zeros (
363
+ 1 ,
364
+ device = self .current_device ,
365
+ dtype = torch .int ,
366
+ ),
367
+ (1 ,),
368
+ is_host_mapped = self .uvm_host_mapped ,
369
+ ),
370
+ )
371
+
328
372
self .timestep = 0
329
373
330
374
# Dummy profile configuration for measuring the SSD get/set time
@@ -1083,35 +1127,30 @@ def prefetch( # noqa C901
1083
1127
self .local_ssd_cache_stats ,
1084
1128
)
1085
1129
1086
- # Allocate output tensors for compact_indices
1087
- compact_evicted_indices = torch .empty_like (evicted_indices )
1088
- compact_assigned_cache_slots = torch .empty_like (assigned_cache_slots )
1089
- compact_actions_count_gpu = torch .empty_like (actions_count_gpu )
1090
-
1091
1130
# Defrag indices based on evicted_indices (removing -1 and making
1092
1131
# the non -1 elements contiguous). We need to do this because the
1093
1132
# number of rows in `lxu_cache_evicted_weights` might be smaller
1094
1133
# than the number of elements in `evicted_indices`. Without this
1095
1134
# step, we can run into the index out of bound issue
1096
1135
current_stream .wait_event (self .ssd_event_cache_evict )
1097
1136
torch .ops .fbgemm .compact_indices (
1098
- compact_indices = [compact_evicted_indices , compact_assigned_cache_slots ],
1099
- compact_count = compact_actions_count_gpu ,
1137
+ compact_indices = [
1138
+ self .lxu_cache_evicted_indices ,
1139
+ self .lxu_cache_evicted_slots ,
1140
+ ],
1141
+ compact_count = self .lxu_cache_evicted_count ,
1100
1142
indices = [evicted_indices , assigned_cache_slots ],
1101
1143
masks = torch .where (evicted_indices != - 1 , 1 , 0 ),
1102
1144
count = actions_count_gpu ,
1103
1145
)
1104
1146
1105
- evicted_indices = compact_evicted_indices
1106
-
1107
1147
with record_function ("## ssd_d2h_inserted_indices ##" ):
1108
1148
# Transfer actions_count and insert_indices right away to
1109
1149
# incrase an overlap opportunity
1110
- actions_count_cpu , compact_actions_count_cpu , inserted_indices_cpu = (
1150
+ actions_count_cpu , inserted_indices_cpu = (
1111
1151
self .to_pinned_cpu_on_stream_wait_on_another_stream (
1112
1152
tensors = [
1113
1153
actions_count_gpu ,
1114
- compact_actions_count_gpu ,
1115
1154
inserted_indices ,
1116
1155
],
1117
1156
stream = self .ssd_memcpy_stream ,
@@ -1120,26 +1159,14 @@ def prefetch( # noqa C901
1120
1159
)
1121
1160
)
1122
1161
1123
- with record_function ("## ssd_d2h_evicted_indices ##" ):
1124
- # Transfer evicted indices from GPU to CPU right away to increase a
1125
- # chance of overlapping with compute on the default stream
1126
- (evicted_indices_cpu ,) = (
1127
- self .to_pinned_cpu_on_stream_wait_on_another_stream (
1128
- tensors = [evicted_indices ],
1129
- stream = self .ssd_eviction_stream ,
1130
- stream_to_wait_on = current_stream ,
1131
- post_event = None ,
1132
- )
1133
- )
1134
-
1135
1162
# Copy rows to be evicted into a separate buffer (will be evicted
1136
1163
# later in the prefetch step)
1137
1164
with record_function ("## ssd_compute_evicted_rows ##" ):
1138
1165
torch .ops .fbgemm .masked_index_select (
1139
1166
self .lxu_cache_evicted_weights ,
1140
- compact_assigned_cache_slots ,
1167
+ self . lxu_cache_evicted_slots ,
1141
1168
self .lxu_cache_weights ,
1142
- compact_actions_count_gpu ,
1169
+ self . lxu_cache_evicted_count ,
1143
1170
)
1144
1171
1145
1172
# Allocation a scratch pad for the current iteration. The scratch
@@ -1293,8 +1320,8 @@ def prefetch( # noqa C901
1293
1320
# Evict rows from cache to SSD
1294
1321
self .evict (
1295
1322
rows = self .lxu_cache_evicted_weights ,
1296
- indices_cpu = evicted_indices_cpu ,
1297
- actions_count_cpu = compact_actions_count_cpu ,
1323
+ indices_cpu = self . lxu_cache_evicted_indices ,
1324
+ actions_count_cpu = self . lxu_cache_evicted_count ,
1298
1325
stream = self .ssd_eviction_stream ,
1299
1326
pre_event = self .ssd_event_get ,
1300
1327
# Record completion event after scratch pad eviction
0 commit comments