19
19
import torch
20
20
from fbgemm_gpu .split_embedding_configs import EmbOptimType as OptimType , SparseType
21
21
from fbgemm_gpu .split_table_batched_embeddings_ops_common import (
22
- BoundsCheckMode ,
23
22
CacheAlgorithm ,
24
23
EmbeddingLocation ,
25
- str_to_embedding_location ,
26
- str_to_pooling_mode ,
27
24
)
28
25
from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
29
26
ComputeDevice ,
32
29
)
33
30
from fbgemm_gpu .tbe .bench import (
34
31
benchmark_requests ,
32
+ EmbeddingOpsCommonConfigLoader ,
35
33
TBEBenchmarkingConfigLoader ,
36
34
TBEDataConfigLoader ,
37
35
)
@@ -50,50 +48,39 @@ def cli() -> None:
50
48
51
49
52
50
@cli .command ()
53
- @click .option ("--weights-precision" , type = SparseType , default = SparseType .FP32 )
54
- @click .option ("--cache-precision" , type = SparseType , default = None )
55
- @click .option ("--stoc" , is_flag = True , default = False )
56
- @click .option (
57
- "--managed" ,
58
- default = "device" ,
59
- type = click .Choice (["device" , "managed" , "managed_caching" ], case_sensitive = False ),
60
- )
61
51
@click .option (
62
52
"--emb-op-type" ,
63
53
default = "split" ,
64
54
type = click .Choice (["split" , "dense" , "ssd" ], case_sensitive = False ),
55
+ help = "The type of the embedding op to benchmark" ,
56
+ )
57
+ @click .option (
58
+ "--row-wise/--no-row-wise" ,
59
+ default = True ,
60
+ help = "Whether to use row-wise adagrad optimzier or not" ,
65
61
)
66
- @click .option ("--row-wise/--no-row-wise" , default = True )
67
- @click .option ("--pooling" , type = str , default = "sum" )
68
- @click .option ("--weighted-num-requires-grad" , type = int , default = None )
69
- @click .option ("--bounds-check-mode" , type = int , default = BoundsCheckMode .NONE .value )
70
- @click .option ("--output-dtype" , type = SparseType , default = SparseType .FP32 )
71
62
@click .option (
72
- "--uvm-host-mapped " ,
73
- is_flag = True ,
74
- default = False ,
75
- help = "Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister) " ,
63
+ "--weighted-num-requires-grad " ,
64
+ type = int ,
65
+ default = None ,
66
+ help = "The number of weighted tables that require gradient " ,
76
67
)
77
68
@click .option (
78
- "--ssd-prefix" , type = str , default = "/tmp/ssd_benchmark" , help = "SSD directory prefix"
69
+ "--ssd-prefix" ,
70
+ type = str ,
71
+ default = "/tmp/ssd_benchmark" ,
72
+ help = "SSD directory prefix" ,
79
73
)
80
74
@click .option ("--cache-load-factor" , default = 0.2 )
81
75
@TBEBenchmarkingConfigLoader .options
82
76
@TBEDataConfigLoader .options
77
+ @EmbeddingOpsCommonConfigLoader .options
83
78
@click .pass_context
84
79
def device ( # noqa C901
85
80
context : click .Context ,
86
81
emb_op_type : click .Choice ,
87
- weights_precision : SparseType ,
88
- cache_precision : Optional [SparseType ],
89
- stoc : bool ,
90
- managed : click .Choice ,
91
82
row_wise : bool ,
92
- pooling : str ,
93
83
weighted_num_requires_grad : Optional [int ],
94
- bounds_check_mode : int ,
95
- output_dtype : SparseType ,
96
- uvm_host_mapped : bool ,
97
84
cache_load_factor : float ,
98
85
# SSD params
99
86
ssd_prefix : str ,
@@ -110,6 +97,9 @@ def device( # noqa C901
110
97
# Load TBE data configuration from cli arguments
111
98
tbeconfig = TBEDataConfigLoader .load (context )
112
99
100
+ # Load common embedding op configuration from cli arguments
101
+ embconfig = EmbeddingOpsCommonConfigLoader .load (context )
102
+
113
103
# Generate feature_requires_grad
114
104
feature_requires_grad = (
115
105
tbeconfig .generate_feature_requires_grad (weighted_num_requires_grad )
@@ -123,22 +113,8 @@ def device( # noqa C901
123
113
# Determine the optimizer
124
114
optimizer = OptimType .EXACT_ROWWISE_ADAGRAD if row_wise else OptimType .EXACT_ADAGRAD
125
115
126
- # Determine the embedding location
127
- embedding_location = str_to_embedding_location (str (managed ))
128
- if embedding_location is EmbeddingLocation .DEVICE and not torch .cuda .is_available ():
129
- embedding_location = EmbeddingLocation .HOST
130
-
131
- # Determine the pooling mode
132
- pooling_mode = str_to_pooling_mode (pooling )
133
-
134
116
# Construct the common split arguments for the embedding op
135
- common_split_args : Dict [str , Any ] = {
136
- "weights_precision" : weights_precision ,
137
- "stochastic_rounding" : stoc ,
138
- "output_dtype" : output_dtype ,
139
- "pooling_mode" : pooling_mode ,
140
- "bounds_check_mode" : BoundsCheckMode (bounds_check_mode ),
141
- "uvm_host_mapped" : uvm_host_mapped ,
117
+ common_split_args : Dict [str , Any ] = embconfig .split_args () | {
142
118
"optimizer" : optimizer ,
143
119
"learning_rate" : 0.1 ,
144
120
"eps" : 0.1 ,
@@ -154,7 +130,7 @@ def device( # noqa C901
154
130
)
155
131
for d in Ds
156
132
],
157
- pooling_mode = pooling_mode ,
133
+ pooling_mode = embconfig . pooling_mode ,
158
134
use_cpu = not torch .cuda .is_available (),
159
135
)
160
136
elif emb_op_type == "ssd" :
@@ -177,7 +153,7 @@ def device( # noqa C901
177
153
(
178
154
tbeconfig .E ,
179
155
d ,
180
- embedding_location ,
156
+ embconfig . embedding_location ,
181
157
(
182
158
ComputeDevice .CUDA
183
159
if torch .cuda .is_available ()
@@ -187,25 +163,27 @@ def device( # noqa C901
187
163
for d in Ds
188
164
],
189
165
cache_precision = (
190
- weights_precision if cache_precision is None else cache_precision
166
+ embconfig .weights_dtype
167
+ if embconfig .cache_dtype is None
168
+ else embconfig .cache_dtype
191
169
),
192
170
cache_algorithm = CacheAlgorithm .LRU ,
193
171
cache_load_factor = cache_load_factor ,
194
172
** common_split_args ,
195
173
)
196
174
embedding_op = embedding_op .to (get_device ())
197
175
198
- if weights_precision == SparseType .INT8 :
176
+ if embconfig . weights_dtype == SparseType .INT8 :
199
177
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
200
178
# min_val: float, max_val: float) -> None, (self:
201
179
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
202
180
# None, Tensor, Module]` is not a function.
203
181
embedding_op .init_embedding_weights_uniform (- 0.0003 , 0.0003 )
204
182
205
183
nparams = sum (d * tbeconfig .E for d in Ds )
206
- param_size_multiplier = weights_precision .bit_rate () / 8.0
207
- output_size_multiplier = output_dtype .bit_rate () / 8.0
208
- if pooling_mode .do_pooling ():
184
+ param_size_multiplier = embconfig . weights_dtype .bit_rate () / 8.0
185
+ output_size_multiplier = embconfig . output_dtype .bit_rate () / 8.0
186
+ if embconfig . pooling_mode .do_pooling ():
209
187
read_write_bytes = (
210
188
output_size_multiplier * tbeconfig .batch_params .B * sum (Ds )
211
189
+ param_size_multiplier
@@ -225,7 +203,7 @@ def device( # noqa C901
225
203
* tbeconfig .pooling_params .L
226
204
)
227
205
228
- logging .info (f"Managed option: { managed } " )
206
+ logging .info (f"Managed option: { embconfig . embedding_location } " )
229
207
logging .info (
230
208
f"Embedding parameters: { nparams / 1.0e9 : .2f} GParam, "
231
209
f"{ nparams * param_size_multiplier / 1.0e9 : .2f} GB"
@@ -274,11 +252,11 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
274
252
f"T: { time_per_iter * 1.0e6 :.0f} us"
275
253
)
276
254
277
- if output_dtype == SparseType .INT8 :
255
+ if embconfig . output_dtype == SparseType .INT8 :
278
256
# backward bench not representative
279
257
return
280
258
281
- if pooling_mode .do_pooling ():
259
+ if embconfig . pooling_mode .do_pooling ():
282
260
grad_output = torch .randn (tbeconfig .batch_params .B , sum (Ds )).to (get_device ())
283
261
else :
284
262
grad_output = torch .randn (
0 commit comments