Skip to content

Commit dda9d7e

Browse files
duduyi2013facebook-github-bot
authored andcommitted
pass in kernel tbe id into rocksdb wrapper (pytorch#2930)
Summary: X-link: facebookresearch/FBGEMM#32 Pull Request resolved: pytorch#2930 the reason we need this is we constantly see the port conflict error in rocksdb initialization. Before this diff we call getFreePort to ge an available port. For each ssd tbe we will create 32 rocksdb shards, so in total there are 256 ports needed per host. This works fine with 4 hosts until we are running 16 hosts training job as we need make sure all 16 hosts don't get into the corner cases where multiple db shard get assigned the same free port. Reviewed By: sryap Differential Revision: D60635718
1 parent 48fd1f1 commit dda9d7e

File tree

3 files changed

+29
-18
lines changed

3 files changed

+29
-18
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,25 @@ def __init__(
258258
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
259259
)
260260
# logging.info("DEBUG: weights_precision {}".format(weights_precision))
261+
262+
# create tbe unique id using rank index | local tbe idx
263+
if tbe_unique_id == -1:
264+
SSDTableBatchedEmbeddingBags._local_instance_index += 1
265+
if dist.is_initialized():
266+
assert (
267+
SSDTableBatchedEmbeddingBags._local_instance_index < 1024
268+
), f"{SSDTableBatchedEmbeddingBags._local_instance_index}, more than 1024 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
269+
tbe_unique_id = (
270+
dist.get_rank() << 10
271+
| SSDTableBatchedEmbeddingBags._local_instance_index
272+
)
273+
else:
274+
logging.warning("dist is not initialized, treating as single gpu cases")
275+
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
276+
logging.info(f"tbe_unique_id: {tbe_unique_id}")
261277
if not ps_hosts:
262278
logging.info(
263-
f"Logging SSD offloading setup "
279+
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id},"
264280
f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
265281
f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset},"
266282
f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},rate_limit_mbps={ssd_rate_limit_mbps},"
@@ -290,19 +306,9 @@ def __init__(
290306
weights_precision.bit_rate(), # row_storage_bitwidth
291307
ssd_block_cache_size_per_tbe,
292308
use_passed_in_path,
309+
tbe_unique_id,
293310
)
294311
else:
295-
# create tbe unique id using rank index | local tbe idx
296-
if tbe_unique_id == -1:
297-
SSDTableBatchedEmbeddingBags._local_instance_index += 1
298-
assert (
299-
SSDTableBatchedEmbeddingBags._local_instance_index < 8
300-
), f"{SSDTableBatchedEmbeddingBags._local_instance_index}, more than 8 TBE instance is created in one rank, the tbe unique id won't be unique in this case."
301-
tbe_unique_id = (
302-
dist.get_rank() << 3
303-
| SSDTableBatchedEmbeddingBags._local_instance_index
304-
)
305-
logging.info(f"tbe_unique_id: {tbe_unique_id}")
306312
# pyre-fixme[4]: Attribute must be annotated.
307313
# pyre-ignore[16]
308314
self.ssd_db = torch.classes.fbgemm.EmbeddingParameterServerWrapper(

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
168168
double uniform_init_upper,
169169
int64_t row_storage_bitwidth = 32,
170170
int64_t cache_size = 0,
171-
bool use_passed_in_path = false)
171+
bool use_passed_in_path = false,
172+
int64_t tbe_unique_id = 0)
172173
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
173174
path,
174175
num_shards,
@@ -186,7 +187,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
186187
uniform_init_upper,
187188
row_storage_bitwidth,
188189
cache_size,
189-
use_passed_in_path)) {}
190+
use_passed_in_path,
191+
tbe_unique_id)) {}
190192

191193
void
192194
set_cuda(Tensor indices, Tensor weights, Tensor count, int64_t timestep) {
@@ -238,7 +240,8 @@ static auto embedding_rocks_db_wrapper =
238240
double,
239241
int64_t,
240242
int64_t,
241-
bool>(),
243+
bool,
244+
int64_t>(),
242245
"",
243246
{
244247
torch::arg("path"),
@@ -258,6 +261,7 @@ static auto embedding_rocks_db_wrapper =
258261
torch::arg("row_storage_bitwidth"),
259262
torch::arg("cache_size"),
260263
torch::arg("use_passed_in_path") = true,
264+
torch::arg("tbe_unique_id") = 0,
261265
})
262266
.def("set_cuda", &EmbeddingRocksDBWrapper::set_cuda)
263267
.def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda)

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <torch/nn/init.h>
1313
#include <iostream>
1414
#ifdef FBGEMM_FBCODE
15-
#include "common/network/PortUtil.h"
1615
#include "common/strings/UUID.h"
1716
#include "fb_rocksdb/DBMonitor/DBMonitor.h"
1817
#include "fb_rocksdb/FbRocksDb.h"
@@ -40,6 +39,7 @@ constexpr size_t kRowInitBufferSize = 32 * 1024;
4039
#ifdef FBGEMM_FBCODE
4140
constexpr size_t num_ssd_drives = 8;
4241
const std::string ssd_mount_point = "/data00_nvidia";
42+
const size_t base_port = 136000;
4343
#endif
4444

4545
class Initializer {
@@ -132,7 +132,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
132132
float uniform_init_upper,
133133
int64_t row_storage_bitwidth = 32,
134134
int64_t cache_size = 0,
135-
bool use_passed_in_path = false) {
135+
bool use_passed_in_path = false,
136+
int64_t tbe_unqiue_id = 0) {
136137
// TODO: lots of tunables. NNI or something for this?
137138
rocksdb::Options options;
138139
options.create_if_missing = true;
@@ -256,7 +257,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
256257
rocksdb::DB* db;
257258

258259
#ifdef FBGEMM_FBCODE
259-
db_monitor_options.port = facebook::network::getFreePort();
260+
db_monitor_options.port = base_port + tbe_unqiue_id;
260261
auto s = facebook::fb_rocksdb::openRocksDB(
261262
options,
262263
shard_path,

0 commit comments

Comments
 (0)