Skip to content

Commit 25d2486

Browse files
sryapfacebook-github-bot
authored andcommitted
Make indices related to cache eviction UVA tensors (pytorch#3077)
Summary: Pull Request resolved: pytorch#3077 X-link: facebookresearch/FBGEMM#171 This is a follow up diff of D62114877 which makes the indices related L1 cache eviction UVA buffers. Differential Revision: D62114882
1 parent e178ed9 commit 25d2486

File tree

2 files changed

+76
-33
lines changed

2 files changed

+76
-33
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def __init__(
309309
* self.lxu_cache_weights.element_size()
310310
), "The precomputed cache_size does not match the actual cache size"
311311

312+
# Buffers for cache eviction
312313
# For storing weights to evict
313314
# The max number of rows to be evicted is limited by the number of
314315
# slots in the cache. Thus, we allocate `lxu_cache_evicted_weights` to
@@ -325,6 +326,49 @@ def __init__(
325326
is_host_mapped=self.uvm_host_mapped,
326327
),
327328
)
329+
330+
# For storing embedding indices to evict to
331+
self.register_buffer(
332+
"lxu_cache_evicted_indices",
333+
torch.ops.fbgemm.new_unified_tensor(
334+
torch.zeros(
335+
1,
336+
device=self.current_device,
337+
dtype=torch.long,
338+
),
339+
(self.lxu_cache_weights.shape[0],),
340+
is_host_mapped=self.uvm_host_mapped,
341+
),
342+
)
343+
344+
# For storing cache slots to evict
345+
self.register_buffer(
346+
"lxu_cache_evicted_slots",
347+
torch.ops.fbgemm.new_unified_tensor(
348+
torch.zeros(
349+
1,
350+
device=self.current_device,
351+
dtype=torch.int,
352+
),
353+
(self.lxu_cache_weights.shape[0],),
354+
is_host_mapped=self.uvm_host_mapped,
355+
),
356+
)
357+
358+
# For storing the number of evicted rows
359+
self.register_buffer(
360+
"lxu_cache_evicted_count",
361+
torch.ops.fbgemm.new_unified_tensor(
362+
torch.zeros(
363+
1,
364+
device=self.current_device,
365+
dtype=torch.int,
366+
),
367+
(1,),
368+
is_host_mapped=self.uvm_host_mapped,
369+
),
370+
)
371+
328372
self.timestep = 0
329373

330374
# Dummy profile configuration for measuring the SSD get/set time
@@ -1083,35 +1127,30 @@ def prefetch( # noqa C901
10831127
self.local_ssd_cache_stats,
10841128
)
10851129

1086-
# Allocate output tensors for compact_indices
1087-
compact_evicted_indices = torch.empty_like(evicted_indices)
1088-
compact_assigned_cache_slots = torch.empty_like(assigned_cache_slots)
1089-
compact_actions_count_gpu = torch.empty_like(actions_count_gpu)
1090-
10911130
# Defrag indices based on evicted_indices (removing -1 and making
10921131
# the non -1 elements contiguous). We need to do this because the
10931132
# number of rows in `lxu_cache_evicted_weights` might be smaller
10941133
# than the number of elements in `evicted_indices`. Without this
10951134
# step, we can run into the index out of bound issue
10961135
current_stream.wait_event(self.ssd_event_cache_evict)
10971136
torch.ops.fbgemm.compact_indices(
1098-
compact_indices=[compact_evicted_indices, compact_assigned_cache_slots],
1099-
compact_count=compact_actions_count_gpu,
1137+
compact_indices=[
1138+
self.lxu_cache_evicted_indices,
1139+
self.lxu_cache_evicted_slots,
1140+
],
1141+
compact_count=self.lxu_cache_evicted_count,
11001142
indices=[evicted_indices, assigned_cache_slots],
11011143
masks=torch.where(evicted_indices != -1, 1, 0),
11021144
count=actions_count_gpu,
11031145
)
11041146

