Skip to content

Commit 6375086

Browse files
sryapfacebook-github-bot
authored andcommitted
Make indices related to cache eviction UVA tensors (pytorch#3077)
Summary: Pull Request resolved: pytorch#3077 X-link: facebookresearch/FBGEMM#171 This is a follow up diff of D62114877 which makes the indices related L1 cache eviction UVA buffers. Differential Revision: D62114882
1 parent b5f1d96 commit 6375086

File tree

2 files changed

+57
-28
lines changed

2 files changed

+57
-28
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def __init__(
309309
* self.lxu_cache_weights.element_size()
310310
), "The precomputed cache_size does not match the actual cache size"
311311

312+
# Buffers for cache eviction
312313
# For storing weights to evict
313314
# The max number of rows to be evicted is limited by the number of
314315
# slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
@@ -325,6 +326,49 @@ def __init__(
325326
is_host_mapped=self.uvm_host_mapped,
326327
),
327328
)
329+
330+
# For storing embedding indices to evict to
331+
self.register_buffer(
332+
"lxu_cache_evicted_indices",
333+
torch.ops.fbgemm.new_unified_tensor(
334+
torch.zeros(
335+
1,
336+
device=self.current_device,
337+
dtype=torch.long,
338+
),
339+
(self.lxu_cache_weights.shape[0],),
340+
is_host_mapped=self.uvm_host_mapped,
341+
),
342+
)
343+
344+
# For storing cache slots to evict
345+
self.register_buffer(
346+
"lxu_cache_evicted_slots",
347+
torch.ops.fbgemm.new_unified_tensor(
348+
torch.zeros(
349+
1,
350+
device=self.current_device,
351+
dtype=torch.int,
352+
),
353+
(self.lxu_cache_weights.shape[0],),
354+
is_host_mapped=self.uvm_host_mapped,
355+
),
356+
)
357+
358+
# For storing the number of evicted rows
359+
self.register_buffer(
360+
"lxu_cache_evicted_count",
361+
torch.ops.fbgemm.new_unified_tensor(
362+
torch.zeros(
363+
1,
364+
device=self.current_device,
365+
dtype=torch.int,
366+
),
367+
(1,),
368+
is_host_mapped=self.uvm_host_mapped,
369+
),
370+
)
371+
328372
self.timestep = 0
329373

330374
# Dummy profile configuration for measuring the SSD get/set time
@@ -1081,34 +1125,29 @@ def prefetch( # noqa C901
10811125
self.local_ssd_cache_stats,
10821126
)
10831127

1084-
# Allocate output tensors for compact_indices
1085-
compact_evicted_indices = torch.empty_like(evicted_indices)
1086-
compact_assigned_cache_slots = torch.empty_like(assigned_cache_slots)
1087-
compact_actions_count_gpu = torch.empty_like(actions_count_gpu)
1088-
10891128
# Defrag indices based on evicted_indices (removing -1 and making
10901129
# the non -1 elements contiguous). We need to do this because the
10911130
# number of rows in `lxu_cache_evicted_weights` might be smaller
10921131
# than the number of elements in `evicted_indices`. Without this
10931132
# step, we can run into the index out of bound issue
10941133
torch.ops.fbgemm.compact_indices(
1095-
compact_indices=[compact_evicted_indices, compact_assigned_cache_slots],
1096-
compact_count=compact_actions_count_gpu,
1134+
compact_indices=[
1135+
self.lxu_cache_evicted_indices,
1136+
self.lxu_cache_evicted_slots,
1137+
],
1138+
compact_count=self.lxu_cache_evicted_count,
10971139
indices=[evicted_indices, assigned_cache_slots],
10981140
masks=torch.where(evicted_indices != -1, 1, 0),
10991141
count=actions_count_gpu,
11001142
)
11011143

1102-
evicted_indices = compact_evicted_indices
1103-
11041144
with record_function("## ssd_d2h_inserted_indices ##"):
11051145
# Transfer actions_count and insert_indices right away to
11061146
# incrase an overlap opportunity
1107-
actions_count_cpu, compact_actions_count_cpu, inserted_indices_cpu = (
1147+
actions_count_cpu, inserted_indices_cpu = (
11081148
self.to_pinned_cpu_on_stream_wait_on_another_stream(
11091149
tensors=[
11101150
actions_count_gpu,
1111-
compact_actions_count_gpu,
11121151
inserted_indices,
11131152
],
11141153
stream=self.ssd_memcpy_stream,
@@ -1117,26 +1156,14 @@ def prefetch( # noqa C901
11171156
)
11181157
)
11191158

1120-
with record_function("## ssd_d2h_evicted_indices ##"):
1121-
# Transfer evicted indices from GPU to CPU right away to increase a
1122-
# chance of overlapping with compute on the default stream
1123-
(evicted_indices_cpu,) = (
1124-
self.to_pinned_cpu_on_stream_wait_on_another_stream(
1125-
tensors=[evicted_indices],
1126-
stream=self.ssd_eviction_stream,
1127-
stream_to_wait_on=current_stream,
1128-
post_event=None,
1129-
)
1130-
)
1131-
11321159
# Copy rows to be evicted into a separate buffer (will be evicted
11331160
# later in the prefetch step)
11341161
with record_function("## ssd_compute_evicted_rows ##"):
11351162
torch.ops.fbgemm.masked_index_select(
11361163
self.lxu_cache_evicted_weights,
1137-
compact_assigned_cache_slots,
1164+
self.lxu_cache_evicted_slots,
11381165
self.lxu_cache_weights,
1139-
compact_actions_count_gpu,
1166+
self.lxu_cache_evicted_count,
11401167
)
11411168

11421169
# Allocation a scratch pad for the current iteration. The scratch
@@ -1290,8 +1317,8 @@ def prefetch( # noqa C901
12901317
# Evict rows from cache to SSD
12911318
self.evict(
12921319
rows=self.lxu_cache_evicted_weights,
1293-
indices_cpu=evicted_indices_cpu,
1294-
actions_count_cpu=compact_actions_count_cpu,
1320+
indices_cpu=self.lxu_cache_evicted_indices,
1321+
actions_count_cpu=self.lxu_cache_evicted_count,
12951322
stream=self.ssd_eviction_stream,
12961323
pre_event=self.ssd_event_get,
12971324
# Record completion event after scratch pad eviction

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ folly::Optional<std::pair<at::Tensor, at::Tensor>> EmbeddingKVDB::set_cache(
348348

349349
l2_cache_->init_tensor_for_l2_eviction(indices, weights, count);
350350
auto indices_addr = indices.data_ptr<int64_t>();
351-
auto num_lookups = count.item<long>();
351+
const int64_t num_lookups = count.scalar_type() == at::ScalarType::Long
352+
? *(count.data_ptr<int64_t>())
353+
: *(count.data_ptr<int32_t>());
352354
auto num_shards = executor_tp_->numThreads();
353355

354356
std::vector<folly::coro::TaskWithExecutor<void>> tasks;

0 commit comments

Comments
 (0)