4
4
* This source code is licensed under the BSD-style license found in the
5
5
* LICENSE file in the root directory of this source tree.
6
6
*/
7
+
7
8
#include < ATen/ATen.h>
8
9
#include < ATen/TypeDefault.h>
9
10
#include < ATen/core/op_registration/op_registration.h>
10
11
#include < ATen/cuda/CUDAContext.h>
11
12
#include < torch/library.h>
12
- #include < algorithm>
13
13
#include " c10/core/ScalarType.h"
14
14
#ifdef FBCODE_CAFFE2
15
15
#include " common/stats/Stats.h"
18
18
#include " fbgemm_gpu/sparse_ops_utils.h"
19
19
#include " fbgemm_gpu/split_embeddings_cache_cuda.cuh"
20
20
21
+ #include < algorithm>
22
+
21
23
using Tensor = at::Tensor;
22
24
using namespace fbgemm_gpu ;
23
25
@@ -37,14 +39,29 @@ DEFINE_quantile_stat(
37
39
facebook::fb303::ExportTypeConsts::kNone ,
38
40
std::array<double , 4 >{{.25 , .50 , .75 , .99 }});
39
41
40
- // Miss rate due to conflict in cache associativity.
42
+ // (Unique) Miss rate due to conflict in cache associativity.
41
43
// # unique misses due to conflict / # requested indices.
42
44
DEFINE_quantile_stat (
43
45
tbe_uvm_cache_conflict_unique_miss_rate,
44
46
" tbe_uvm_cache_conflict_unique_miss_rate_per_mille" ,
45
47
facebook::fb303::ExportTypeConsts::kNone ,
46
48
std::array<double , 4 >{{.25 , .50 , .75 , .99 }});
47
49
50
+ // Miss rate due to conflict in cache associativity.
51
+ // # misses due to conflict / # requested indices.
52
+ DEFINE_quantile_stat (
53
+ tbe_uvm_cache_conflict_miss_rate,
54
+ " tbe_uvm_cache_conflict_miss_rate_per_mille" ,
55
+ facebook::fb303::ExportTypeConsts::kNone ,
56
+ std::array<double , 4 >{{.25 , .50 , .75 , .99 }});
57
+
58
+ // Total miss rate.
59
+ DEFINE_quantile_stat (
60
+ tbe_uvm_cache_total_miss_rate,
61
+ " tbe_uvm_cache_total_miss_rate_per_mille" ,
62
+ facebook::fb303::ExportTypeConsts::kNone ,
63
+ std::array<double , 4 >{{.25 , .50 , .75 , .99 }});
64
+
48
65
// FLAGs to control UVMCacheStats.
49
66
DEFINE_int32 (
50
67
tbe_uvm_cache_stat_report,
@@ -58,6 +75,12 @@ DEFINE_int32(
58
75
" If tbe_uvm_cache_stat_report is enabled, more detailed raw stats will be printed with this "
59
76
" period. This should be an integer multiple of tbe_uvm_cache_stat_report." );
60
77
78
+ DEFINE_int32 (
79
+ tbe_uvm_cache_enforced_misses,
80
+ 0 ,
81
+ " If set to non-zero, some cache lookups (tbe_uvm_cache_enforced_misses / 256) are enforced to be misses; "
82
+ " this is performance evaluation purposes only; and should be zero otherwise." );
83
+
61
84
// TODO: align this with uvm_cache_stats_index in
62
85
// split_embeddings_cache_cuda.cu.
63
86
const int kUvmCacheStatsSize = 6 ;
@@ -84,10 +107,11 @@ void process_uvm_cache_stats(
84
107
// uvm_cache_stats_counters[0]: num_req_indices
85
108
// uvm_cache_stats_counters[1]: num_unique_indices
86
109
// uvm_cache_stats_counters[2]: num_unique_misses
87
- // uvm_cache_stats_counters[3]: num_unique_conflict_misses
110
+ // uvm_cache_stats_counters[3]: num_conflict_unique_misses
111
+ // uvm_cache_stats_counters[4]: num_conflict_misses
88
112
// They should be zero-out after the calculated rates are populated into
89
113
// cache counters.
90
- static std::vector<int64_t > uvm_cache_stats_counters (4 );
114
+ static std::vector<int64_t > uvm_cache_stats_counters (5 );
91
115
92
116
// Export cache stats.
93
117
auto uvm_cache_stats_cpu = uvm_cache_stats.cpu ();
@@ -107,19 +131,32 @@ void process_uvm_cache_stats(
107
131
// Calculate cache related ratios based on the cumulated numbers and
108
132
// push them into the counter pools.
109
133
if (populate_uvm_stats && uvm_cache_stats_counters[0 ] > 0 ) {
110
- double unique_rate =
134
+ const double unique_rate =
111
135
static_cast <double >(uvm_cache_stats_counters[1 ]) /
112
136
uvm_cache_stats_counters[0 ] * 1000 ;
113
- double unique_miss_rate =
137
+ const double unique_miss_rate =
114
138
static_cast <double >(uvm_cache_stats_counters[2 ]) /
115
139
uvm_cache_stats_counters[0 ] * 1000 ;
116
- double unique_conflict_miss_rate =
140
+ const double conflict_unique_miss_rate =
117
141
static_cast <double >(uvm_cache_stats_counters[3 ]) /
118
142
uvm_cache_stats_counters[0 ] * 1000 ;
143
+ const double conflict_miss_rate =
144
+ static_cast <double >(uvm_cache_stats_counters[4 ]) /
145
+ uvm_cache_stats_counters[0 ] * 1000 ;
146
+ // total # misses = unique misses - conflict_unique_misses + conflict
147
+ // misses.
148
+ const double total_miss_rate =
149
+ static_cast <double >(
150
+ uvm_cache_stats_counters[2 ] - uvm_cache_stats_counters[3 ] +
151
+ uvm_cache_stats_counters[4 ]) /
152
+ uvm_cache_stats_counters[0 ] * 1000 ;
153
+
119
154
STATS_tbe_uvm_cache_unique_rate.addValue (unique_rate);
120
155
STATS_tbe_uvm_cache_unique_miss_rate.addValue (unique_miss_rate);
121
156
STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue (
122
- unique_conflict_miss_rate);
157
+ conflict_unique_miss_rate);
158
+ STATS_tbe_uvm_cache_conflict_miss_rate.addValue (conflict_miss_rate);
159
+ STATS_tbe_uvm_cache_total_miss_rate.addValue (total_miss_rate);
123
160
124
161
// Fill all the elements of the vector uvm_cache_stats_counters as 0
125
162
// to zero out the cumulated counters.
@@ -365,7 +402,7 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
365
402
// cache_index_table_map: (linearized) index to table number map.
366
403
// 1D tensor, dtype=int32.
367
404
c10::optional<Tensor> cache_index_table_map,
368
- // lxu_cache_state: Cache state (cached idnex , or invalid).
405
+ // lxu_cache_state: Cache state (cached index , or invalid).
369
406
// 2D tensor: # sets x assoc. dtype=int64.
370
407
c10::optional<Tensor> lxu_cache_state,
371
408
// lxu_state: meta info for replacement (time stamp for LRU).
@@ -461,6 +498,16 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
461
498
uvm_cache_stats);
462
499
463
500
#ifdef FBCODE_CAFFE2
501
+ if (FLAGS_tbe_uvm_cache_enforced_misses > 0 ) {
502
+ // Override some lxu_cache_locations (N for every 256 indices) with cache
503
+ // miss to enforce access to UVM.
504
+ lxu_cache_locations = emulate_cache_miss (
505
+ lxu_cache_locations.value (),
506
+ FLAGS_tbe_uvm_cache_enforced_misses,
507
+ gather_uvm_stats,
508
+ uvm_cache_stats);
509
+ }
510
+
464
511
process_uvm_cache_stats (
465
512
signature,
466
513
total_cache_hash_size.value (),
0 commit comments