1105-
evicted_indices = compact_evicted_indices
1106-
11071147
with record_function("## ssd_d2h_inserted_indices ##"):
11081148
# Transfer actions_count and insert_indices right away to
11091149
# incrase an overlap opportunity
1110-
actions_count_cpu, compact_actions_count_cpu, inserted_indices_cpu = (
1150+
actions_count_cpu, inserted_indices_cpu = (
11111151
self.to_pinned_cpu_on_stream_wait_on_another_stream(
11121152
tensors=[
11131153
actions_count_gpu,
1114-
compact_actions_count_gpu,
11151154
inserted_indices,
11161155
],
11171156
stream=self.ssd_memcpy_stream,
@@ -1120,26 +1159,14 @@ def prefetch( # noqa C901
11201159
)
11211160
)
11221161

1123-
with record_function("## ssd_d2h_evicted_indices ##"):
1124-
# Transfer evicted indices from GPU to CPU right away to increase a
1125-
# chance of overlapping with compute on the default stream
1126-
(evicted_indices_cpu,) = (
1127-
self.to_pinned_cpu_on_stream_wait_on_another_stream(
1128-
tensors=[evicted_indices],
1129-
stream=self.ssd_eviction_stream,
1130-
stream_to_wait_on=current_stream,
1131-
post_event=None,
1132-
)
1133-
)
1134-
11351162
# Copy rows to be evicted into a separate buffer (will be evicted
11361163
# later in the prefetch step)
11371164
with record_function("## ssd_compute_evicted_rows ##"):
11381165
torch.ops.fbgemm.masked_index_select(
11391166
self.lxu_cache_evicted_weights,
1140-
compact_assigned_cache_slots,
1167+
self.lxu_cache_evicted_slots,
11411168
self.lxu_cache_weights,
1142-
compact_actions_count_gpu,
1169+
self.lxu_cache_evicted_count,
11431170
)
11441171

11451172
# Allocation a scratch pad for the current iteration. The scratch
@@ -1293,8 +1320,8 @@ def prefetch( # noqa C901
12931320
# Evict rows from cache to SSD
12941321
self.evict(
12951322
rows=self.lxu_cache_evicted_weights,
1296-
indices_cpu=evicted_indices_cpu,
1297-
actions_count_cpu=compact_actions_count_cpu,
1323+
indices_cpu=self.lxu_cache_evicted_indices,
1324+
actions_count_cpu=self.lxu_cache_evicted_count,
12981325
stream=self.ssd_eviction_stream,
12991326
pre_event=self.ssd_event_get,
13001327
# Record completion event after scratch pad eviction

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,31 @@
1616

1717
namespace kv_db {
1818

19+
namespace {
20+
21+
/// Read a scalar value from a tensor that is maybe a UVM tensor
22+
/// Note that `tensor.item<type>()` is not allowed on a UVM tensor in
23+
/// PyTorch
24+
inline int64_t get_maybe_uvm_scalar(const at::Tensor& tensor) {
25+
return tensor.scalar_type() == at::ScalarType::Long
26+
? *(tensor.data_ptr<int64_t>())
27+
: *(tensor.data_ptr<int32_t>());
28+
}
29+
30+
}; // namespace
31+
1932
std::tuple<at::Tensor, at::Tensor, at::Tensor> tensor_copy(
2033
const at::Tensor& indices,
2134
const at::Tensor& weights,
2235
const at::Tensor& count) {
23-
auto num_sets = count.item<long>();
36+
auto num_sets = get_maybe_uvm_scalar(count);
2437
auto new_indices = at::empty(
2538
num_sets, at::TensorOptions().device(at::kCPU).dtype(indices.dtype()));
2639
auto new_weights = at::empty(
2740
{num_sets, weights.size(1)},
2841
at::TensorOptions().device(at::kCPU).dtype(weights.dtype()));
42+
auto new_count =
43+
at::empty({1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong));
2944
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
3045
weights.scalar_type(), "cache_memcpy", [&] {
3146
auto indices_addr = indices.data_ptr<int64_t>();
@@ -42,7 +57,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> tensor_copy(
4257
weights_addr + num_sets * weights.size(1),
4358
new_weightss_addr); // dst_start
4459
});
45-
return std::make_tuple(new_indices, new_weights, count.clone());
60+
*new_count.data_ptr<int64_t>() = num_sets;
61+
return std::make_tuple(new_indices, new_weights, new_count);
4662
}
4763

4864
EmbeddingKVDB::EmbeddingKVDB(
@@ -182,7 +198,7 @@ void EmbeddingKVDB::set(
182198
const at::Tensor& weights,
183199
const at::Tensor& count,
184200
const bool is_bwd) {
185-
if (auto num_evictions = count.item<long>(); num_evictions <= 0) {
201+
if (auto num_evictions = get_maybe_uvm_scalar(count); num_evictions <= 0) {
186202
XLOG_EVERY_MS(INFO, 60000)
187203
<< "[" << unique_id_ << "]skip set_cuda since number evictions is "
188204
<< num_evictions;
@@ -204,7 +220,7 @@ void EmbeddingKVDB::get(
204220
const at::Tensor& indices,
205221
const at::Tensor& weights,
206222
const at::Tensor& count) {
207-
if (auto num_lookups = count.item<long>(); num_lookups <= 0) {
223+
if (auto num_lookups = get_maybe_uvm_scalar(count); num_lookups <= 0) {
208224
XLOG_EVERY_MS(INFO, 60000)
209225
<< "[" << unique_id_ << "]skip get_cuda since number lookups is "
210226
<< num_lookups;
@@ -255,7 +271,7 @@ std::shared_ptr<CacheContext> EmbeddingKVDB::get_cache(
255271
}
256272
auto start_ts = facebook::WallClockUtil::NowInUsecFast();
257273
auto indices_addr = indices.data_ptr<int64_t>();
258-
auto num_lookups = count.item<long>();
274+
auto num_lookups = get_maybe_uvm_scalar(count);
259275
auto cache_context = std::make_shared<CacheContext>(num_lookups);
260276

261277
auto num_shards = executor_tp_->numThreads();
@@ -348,7 +364,7 @@ folly::Optional<std::pair<at::Tensor, at::Tensor>> EmbeddingKVDB::set_cache(
348364

349365
l2_cache_->init_tensor_for_l2_eviction(indices, weights, count);
350366
auto indices_addr = indices.data_ptr<int64_t>();
351-
auto num_lookups = count.item<long>();
367+
const int64_t num_lookups = get_maybe_uvm_scalar(count);
352368
auto num_shards = executor_tp_->numThreads();
353369

354370
std::vector<folly::coro::TaskWithExecutor<void>> tasks;

0 commit comments

Comments
 (0)