Skip to content

Commit 50b5b37

Browse files
duduyi2013facebook-github-bot
authored andcommitted
add L2 flush (pytorch#197)
Summary: X-link: pytorch#3110 Pull Request resolved: facebookresearch/FBGEMM#197 add L2 flush support for checkpoint Reviewed By: q10 Differential Revision: D62462352 fbshipit-source-id: dfd59f0ebd43b27b1ce6f8a684b8956bf8672191
1 parent ad0122b commit 50b5b37

File tree

7 files changed

+120
-18
lines changed

7 files changed

+120
-18
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,8 @@ def flush(self) -> None:
16431643
False,
16441644
)
16451645

1646+
self.ssd_db.flush()
1647+
16461648
def prepare_inputs(
16471649
self,
16481650
indices: Tensor,

fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ class CacheLibCache {
3333
public:
3434
using Cache = facebook::cachelib::LruAllocator;
3535
struct CacheConfig {
36-
size_t cacheSizeBytes;
36+
size_t cache_size_bytes;
37+
size_t item_size_bytes;
38+
size_t num_shards;
39+
int64_t max_D_;
3740
};
3841

39-
explicit CacheLibCache(size_t cacheSizeBytes, int64_t num_shards);
42+
explicit CacheLibCache(const CacheConfig& cache_config);
4043

4144
std::unique_ptr<Cache> initializeCacheLib(const CacheConfig& config);
4245

@@ -85,6 +88,20 @@ class CacheLibCache {
8588
/// @note cache_->allocation will trigger eviction callback func
8689
bool put(int64_t key, const at::Tensor& data);
8790

91+
/// iterate through all items in L2 cache, fill them in indices and weights
92+
/// respectively and return indices, weights and count
93+
///
94+
/// @return indices The 1D embedding index tensor, should skip on negative
95+
/// value
96+
/// @return weights The 2D tensor that each row(embeddings) is paired up with
97+
/// relative element in <indices>
98+
/// @return count A single element tensor that contains the number of indices
99+
/// to be processed
100+
///
101+
/// @note this isn't thread safe, caller needs to make sure put isn't called
102+
/// while this is executed.
103+
std::tuple<at::Tensor, at::Tensor, at::Tensor> get_all_items();
104+
88105
/// instantiate eviction related indices and weights tensors(size of <count>)
89106
/// for L2 eviction using the same dtype and device from <indices> and
90107
/// <weights> , managed on the caller side

fbgemm_gpu/src/ps_split_embeddings_cache/ps_table_batched_embeddings.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class EmbeddingParameterServer : public kv_db::EmbeddingKVDB {
5959
RECORD_USER_SCOPE("EmbeddingParameterServer::get");
6060
co_await tps_client_->get(indices, weights, count.item().toLong());
6161
}
62-
void flush() override {}
62+
void flush() {}
6363
void compact() override {}
6464
// cleanup cached results in server side
6565
// This is a test helper, please do not use it in production

fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,32 @@
1313
namespace l2_cache {
1414

1515
using Cache = facebook::cachelib::LruAllocator;
16-
CacheLibCache::CacheLibCache(size_t cacheSizeBytes, int64_t num_shards)
17-
: cache_config_(CacheConfig{.cacheSizeBytes = cacheSizeBytes}),
16+
17+
// this is a general predictor for weights data type, might not be general
18+
// enough for all the cases
19+
at::ScalarType bytes_to_dtype(int num_bytes) {
20+
switch (num_bytes) {
21+
case 1:
22+
return at::kByte;
23+
case 2:
24+
return at::kHalf;
25+
case 4:
26+
return at::kFloat;
27+
case 8:
28+
return at::kDouble;
29+
default:
30+
throw std::runtime_error("Unsupported dtype");
31+
}
32+
}
33+
34+
CacheLibCache::CacheLibCache(const CacheConfig& cache_config)
35+
: cache_config_(cache_config),
1836
cache_(initializeCacheLib(cache_config_)),
1937
admin_(createCacheAdmin(*cache_)) {
20-
for (int i = 0; i < num_shards; i++) {
38+
for (size_t i = 0; i < cache_config_.num_shards; i++) {
2139
pool_ids_.push_back(cache_->addPool(
2240
fmt::format("shard_{}", i),
23-
cache_->getCacheMemoryStats().ramCacheSize / num_shards));
41+
cache_->getCacheMemoryStats().ramCacheSize / cache_config_.num_shards));
2442
}
2543
}
2644

@@ -51,7 +69,7 @@ std::unique_ptr<Cache> CacheLibCache::initializeCacheLib(
5169
});
5270
};
5371
Cache::Config cacheLibConfig;
54-
cacheLibConfig.setCacheSize(static_cast<uint64_t>(config.cacheSizeBytes))
72+
cacheLibConfig.setCacheSize(static_cast<uint64_t>(config.cache_size_bytes))
5573
.setRemoveCallback(eviction_cb)
5674
.setCacheName("TBEL2Cache")
5775
.setAccessConfig({25 /* bucket power */, 10 /* lock power */})
@@ -99,6 +117,44 @@ bool CacheLibCache::put(int64_t key, const at::Tensor& data) {
99117
return true;
100118
}
101119

120+
std::tuple<at::Tensor, at::Tensor, at::Tensor> CacheLibCache::get_all_items() {
121+
int total_num_items = 0;
122+
for (auto& pool_id : pool_ids_) {
123+
total_num_items += cache_->getPoolStats(pool_id).numItems();
124+
}
125+
auto weight_dim = cache_config_.max_D_;
126+
auto weights_dtype =
127+
bytes_to_dtype(cache_config_.item_size_bytes / weight_dim);
128+
auto indices = at::empty(
129+
total_num_items, at::TensorOptions().dtype(at::kLong).device(at::kCPU));
130+
auto weights = at::empty(
131+
{total_num_items, weight_dim},
132+
at::TensorOptions().dtype(weights_dtype).device(at::kCPU));
133+
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
134+
weights.scalar_type(), "get_all_items", [&] {
135+
auto indices_data_ptr = indices.data_ptr<int64_t>();
136+
auto weights_data_ptr = weights.data_ptr<scalar_t>();
137+
int64_t item_idx = 0;
138+
for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) {
139+
const auto key_ptr =
140+
reinterpret_cast<const int64_t*>(itr->getKey().data());
141+
indices_data_ptr[item_idx] = *key_ptr;
142+
std::copy(
143+
reinterpret_cast<const scalar_t*>(itr->getMemory()),
144+
reinterpret_cast<const scalar_t*>(itr->getMemory()) + weight_dim,
145+
&weights_data_ptr[item_idx * weight_dim]); // dst_start
146+
item_idx++;
147+
}
148+
CHECK_EQ(total_num_items, item_idx);
149+
});
150+
return std::make_tuple(
151+
indices,
152+
weights,
153+
at::tensor(
154+
{total_num_items},
155+
at::TensorOptions().dtype(at::kLong).device(at::kCPU)));
156+
}
157+
102158
void CacheLibCache::init_tensor_for_l2_eviction(
103159
const at::Tensor& indices,
104160
const at::Tensor& weights,
@@ -130,7 +186,7 @@ CacheLibCache::get_evicted_indices_and_weights() {
130186

131187
std::vector<int64_t> CacheLibCache::get_cache_usage() {
132188
std::vector<int64_t> cache_mem_stats(2, 0); // freeBytes, capacity
133-
cache_mem_stats[1] = cache_config_.cacheSizeBytes;
189+
cache_mem_stats[1] = cache_config_.cache_size_bytes;
134190
for (auto& pool_id : pool_ids_) {
135191
auto pool_stats = cache_->getPoolStats(pool_id);
136192
cache_mem_stats[0] += pool_stats.freeMemoryBytes();

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,23 @@ EmbeddingKVDB::EmbeddingKVDB(
6565
int64_t num_shards,
6666
int64_t max_D,
6767
int64_t cache_size_gb,
68-
int64_t unique_id)
68+
int64_t unique_id,
69+
int64_t ele_size_bytes)
6970
: unique_id_(unique_id),
7071
num_shards_(num_shards),
7172
max_D_(max_D),
7273
executor_tp_(std::make_unique<folly::CPUThreadPoolExecutor>(num_shards)) {
7374
assert(num_shards > 0);
74-
l2_cache_ = cache_size_gb > 0
75-
? std::make_unique<l2_cache::CacheLibCache>(
76-
cache_size_gb * 1024 * 1024 * 1024, num_shards_)
77-
: nullptr;
75+
if (cache_size_gb > 0) {
76+
l2_cache::CacheLibCache::CacheConfig cache_config;
77+
cache_config.cache_size_bytes = cache_size_gb * 1024 * 1024 * 1024;
78+
cache_config.num_shards = num_shards_;
79+
cache_config.item_size_bytes = max_D_ * ele_size_bytes;
80+
cache_config.max_D_ = max_D_;
81+
l2_cache_ = std::make_unique<l2_cache::CacheLibCache>(cache_config);
82+
} else {
83+
l2_cache_ = nullptr;
84+
}
7885
cache_filling_thread_ = std::make_unique<std::thread>([=] {
7986
while (!stop_) {
8087
auto filling_item_ptr = weights_to_fill_queue_.try_peek();
@@ -114,6 +121,17 @@ EmbeddingKVDB::~EmbeddingKVDB() {
114121
cache_filling_thread_->join();
115122
}
116123

124+
void EmbeddingKVDB::flush() {
125+
wait_util_filling_work_done();
126+
if (l2_cache_) {
127+
auto tensor_tuple = l2_cache_->get_all_items();
128+
auto& indices = std::get<0>(tensor_tuple);
129+
auto& weights = std::get<1>(tensor_tuple);
130+
auto& count = std::get<2>(tensor_tuple);
131+
folly::coro::blockingWait(set_kv_db_async(indices, weights, count));
132+
}
133+
}
134+
117135
void EmbeddingKVDB::get_cuda(
118136
const at::Tensor& indices,
119137
const at::Tensor& weights,

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
7676
int64_t num_shards,
7777
int64_t max_D,
7878
int64_t cache_size_gb = 0,
79-
int64_t unique_id = 0);
79+
int64_t unique_id = 0,
80+
int64_t ele_size_bytes = 2 /*assume by default fp16*/);
8081

8182
virtual ~EmbeddingKVDB();
8283

@@ -140,7 +141,13 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
140141

141142
virtual void compact() = 0;
142143

143-
virtual void flush() = 0;
144+
/// Flush L2 cache into backend storage
145+
/// @return None
146+
/// @note caller side should mananger the timing to make sure flush doens't
147+
/// happen at the same time as get/set
148+
/// @note flush only flushes L2 cache, if there is cache on the backend
149+
/// storage, that flush should be called as well
150+
void flush();
144151

145152
// The function attaches the CUDA callback logic to the compute
146153
// stream to ensure that the data retrieval is carried out properly.

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
151151
num_shards,
152152
max_D,
153153
l2_cache_size_gb,
154-
tbe_unqiue_id) {
154+
tbe_unqiue_id,
155+
row_storage_bitwidth / 8) {
155156
// TODO: lots of tunables. NNI or something for this?
156157
rocksdb::Options options;
157158
options.create_if_missing = true;
@@ -580,7 +581,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
580581
}
581582
}
582583

583-
void flush() override {
584+
void flush() {
585+
kv_db::EmbeddingKVDB::flush();
584586
for (auto& db : dbs_) {
585587
db->Flush(rocksdb::FlushOptions());
586588
}

0 commit comments

Comments
 (0)