36
36
benchmark_pipelined_requests ,
37
37
benchmark_requests ,
38
38
benchmark_vbe ,
39
+ EmbeddingOpsCommonConfigLoader ,
40
+ TBEBenchmarkingConfigLoader ,
39
41
)
40
42
from fbgemm_gpu .tbe .ssd import SSDTableBatchedEmbeddingBags
41
43
from fbgemm_gpu .tbe .utils import generate_requests , get_device , round_up , TBERequest
@@ -1202,56 +1204,118 @@ def device_with_spec( # noqa C901
1202
1204
1203
1205
1204
1206
@cli .command ()
1205
- @click .option ("--batch-sizes" , default = "128000,1280" )
1206
- @click .option ("--embedding-dims" , default = "1024,16" )
1207
- @click .option ("--bag-sizes" , default = "5,2" )
1208
- @click .option ("--nums-embeddings" , default = "10000,1000000" )
1209
- @click .option ("--num-tables" , default = 2 )
1210
- @click .option ("--iters" , default = 100 )
1207
+ @click .option (
1208
+ "--batch-size-list" ,
1209
+ type = str ,
1210
+ required = True ,
1211
+ help = "A comma separated list of batch sizes (B) for each table." ,
1212
+ )
1213
+ @click .option (
1214
+ "--embedding-dim-list" ,
1215
+ type = str ,
1216
+ required = True ,
1217
+ help = "A comma separated list of embedding dimensions (D) for each table." ,
1218
+ )
1219
+ @click .option (
1220
+ "--bag-size-list" ,
1221
+ type = str ,
1222
+ required = True ,
1223
+ help = "A comma separated list of bag sizes (L) for each table." ,
1224
+ )
1225
+ @click .option (
1226
+ "--bag-size-sigma-list" ,
1227
+ type = str ,
1228
+ default = "None" ,
1229
+ help = "A comma separated list of standard deviations for generating bag sizes per table. "
1230
+ "If 'None' is set, bag sizes are fixed per table." ,
1231
+ )
1232
+ @click .option (
1233
+ "--num-embeddings-list" ,
1234
+ type = str ,
1235
+ required = True ,
1236
+ help = "A comma separated list of number of embeddings (E) for each table." ,
1237
+ )
1238
+ @click .option (
1239
+ "--alpha-list" ,
1240
+ type = str ,
1241
+ default = "None" ,
1242
+ help = "A comma separated list of ZipF-alpha values for index distribution for each table. "
1243
+ "If 'None' is set, uniform distribution is used." ,
1244
+ )
1245
+ @click .option (
1246
+ "--num-tables" ,
1247
+ type = int ,
1248
+ required = True ,
1249
+ help = "The number of tables." ,
1250
+ )
1251
+ @click .option (
1252
+ "--weighted" ,
1253
+ is_flag = True ,
1254
+ default = False ,
1255
+ help = "Whether the table is weighted or not" ,
1256
+ )
1257
+ @TBEBenchmarkingConfigLoader .options
1258
+ @EmbeddingOpsCommonConfigLoader .options
1259
+ @click .pass_context
1211
1260
def vbe (
1212
- batch_sizes : str ,
1213
- embedding_dims : str ,
1214
- bag_sizes : str ,
1215
- nums_embeddings : str ,
1261
+ context : click .Context ,
1262
+ batch_size_list : str ,
1263
+ embedding_dim_list : str ,
1264
+ bag_size_list : str ,
1265
+ bag_size_sigma_list : str ,
1266
+ num_embeddings_list : str ,
1267
+ alpha_list : str ,
1216
1268
num_tables : int ,
1217
- iters : int ,
1269
+ weighted : bool ,
1270
+ # pyre-ignore[2]
1271
+ ** kwargs ,
1218
1272
) -> None :
1219
1273
"""
1220
1274
A benchmark function to evaluate variable batch-size table-batched
1221
1275
embedding (VBE) kernels for both forward and backward. Unlike TBE,
1222
1276
batch sizes can be specified per table for VBE.
1277
+ """
1278
+ np .random .seed (42 )
1279
+ torch .manual_seed (42 )
1223
1280
1224
- Args:
1225
- batch_sizes (str):
1226
- A comma separated list of batch sizes for each table.
1227
-
1228
- embedding_dims (str):
1229
- A comma separated list of embedding dimensions for each table.
1281
+ # Load general TBE benchmarking configuration from cli arguments
1282
+ benchconfig = TBEBenchmarkingConfigLoader .load (context )
1283
+ if benchconfig .num_requests != benchconfig .iterations :
1284
+ raise ValueError ("--bench-num-requests is not supported." )
1230
1285
1231
- bag_sizes (str) :
1232
- A comma separated list of bag sizes for each table.
1286
+ if benchconfig . flush_gpu_cache_size_mb != 0 :
1287
+ raise ValueError ( "--bench-flush-gpu-cache-size is not supported." )
1233
1288
1234
- num_embeddings (str) :
1235
- A comma separated list of number of embeddings for each table.
1289
+ if benchconfig . export_trace :
1290
+ raise ValueError ( "--bench-export-trace is not supported." )
1236
1291
1237
- num_tables (int):
1238
- The number of tables.
1292
+ # Load common embedding op configuration from cli arguments
1293
+ embconfig = EmbeddingOpsCommonConfigLoader .load (context )
1294
+ if embconfig .uvm_host_mapped :
1295
+ raise ValueError ("--emb-uvm-host-mapped is not supported." )
1239
1296
1240
- iters (int):
1241
- The number of iterations to run the benchmark for.
1242
- """
1243
-
1244
- torch .manual_seed (42 )
1245
- Bs = [int (v ) for v in batch_sizes .split ("," )]
1246
- Ds = [int (v ) for v in embedding_dims .split ("," )]
1247
- Ls = [int (v ) for v in bag_sizes .split ("," )]
1248
- Es = [int (v ) for v in nums_embeddings .split ("," )]
1249
1297
T = num_tables
1298
+ alphas = (
1299
+ [float (alpha ) for alpha in alpha_list .split ("," )]
1300
+ if alpha_list != "None"
1301
+ else [0.0 ] * T
1302
+ )
1303
+ Bs = [int (v ) for v in batch_size_list .split ("," )]
1304
+ Ds = [int (v ) for v in embedding_dim_list .split ("," )]
1305
+ Ls = [int (v ) for v in bag_size_list .split ("," )]
1306
+ sigma_Ls = (
1307
+ [int (sigma ) for sigma in bag_size_sigma_list .split ("," )]
1308
+ if bag_size_sigma_list != "None"
1309
+ else [0 ] * T
1310
+ )
1311
+ Es = [int (v ) for v in num_embeddings_list .split ("," )]
1250
1312
1251
1313
# All these variables must have the same length.
1314
+ assert T == len (alphas )
1252
1315
assert T == len (Bs )
1253
1316
assert T == len (Ds )
1254
1317
assert T == len (Ls )
1318
+ assert T == len (sigma_Ls )
1255
1319
assert T == len (Es )
1256
1320
1257
1321
optimizer = OptimType .EXACT_ROWWISE_ADAGRAD
@@ -1260,7 +1324,6 @@ def vbe(
1260
1324
if get_available_compute_device () != ComputeDevice .CPU
1261
1325
else EmbeddingLocation .HOST
1262
1326
)
1263
- pooling_mode = PoolingMode .SUM
1264
1327
1265
1328
emb = SplitTableBatchedEmbeddingBagsCodegen (
1266
1329
[
@@ -1275,66 +1338,67 @@ def vbe(
1275
1338
optimizer = optimizer ,
1276
1339
learning_rate = 0.1 ,
1277
1340
eps = 0.1 ,
1278
- weights_precision = SparseType .FP32 ,
1279
- stochastic_rounding = False ,
1280
- output_dtype = SparseType .FP32 ,
1281
- pooling_mode = pooling_mode ,
1282
- bounds_check_mode = BoundsCheckMode (BoundsCheckMode .NONE .value ),
1341
+ cache_precision = embconfig .cache_dtype ,
1342
+ weights_precision = embconfig .weights_dtype ,
1343
+ stochastic_rounding = embconfig .stochastic_rounding ,
1344
+ output_dtype = embconfig .output_dtype ,
1345
+ pooling_mode = embconfig .pooling_mode ,
1346
+ bounds_check_mode = embconfig .bounds_check_mode ,
1283
1347
).to (get_device ())
1284
1348
1285
- lengths_list : List [torch .Tensor ] = []
1286
- num_values_per_table : List [int ] = []
1287
- for t , B in enumerate (Bs ):
1288
- L = Ls [t ]
1289
- # Assume a uniformly distributed random number in [0, 2L)
1290
- # On average it should be L.
1291
- lengths_list .append (
1292
- torch .randint (
1293
- low = 0 , high = 2 * L , size = (B ,), dtype = torch .int64 , device = get_device ()
1294
- )
1349
+ all_requests = {
1350
+ "indices" : [[] for _ in range (benchconfig .iterations )],
1351
+ "offsets" : [[] for _ in range (benchconfig .iterations )],
1352
+ "weights" : [[] for _ in range (benchconfig .iterations )],
1353
+ }
1354
+ for t , (E , B , L , sigma_L , alpha ) in enumerate (zip (Es , Bs , Ls , sigma_Ls , alphas )):
1355
+ # Generate a request for a single table.
1356
+ local_requests = generate_requests (
1357
+ benchconfig .iterations ,
1358
+ B ,
1359
+ 1 ,
1360
+ L ,
1361
+ E ,
1362
+ alpha = alpha ,
1363
+ weighted = weighted ,
1364
+ sigma_L = sigma_L ,
1365
+ zipf_oversample_ratio = 3 if L > 5 else 5 ,
1366
+ use_cpu = get_available_compute_device () == ComputeDevice .CPU ,
1367
+ index_dtype = torch .long ,
1368
+ offset_dtype = torch .long ,
1295
1369
)
1296
1370
1297
- # num_values is used later.
1298
- # Note: sum().tolist() returns a scalar value.
1299
- # pyre-ignore
1300
- num_values : int = torch .sum (lengths_list [- 1 ]).tolist ()
1301
- num_values_per_table .append (num_values )
1302
-
1303
- lengths = torch .cat (lengths_list , 0 )
1304
-
1305
- # Convert lengths into offsets.
1306
- offsets = torch .ops .fbgemm .asynchronous_complete_cumsum (lengths ).long ()
1307
-
1308
- # Set up values.
1309
- values_list : List [torch .Tensor ] = []
1310
- for t , E in enumerate (Es ):
1311
- # Assuming that an index distribution is uniform [0, E)
1312
- values_list .append (
1313
- torch .randint (
1314
- low = 0 ,
1315
- high = E ,
1316
- size = (num_values_per_table [t ],),
1317
- dtype = torch .int32 ,
1318
- device = get_device (),
1319
- )
1320
- )
1321
- values = torch .cat (values_list , 0 ).long ()
1371
+ # Store requests for each table in all_requests.
1372
+ for i , req in enumerate (local_requests ):
1373
+ indices , offsets , weights = req .unpack_3 ()
1374
+ all_requests ["indices" ][i ].append (indices )
1375
+ if t > 0 :
1376
+ offsets = offsets [1 :] # remove the first element
1377
+ offsets += all_requests ["offsets" ][i ][t - 1 ][- 1 ]
1378
+ all_requests ["offsets" ][i ].append (offsets )
1379
+ all_requests ["weights" ][i ].append (weights )
1322
1380
1381
+ # Combine the requests for all tables by
1323
1382
requests = [
1324
1383
(
1325
- values ,
1326
- offsets ,
1384
+ torch .concat (all_requests ["indices" ][i ]),
1385
+ torch .concat (all_requests ["offsets" ][i ]),
1386
+ torch .concat (all_requests ["weights" ][i ]) if weighted else None ,
1327
1387
)
1328
- for _ in range (iters )
1388
+ for i in range (benchconfig . iterations )
1329
1389
]
1330
1390
1391
+ del all_requests
1392
+
1331
1393
fwd_time_sec , bwd_time_sec = benchmark_vbe (
1332
1394
requests ,
1333
- func = lambda indices , offsets : emb .forward (
1395
+ func = lambda indices , offsets , per_sample_weights : emb .forward (
1334
1396
indices ,
1335
1397
offsets ,
1398
+ per_sample_weights ,
1336
1399
batch_size_per_feature_per_rank = [[B ] for B in Bs ],
1337
1400
),
1401
+ num_warmups = benchconfig .warmup_iterations ,
1338
1402
)
1339
1403
logging .info (
1340
1404
f"T: { T } , Bs: { Bs } , Ds: { Ds } , Ls: { Ls } , Es: { Es } \n "
0 commit comments