Skip to content

Commit 6ed389f

Browse files
duduyi2013facebook-github-bot
authored andcommitted
add enable_async_update into tbe signature (pytorch#518)
Summary: X-link: pytorch#3461 X-link: pytorch/torchrec#2599 X-link: pytorch#3431 Pull Request resolved: facebookresearch/FBGEMM#518 ad eanble_async_update into tbe config Reviewed By: chrisxcai Differential Revision: D66802199 fbshipit-source-id: 56d83586396f293b5932f5b6ad7661493278dd8c
1 parent 391c9ef commit 6ed389f

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __init__(
145145
# Set to True to alloc a UVM tensor using malloc+cudaHostRegister.
146146
# Set to False to use cudaMallocManaged
147147
uvm_host_mapped: bool = False,
148+
enable_async_update: bool = True, # whether enable L2/rocksdb write to async background thread
148149
) -> None:
149150
super(SSDTableBatchedEmbeddingBags, self).__init__()
150151

@@ -427,7 +428,7 @@ def __init__(
427428
logging.info(f"tbe_unique_id: {tbe_unique_id}")
428429
if not ps_hosts:
429430
logging.info(
430-
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
431+
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, enable_async_update:{enable_async_update}"
431432
f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
432433
f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset},"
433434
f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},rate_limit_mbps={ssd_rate_limit_mbps},"
@@ -459,6 +460,7 @@ def __init__(
459460
use_passed_in_path,
460461
tbe_unique_id,
461462
l2_cache_size,
463+
enable_async_update,
462464
)
463465
else:
464466
# pyre-fixme[4]: Attribute must be annotated.

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
279279
int64_t cache_size = 0,
280280
bool use_passed_in_path = false,
281281
int64_t tbe_unique_id = 0,
282-
int64_t l2_cache_size_gb = 0)
282+
int64_t l2_cache_size_gb = 0,
283+
bool enable_async_update = false)
283284
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
284285
path,
285286
num_shards,
@@ -299,7 +300,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
299300
cache_size,
300301
use_passed_in_path,
301302
tbe_unique_id,
302-
l2_cache_size_gb)) {}
303+
l2_cache_size_gb,
304+
enable_async_update)) {}
303305

304306
void set_cuda(
305307
Tensor indices,
@@ -481,7 +483,8 @@ static auto embedding_rocks_db_wrapper =
481483
int64_t,
482484
bool,
483485
int64_t,
484-
int64_t>(),
486+
int64_t,
487+
bool>(),
485488
"",
486489
{
487490
torch::arg("path"),
@@ -503,6 +506,7 @@ static auto embedding_rocks_db_wrapper =
503506
torch::arg("use_passed_in_path") = true,
504507
torch::arg("tbe_unique_id") = 0,
505508
torch::arg("l2_cache_size_gb") = 0,
509+
torch::arg("enable_async_update") = true,
506510
})
507511
.def(
508512
"set_cuda",

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,15 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
192192
int64_t cache_size = 0,
193193
bool use_passed_in_path = false,
194194
int64_t tbe_unqiue_id = 0,
195-
int64_t l2_cache_size_gb = 0)
195+
int64_t l2_cache_size_gb = 0,
196+
bool enable_async_update = false)
196197
: kv_db::EmbeddingKVDB(
197198
num_shards,
198199
max_D,
199200
l2_cache_size_gb,
200201
tbe_unqiue_id,
201-
row_storage_bitwidth / 8),
202+
row_storage_bitwidth / 8,
203+
enable_async_update),
202204
max_D_(max_D) {
203205
// TODO: lots of tunables. NNI or something for this?
204206
rocksdb::Options options;

0 commit comments

Comments
 (0)