Skip to content

Commit f3b9b27

Browse files
doehyunfacebook-github-bot
authored andcommitted
Implement cache miss emulation in UVM_CACHING (pytorch#1637)
Summary: Pull Request resolved: pytorch#1637 Enforce cache misses (even if trace-driven testing doesn't experience cache miss due to limited trace size) so that we can evaluate performance under cache misses. Note that it's not exactly cache misses; enforce access to UVM by overriding lxu_cache_locations -- N / 256 requests. Differential Revision: D42194019 fbshipit-source-id: 8c64eb32393eaaa06b5419633711fdfdd2fde890
1 parent 75dd112 commit f3b9b27

File tree

4 files changed

+258
-24
lines changed

4 files changed

+258
-24
lines changed

fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
* This source code is licensed under the BSD-style license found in the
55
* LICENSE file in the root directory of this source tree.
66
*/
7+
78
#include <ATen/ATen.h>
89
#include <ATen/TypeDefault.h>
910
#include <ATen/core/op_registration/op_registration.h>
1011
#include <ATen/cuda/CUDAContext.h>
1112
#include <torch/library.h>
12-
#include <algorithm>
1313
#include "c10/core/ScalarType.h"
1414
#ifdef FBCODE_CAFFE2
1515
#include "common/stats/Stats.h"
@@ -18,6 +18,8 @@
1818
#include "fbgemm_gpu/sparse_ops_utils.h"
1919
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
2020

21+
#include <algorithm>
22+
2123
using Tensor = at::Tensor;
2224
using namespace fbgemm_gpu;
2325

@@ -37,14 +39,29 @@ DEFINE_quantile_stat(
3739
facebook::fb303::ExportTypeConsts::kNone,
3840
std::array<double, 4>{{.25, .50, .75, .99}});
3941

40-
// Miss rate due to conflict in cache associativity.
42+
// (Unique) Miss rate due to conflict in cache associativity.
4143
// # unique misses due to conflict / # requested indices.
4244
DEFINE_quantile_stat(
4345
tbe_uvm_cache_conflict_unique_miss_rate,
4446
"tbe_uvm_cache_conflict_unique_miss_rate_per_mille",
4547
facebook::fb303::ExportTypeConsts::kNone,
4648
std::array<double, 4>{{.25, .50, .75, .99}});
4749

50+
// Miss rate due to conflict in cache associativity.
51+
// # misses due to conflict / # requested indices.
52+
DEFINE_quantile_stat(
53+
tbe_uvm_cache_conflict_miss_rate,
54+
"tbe_uvm_cache_conflict_miss_rate_per_mille",
55+
facebook::fb303::ExportTypeConsts::kNone,
56+
std::array<double, 4>{{.25, .50, .75, .99}});
57+
58+
// Total miss rate.
59+
DEFINE_quantile_stat(
60+
tbe_uvm_cache_total_miss_rate,
61+
"tbe_uvm_cache_total_miss_rate_per_mille",
62+
facebook::fb303::ExportTypeConsts::kNone,
63+
std::array<double, 4>{{.25, .50, .75, .99}});
64+
4865
// FLAGs to control UVMCacheStats.
4966
DEFINE_int32(
5067
tbe_uvm_cache_stat_report,
@@ -58,6 +75,12 @@ DEFINE_int32(
5875
"If tbe_uvm_cache_stat_report is enabled, more detailed raw stats will be printed with this "
5976
"period. This should be an integer multiple of tbe_uvm_cache_stat_report.");
6077

78+
DEFINE_int32(
79+
tbe_uvm_cache_enforced_misses,
80+
0,
81+
"If set to non-zero, some cache lookups (tbe_uvm_cache_enforced_misses / 256) are enforced to be misses; "
82+
"this is performance evaluation purposes only; and should be zero otherwise.");
83+
6184
// TODO: align this with uvm_cache_stats_index in
6285
// split_embeddings_cache_cuda.cu.
6386
const int kUvmCacheStatsSize = 6;
@@ -84,10 +107,11 @@ void process_uvm_cache_stats(
84107
// uvm_cache_stats_counters[0]: num_req_indices
85108
// uvm_cache_stats_counters[1]: num_unique_indices
86109
// uvm_cache_stats_counters[2]: num_unique_misses
87-
// uvm_cache_stats_counters[3]: num_unique_conflict_misses
110+
// uvm_cache_stats_counters[3]: num_conflict_unique_misses
111+
// uvm_cache_stats_counters[4]: num_conflict_misses
88112
// They should be zero-out after the calculated rates are populated into
89113
// cache counters.
90-
static std::vector<int64_t> uvm_cache_stats_counters(4);
114+
static std::vector<int64_t> uvm_cache_stats_counters(5);
91115

92116
// Export cache stats.
93117
auto uvm_cache_stats_cpu = uvm_cache_stats.cpu();
@@ -107,19 +131,32 @@ void process_uvm_cache_stats(
107131
// Calculate cache related ratios based on the cumulated numbers and
108132
// push them into the counter pools.
109133
if (populate_uvm_stats && uvm_cache_stats_counters[0] > 0) {
110-
double unique_rate =
134+
const double unique_rate =
111135
static_cast<double>(uvm_cache_stats_counters[1]) /
112136
uvm_cache_stats_counters[0] * 1000;
113-
double unique_miss_rate =
137+
const double unique_miss_rate =
114138
static_cast<double>(uvm_cache_stats_counters[2]) /
115139
uvm_cache_stats_counters[0] * 1000;
116-
double unique_conflict_miss_rate =
140+
const double conflict_unique_miss_rate =
117141
static_cast<double>(uvm_cache_stats_counters[3]) /
118142
uvm_cache_stats_counters[0] * 1000;
143+
const double conflict_miss_rate =
144+
static_cast<double>(uvm_cache_stats_counters[4]) /
145+
uvm_cache_stats_counters[0] * 1000;
146+
// total # misses = unique misses - conflict_unique_misses + conflict
147+
// misses.
148+
const double total_miss_rate =
149+
static_cast<double>(
150+
uvm_cache_stats_counters[2] - uvm_cache_stats_counters[3] +
151+
uvm_cache_stats_counters[4]) /
152+
uvm_cache_stats_counters[0] * 1000;
153+
119154
STATS_tbe_uvm_cache_unique_rate.addValue(unique_rate);
120155
STATS_tbe_uvm_cache_unique_miss_rate.addValue(unique_miss_rate);
121156
STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue(
122-
unique_conflict_miss_rate);
157+
conflict_unique_miss_rate);
158+
STATS_tbe_uvm_cache_conflict_miss_rate.addValue(conflict_miss_rate);
159+
STATS_tbe_uvm_cache_total_miss_rate.addValue(total_miss_rate);
123160

124161
// Fill all the elements of the vector uvm_cache_stats_counters as 0
125162
// to zero out the cumulated counters.
@@ -365,7 +402,7 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
365402
// cache_index_table_map: (linearized) index to table number map.
366403
// 1D tensor, dtype=int32.
367404
c10::optional<Tensor> cache_index_table_map,
368-
// lxu_cache_state: Cache state (cached idnex, or invalid).
405+
// lxu_cache_state: Cache state (cached index, or invalid).
369406
// 2D tensor: # sets x assoc. dtype=int64.
370407
c10::optional<Tensor> lxu_cache_state,
371408
// lxu_state: meta info for replacement (time stamp for LRU).
@@ -461,6 +498,16 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
461498
uvm_cache_stats);
462499

463500
#ifdef FBCODE_CAFFE2
501+
if (FLAGS_tbe_uvm_cache_enforced_misses > 0) {
502+
// Override some lxu_cache_locations (N for every 256 indices) with cache
503+
// miss to enforce access to UVM.
504+
lxu_cache_locations = emulate_cache_miss(
505+
lxu_cache_locations.value(),
506+
FLAGS_tbe_uvm_cache_enforced_misses,
507+
gather_uvm_stats,
508+
uvm_cache_stats);
509+
}
510+
464511
process_uvm_cache_stats(
465512
signature,
466513
total_cache_hash_size.value(),

fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ at::Tensor lxu_cache_lookup_cuda(
155155
bool gather_cache_stats,
156156
c10::optional<at::Tensor> uvm_cache_stats);
157157

158+
at::Tensor emulate_cache_miss(
159+
at::Tensor lxu_cache_locations,
160+
const int64_t enforced_misses_per_256,
161+
const bool gather_cache_stats,
162+
at::Tensor uvm_cache_stats);
163+
158164
///@ingroup table-batched-embed-cuda
159165
/// Lookup the LRU/LFU cache: find the cache weights location for all indices.
160166
/// Look up the slots in the cache corresponding to `linear_cache_indices`, with

fbgemm_gpu/src/split_embeddings_cache_cuda.cu

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ enum uvm_cache_stats_index {
7979
num_conflict_misses = 5,
8080
};
8181

82+
// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is
83+
// not sensitive to grid size as long as the number thread blocks per SM is not
84+
// too small nor too big.
85+
constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16;
86+
87+
int get_max_thread_blocks_for_cache_kernels_() {
88+
cudaDeviceProp* deviceProp =
89+
at::cuda::getDeviceProperties(c10::cuda::current_device());
90+
return deviceProp->multiProcessorCount *
91+
MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS;
92+
}
93+
8294
} // namespace
8395

8496
int64_t host_lxu_cache_slot(int64_t h_in, int64_t C) {
@@ -495,6 +507,69 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> get_unique_indices_cuda(
495507
496508
namespace {
497509
510+
template <typename index_t>
511+
__global__ __launch_bounds__(kMaxThreads) void emulate_cache_miss_kernel(
512+
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
513+
lxu_cache_locations,
514+
const int64_t enforced_misses_per_256,
515+
const bool gather_cache_stats,
516+
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
517+
uvm_cache_stats) {
518+
const int32_t N = lxu_cache_locations.size(0);
519+
int64_t n_enforced_misses = 0;
520+
CUDA_KERNEL_LOOP(n, N) {
521+
if ((n & 0x00FF) < enforced_misses_per_256) {
522+
if (lxu_cache_locations[n] >= 0) {
523+
n_enforced_misses++;
524+
}
525+
lxu_cache_locations[n] = kCacheLocationMissing;
526+
}
527+
}
528+
if (gather_cache_stats && n_enforced_misses > 0) {
529+
atomicAdd(
530+
&uvm_cache_stats[uvm_cache_stats_index::num_conflict_misses],
531+
n_enforced_misses);
532+
}
533+
}
534+
} // namespace
535+
536+
Tensor emulate_cache_miss(
537+
Tensor lxu_cache_locations,
538+
const int64_t enforced_misses_per_256,
539+
const bool gather_cache_stats,
540+
Tensor uvm_cache_stats) {
541+
TENSOR_ON_CUDA_GPU(lxu_cache_locations);
542+
TENSOR_ON_CUDA_GPU(uvm_cache_stats);
543+
544+
const auto N = lxu_cache_locations.numel();
545+
if (lxu_cache_locations.numel() == 0) {
546+
// nothing to do
547+
return lxu_cache_locations;
548+
}
549+
550+
const dim3 blocks(std::min(
551+
div_round_up(N, kMaxThreads),
552+
get_max_thread_blocks_for_cache_kernels_()));
553+
554+
AT_DISPATCH_INDEX_TYPES(
555+
lxu_cache_locations.scalar_type(), "emulate_cache_miss", [&] {
556+
emulate_cache_miss_kernel<<<
557+
blocks,
558+
kMaxThreads,
559+
0,
560+
at::cuda::getCurrentCUDAStream()>>>(
561+
lxu_cache_locations
562+
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
563+
enforced_misses_per_256,
564+
gather_cache_stats,
565+
uvm_cache_stats
566+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
567+
C10_CUDA_KERNEL_LAUNCH_CHECK();
568+
});
569+
return lxu_cache_locations;
570+
}
571+
572+
namespace {
498573
template <typename index_t>
499574
__global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel(
500575
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
@@ -622,19 +697,6 @@ __launch_bounds__(kMaxThreads) void direct_mapped_lru_cache_find_uncached_kernel
622697
}
623698
}
624699
}
625-
626-
// Experiments showed that performance of lru/lxu_cache_find_uncached_kernel is
627-
// not sensitive to grid size as long as the number thread blocks per SM is not
628-
// too small nor too big.
629-
constexpr int MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS = 16;
630-
631-
int get_max_thread_blocks_for_cache_kernels_() {
632-
cudaDeviceProp* deviceProp =
633-
at::cuda::getDeviceProperties(c10::cuda::current_device());
634-
return deviceProp->multiProcessorCount *
635-
MAX_THREAD_BLOCKS_PER_SM_FOR_CACHE_KERNELS;
636-
}
637-
638700
} // namespace
639701
640702
std::pair<Tensor, Tensor> lru_cache_find_uncached_cuda(
@@ -798,8 +860,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
798860
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
799861
uvm_cache_stats) {
800862
const int32_t C = lxu_cache_state.size(0);
801-
int64_t n_conflict_misses = 0;
802-
int64_t n_inserted = 0;
863+
int32_t n_conflict_misses = 0;
864+
int32_t n_inserted = 0;
803865
for (int32_t n = blockIdx.x * blockDim.y + threadIdx.y; n < *N_unique;
804866
n += gridDim.x * blockDim.y) {
805867
// check if this warp is responsible for this whole segment.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
#include <gtest/gtest.h>
8+
9+
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
10+
11+
using namespace ::testing;
12+
13+
// Helper function that generates input tensor for emulate_cache_miss testing.
14+
at::Tensor generate_lxu_cache_locations(
15+
const int64_t num_requests,
16+
const int64_t num_sets,
17+
const int64_t associativity = 32) {
18+
const auto lxu_cache_locations = at::randint(
19+
0,
20+
num_sets * associativity,
21+
{num_requests},
22+
at::device(at::kCPU).dtype(at::kInt));
23+
return lxu_cache_locations;
24+
}
25+
26+
// Wrapper function that takes lxu_cache_locations on CPU, copies it to GPU,
27+
// runs emulate_cache_miss(), and then returns the result, placed on CPU.
28+
std::pair<at::Tensor, at::Tensor> run_emulate_cache_miss(
29+
at::Tensor lxu_cache_locations,
30+
const int64_t enforced_misses_per_256,
31+
const bool gather_uvm_stats = false) {
32+
at::Tensor lxu_cache_locations_copy = at::_to_copy(lxu_cache_locations);
33+
const auto options =
34+
lxu_cache_locations.options().device(at::kCUDA).dtype(at::kInt);
35+
const auto uvm_cache_stats =
36+
gather_uvm_stats ? at::zeros({6}, options) : at::empty({0}, options);
37+
38+
const auto lxu_cache_location_with_cache_misses = emulate_cache_miss(
39+
lxu_cache_locations_copy.to(at::kCUDA),
40+
enforced_misses_per_256,
41+
gather_uvm_stats,
42+
uvm_cache_stats);
43+
return {lxu_cache_location_with_cache_misses.cpu(), uvm_cache_stats.cpu()};
44+
}
45+
46+
TEST(uvm_cache_miss_emulate_test, no_cache_miss) {
47+
constexpr int64_t num_requests = 10000;
48+
constexpr int64_t num_sets = 32768;
49+
constexpr int64_t associativity = 32;
50+
51+
auto lxu_cache_locations_cpu =
52+
generate_lxu_cache_locations(num_requests, num_sets, associativity);
53+
auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats =
54+
run_emulate_cache_miss(lxu_cache_locations_cpu, 0);
55+
auto lxu_cache_location_with_cache_misses =
56+
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first;
57+
EXPECT_TRUE(
58+
at::equal(lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses));
59+
}
60+
61+
TEST(uvm_cache_miss_emulate_test, enforced_cache_miss) {
62+
constexpr int64_t num_requests = 10000;
63+
constexpr int64_t num_sets = 32768;
64+
constexpr int64_t associativity = 32;
65+
constexpr std::array<int64_t, 6> enforced_misses_per_256_for_testing = {
66+
1, 5, 7, 33, 100, 256};
67+
68+
for (const bool miss_in_lxu_cache_locations : {false, true}) {
69+
for (const bool gather_cache_stats : {false, true}) {
70+
for (const auto enforced_misses_per_256 :
71+
enforced_misses_per_256_for_testing) {
72+
auto lxu_cache_locations_cpu =
73+
generate_lxu_cache_locations(num_requests, num_sets, associativity);
74+
if (miss_in_lxu_cache_locations) {
75+
// one miss in the original lxu_cache_locations; shouldn't be counted
76+
// as enforced misses from emulate_cache_miss().
77+
auto z = lxu_cache_locations_cpu.data_ptr<int32_t>();
78+
z[0] = -1;
79+
}
80+
auto lxu_cache_location_with_cache_misses_and_uvm_cache_stats =
81+
run_emulate_cache_miss(
82+
lxu_cache_locations_cpu,
83+
enforced_misses_per_256,
84+
gather_cache_stats);
85+
auto lxu_cache_location_with_cache_misses =
86+
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.first;
87+
EXPECT_FALSE(at::equal(
88+
lxu_cache_locations_cpu, lxu_cache_location_with_cache_misses));
89+
90+
auto x = lxu_cache_locations_cpu.data_ptr<int32_t>();
91+
auto y = lxu_cache_location_with_cache_misses.data_ptr<int32_t>();
92+
int64_t enforced_misses = 0;
93+
for (int32_t i = 0; i < lxu_cache_locations_cpu.numel(); ++i) {
94+
if (x[i] != y[i]) {
95+
EXPECT_EQ(y[i], -1);
96+
enforced_misses++;
97+
}
98+
}
99+
int64_t num_requests_over_256 =
100+
static_cast<int64_t>(num_requests / 256);
101+
int64_t expected_misses = num_requests_over_256 *
102+
enforced_misses_per_256 +
103+
std::min((num_requests - num_requests_over_256 * 256),
104+
enforced_misses_per_256);
105+
if (miss_in_lxu_cache_locations) {
106+
expected_misses--;
107+
}
108+
EXPECT_EQ(expected_misses, enforced_misses);
109+
if (gather_cache_stats) {
110+
auto uvm_cache_stats =
111+
lxu_cache_location_with_cache_misses_and_uvm_cache_stats.second;
112+
auto cache_stats_ptr = uvm_cache_stats.data_ptr<int32_t>();
113+
// enforced misses are recorded as conflict misses.
114+
EXPECT_EQ(expected_misses, cache_stats_ptr[5]);
115+
}
116+
}
117+
}
118+
}
119+
}

0 commit comments

Comments
 (0)