Skip to content

Commit 33bdde9

Browse files
Improve VBE benchmark (pytorch#955)
Summary: X-link: pytorch#3867 Pull Request resolved: facebookresearch/FBGEMM#955 This Diff improves the VBE benchmark in FBGEMM-GPU as follows: 1. Make argument names consistent with other benchmarks 2. Add more arguments supported by other benchmarks 3. Nice usage/help with `--help` 4. Use a standard request creation mechanism 5. Support measurement on CPU. ## 1. Argument names I carelessly used "a plural form" for a comma-separated list of values, like `--batch-sizes` and `--embedding-dims`, but other benchmarks use "list" suffix, like ` --batch-size-list` and `--embedding-dim-list`. Since VBE is a newer benchmark, let's rename the arguments to make them consistent. ## 2. More arguments The current VBE benchmark does not support arguments other TBE benchmarks support. Let's add more arguments to make the benchmark configurable. ## 3. Nice usage/help with `--help`. Using new config loaders make the usage/help look much nicer (thanks Benson!) ``` TBEBenchmarkingConfigLoader.options EmbeddingOpsCommonConfigLoader.options ``` ## 4. Standard request creation mechanism Previously I created indices/offsets by myself, but seemingly `generate_requests()` works for VBE too. Let's use this. ## 5. Measurement on CPU Because the benchmark code uses `torch.cuda`, we couldn't run a benchmark on CPU. We changed the benchmarking logic to support CPU (using `time.time()`). Reviewed By: sryap Differential Revision: D71596622 fbshipit-source-id: 38941010c43335d09dd44dc61626321e83c79285
1 parent 6f0eede commit 33bdde9

File tree

2 files changed

+215
-101
lines changed

2 files changed

+215
-101
lines changed

fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

Lines changed: 141 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
benchmark_pipelined_requests,
3737
benchmark_requests,
3838
benchmark_vbe,
39+
EmbeddingOpsCommonConfigLoader,
40+
TBEBenchmarkingConfigLoader,
3941
)
4042
from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags
4143
from fbgemm_gpu.tbe.utils import generate_requests, get_device, round_up, TBERequest
@@ -1202,56 +1204,118 @@ def device_with_spec( # noqa C901
12021204

12031205

12041206
@cli.command()
1205-
@click.option("--batch-sizes", default="128000,1280")
1206-
@click.option("--embedding-dims", default="1024,16")
1207-
@click.option("--bag-sizes", default="5,2")
1208-
@click.option("--nums-embeddings", default="10000,1000000")
1209-
@click.option("--num-tables", default=2)
1210-
@click.option("--iters", default=100)
1207+
@click.option(
1208+
"--batch-size-list",
1209+
type=str,
1210+
required=True,
1211+
help="A comma separated list of batch sizes (B) for each table.",
1212+
)
1213+
@click.option(
1214+
"--embedding-dim-list",
1215+
type=str,
1216+
required=True,
1217+
help="A comma separated list of embedding dimensions (D) for each table.",
1218+
)
1219+
@click.option(
1220+
"--bag-size-list",
1221+
type=str,
1222+
required=True,
1223+
help="A comma separated list of bag sizes (L) for each table.",
1224+
)
1225+
@click.option(
1226+
"--bag-size-sigma-list",
1227+
type=str,
1228+
default="None",
1229+
help="A comma separated list of standard deviations for generating bag sizes per table. "
1230+
"If 'None' is set, bag sizes are fixed per table.",
1231+
)
1232+
@click.option(
1233+
"--num-embeddings-list",
1234+
type=str,
1235+
required=True,
1236+
help="A comma separated list of number of embeddings (E) for each table.",
1237+
)
1238+
@click.option(
1239+
"--alpha-list",
1240+
type=str,
1241+
default="None",
1242+
help="A comma separated list of ZipF-alpha values for index distribution for each table. "
1243+
"If 'None' is set, uniform distribution is used.",
1244+
)
1245+
@click.option(
1246+
"--num-tables",
1247+
type=int,
1248+
required=True,
1249+
help="The number of tables.",
1250+
)
1251+
@click.option(
1252+
"--weighted",
1253+
is_flag=True,
1254+
default=False,
1255+
help="Whether the table is weighted or not",
1256+
)
1257+
@TBEBenchmarkingConfigLoader.options
1258+
@EmbeddingOpsCommonConfigLoader.options
1259+
@click.pass_context
12111260
def vbe(
1212-
batch_sizes: str,
1213-
embedding_dims: str,
1214-
bag_sizes: str,
1215-
nums_embeddings: str,
1261+
context: click.Context,
1262+
batch_size_list: str,
1263+
embedding_dim_list: str,
1264+
bag_size_list: str,
1265+
bag_size_sigma_list: str,
1266+
num_embeddings_list: str,
1267+
alpha_list: str,
12161268
num_tables: int,
1217-
iters: int,
1269+
weighted: bool,
1270+
# pyre-ignore[2]
1271+
**kwargs,
12181272
) -> None:
12191273
"""
12201274
A benchmark function to evaluate variable batch-size table-batched
12211275
embedding (VBE) kernels for both forward and backward. Unlike TBE,
12221276
batch sizes can be specified per table for VBE.
1277+
"""
1278+
np.random.seed(42)
1279+
torch.manual_seed(42)
12231280

1224-
Args:
1225-
batch_sizes (str):
1226-
A comma separated list of batch sizes for each table.
1227-
1228-
embedding_dims (str):
1229-
A comma separated list of embedding dimensions for each table.
1281+
# Load general TBE benchmarking configuration from cli arguments
1282+
benchconfig = TBEBenchmarkingConfigLoader.load(context)
1283+
if benchconfig.num_requests != benchconfig.iterations:
1284+
raise ValueError("--bench-num-requests is not supported.")
12301285

1231-
bag_sizes (str):
1232-
A comma separated list of bag sizes for each table.
1286+
if benchconfig.flush_gpu_cache_size_mb != 0:
1287+
raise ValueError("--bench-flush-gpu-cache-size is not supported.")
12331288

1234-
num_embeddings (str):
1235-
A comma separated list of number of embeddings for each table.
1289+
if benchconfig.export_trace:
1290+
raise ValueError("--bench-export-trace is not supported.")
12361291

1237-
num_tables (int):
1238-
The number of tables.
1292+
# Load common embedding op configuration from cli arguments
1293+
embconfig = EmbeddingOpsCommonConfigLoader.load(context)
1294+
if embconfig.uvm_host_mapped:
1295+
raise ValueError("--emb-uvm-host-mapped is not supported.")
12391296

1240-
iters (int):
1241-
The number of iterations to run the benchmark for.
1242-
"""
1243-
1244-
torch.manual_seed(42)
1245-
Bs = [int(v) for v in batch_sizes.split(",")]
1246-
Ds = [int(v) for v in embedding_dims.split(",")]
1247-
Ls = [int(v) for v in bag_sizes.split(",")]
1248-
Es = [int(v) for v in nums_embeddings.split(",")]
12491297
T = num_tables
1298+
alphas = (
1299+
[float(alpha) for alpha in alpha_list.split(",")]
1300+
if alpha_list != "None"
1301+
else [0.0] * T
1302+
)
1303+
Bs = [int(v) for v in batch_size_list.split(",")]
1304+
Ds = [int(v) for v in embedding_dim_list.split(",")]
1305+
Ls = [int(v) for v in bag_size_list.split(",")]
1306+
sigma_Ls = (
1307+
[int(sigma) for sigma in bag_size_sigma_list.split(",")]
1308+
if bag_size_sigma_list != "None"
1309+
else [0] * T
1310+
)
1311+
Es = [int(v) for v in num_embeddings_list.split(",")]
12501312

12511313
# All these variables must have the same length.
1314+
assert T == len(alphas)
12521315
assert T == len(Bs)
12531316
assert T == len(Ds)
12541317
assert T == len(Ls)
1318+
assert T == len(sigma_Ls)
12551319
assert T == len(Es)
12561320

12571321
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
@@ -1260,7 +1324,6 @@ def vbe(
12601324
if get_available_compute_device() != ComputeDevice.CPU
12611325
else EmbeddingLocation.HOST
12621326
)
1263-
pooling_mode = PoolingMode.SUM
12641327

12651328
emb = SplitTableBatchedEmbeddingBagsCodegen(
12661329
[
@@ -1275,66 +1338,67 @@ def vbe(
12751338
optimizer=optimizer,
12761339
learning_rate=0.1,
12771340
eps=0.1,
1278-
weights_precision=SparseType.FP32,
1279-
stochastic_rounding=False,
1280-
output_dtype=SparseType.FP32,
1281-
pooling_mode=pooling_mode,
1282-
bounds_check_mode=BoundsCheckMode(BoundsCheckMode.NONE.value),
1341+
cache_precision=embconfig.cache_dtype,
1342+
weights_precision=embconfig.weights_dtype,
1343+
stochastic_rounding=embconfig.stochastic_rounding,
1344+
output_dtype=embconfig.output_dtype,
1345+
pooling_mode=embconfig.pooling_mode,
1346+
bounds_check_mode=embconfig.bounds_check_mode,
12831347
).to(get_device())
12841348

1285-
lengths_list: List[torch.Tensor] = []
1286-
num_values_per_table: List[int] = []
1287-
for t, B in enumerate(Bs):
1288-
L = Ls[t]
1289-
# Assume a uniformly distributed random number in [0, 2L)
1290-
# On average it should be L.
1291-
lengths_list.append(
1292-
torch.randint(
1293-
low=0, high=2 * L, size=(B,), dtype=torch.int64, device=get_device()
1294-
)
1349+
all_requests = {
1350+
"indices": [[] for _ in range(benchconfig.iterations)],
1351+
"offsets": [[] for _ in range(benchconfig.iterations)],
1352+
"weights": [[] for _ in range(benchconfig.iterations)],
1353+
}
1354+
for t, (E, B, L, sigma_L, alpha) in enumerate(zip(Es, Bs, Ls, sigma_Ls, alphas)):
1355+
# Generate a request for a single table.
1356+
local_requests = generate_requests(
1357+
benchconfig.iterations,
1358+
B,
1359+
1,
1360+
L,
1361+
E,
1362+
alpha=alpha,
1363+
weighted=weighted,
1364+
sigma_L=sigma_L,
1365+
zipf_oversample_ratio=3 if L > 5 else 5,
1366+
use_cpu=get_available_compute_device() == ComputeDevice.CPU,
1367+
index_dtype=torch.long,
1368+
offset_dtype=torch.long,
12951369
)
12961370

1297-
# num_values is used later.
1298-
# Note: sum().tolist() returns a scalar value.
1299-
# pyre-ignore
1300-
num_values: int = torch.sum(lengths_list[-1]).tolist()
1301-
num_values_per_table.append(num_values)
1302-
1303-
lengths = torch.cat(lengths_list, 0)
1304-
1305-
# Convert lengths into offsets.
1306-
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths).long()
1307-
1308-
# Set up values.
1309-
values_list: List[torch.Tensor] = []
1310-
for t, E in enumerate(Es):
1311-
# Assuming that an index distribution is uniform [0, E)
1312-
values_list.append(
1313-
torch.randint(
1314-
low=0,
1315-
high=E,
1316-
size=(num_values_per_table[t],),
1317-
dtype=torch.int32,
1318-
device=get_device(),
1319-
)
1320-
)
1321-
values = torch.cat(values_list, 0).long()
1371+
# Store requests for each table in all_requests.
1372+
for i, req in enumerate(local_requests):
1373+
indices, offsets, weights = req.unpack_3()
1374+
all_requests["indices"][i].append(indices)
1375+
if t > 0:
1376+
offsets = offsets[1:] # remove the first element
1377+
offsets += all_requests["offsets"][i][t - 1][-1]
1378+
all_requests["offsets"][i].append(offsets)
1379+
all_requests["weights"][i].append(weights)
13221380

1381+
# Combine the requests for all tables by
13231382
requests = [
13241383
(
1325-
values,
1326-
offsets,
1384+
torch.concat(all_requests["indices"][i]),
1385+
torch.concat(all_requests["offsets"][i]),
1386+
torch.concat(all_requests["weights"][i]) if weighted else None,
13271387
)
1328-
for _ in range(iters)
1388+
for i in range(benchconfig.iterations)
13291389
]
13301390

1391+
del all_requests
1392+
13311393
fwd_time_sec, bwd_time_sec = benchmark_vbe(
13321394
requests,
1333-
func=lambda indices, offsets: emb.forward(
1395+
func=lambda indices, offsets, per_sample_weights: emb.forward(
13341396
indices,
13351397
offsets,
1398+
per_sample_weights,
13361399
batch_size_per_feature_per_rank=[[B] for B in Bs],
13371400
),
1401+
num_warmups=benchconfig.warmup_iterations,
13381402
)
13391403
logging.info(
13401404
f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n"

0 commit comments

Comments
 (0)