Skip to content

Commit a25356f

Browse files
q10facebook-github-bot
authored andcommitted
Cleanups for the EEG-based TBE benchmark CLI, pt 2 (#3815)
Summary: Pull Request resolved: #3815 X-link: facebookresearch/FBGEMM#890 - Cleanups for the EEG-based TBE benchmark CLI, pt 2 Reviewed By: jiawenliu64 Differential Revision: D70426271
1 parent 51140fb commit a25356f

File tree

8 files changed

+224
-92
lines changed

8 files changed

+224
-92
lines changed

fbgemm_gpu/bench/tbe/tbe_training_benchmark.py

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,8 @@
1919
import torch
2020
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
2121
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
22-
BoundsCheckMode,
2322
CacheAlgorithm,
2423
EmbeddingLocation,
25-
str_to_embedding_location,
26-
str_to_pooling_mode,
2724
)
2825
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
2926
ComputeDevice,
@@ -32,6 +29,7 @@
3229
)
3330
from fbgemm_gpu.tbe.bench import (
3431
benchmark_requests,
32+
EmbeddingOpsCommonConfigLoader,
3533
TBEBenchmarkingConfigLoader,
3634
TBEDataConfigLoader,
3735
)
@@ -50,50 +48,39 @@ def cli() -> None:
5048

5149

5250
@cli.command()
53-
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
54-
@click.option("--cache-precision", type=SparseType, default=None)
55-
@click.option("--stoc", is_flag=True, default=False)
56-
@click.option(
57-
"--managed",
58-
default="device",
59-
type=click.Choice(["device", "managed", "managed_caching"], case_sensitive=False),
60-
)
6151
@click.option(
6252
"--emb-op-type",
6353
default="split",
6454
type=click.Choice(["split", "dense", "ssd"], case_sensitive=False),
55+
help="The type of the embedding op to benchmark",
56+
)
57+
@click.option(
58+
"--row-wise/--no-row-wise",
59+
default=True,
60+
help="Whether to use row-wise adagrad optimzier or not",
6561
)
66-
@click.option("--row-wise/--no-row-wise", default=True)
67-
@click.option("--pooling", type=str, default="sum")
68-
@click.option("--weighted-num-requires-grad", type=int, default=None)
69-
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value)
70-
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
7162
@click.option(
72-
"--uvm-host-mapped",
73-
is_flag=True,
74-
default=False,
75-
help="Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister)",
63+
"--weighted-num-requires-grad",
64+
type=int,
65+
default=None,
66+
help="The number of weighted tables that require gradient",
7667
)
7768
@click.option(
78-
"--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix"
69+
"--ssd-prefix",
70+
type=str,
71+
default="/tmp/ssd_benchmark",
72+
help="SSD directory prefix",
7973
)
8074
@click.option("--cache-load-factor", default=0.2)
8175
@TBEBenchmarkingConfigLoader.options
8276
@TBEDataConfigLoader.options
77+
@EmbeddingOpsCommonConfigLoader.options
8378
@click.pass_context
8479
def device( # noqa C901
8580
context: click.Context,
8681
emb_op_type: click.Choice,
87-
weights_precision: SparseType,
88-
cache_precision: Optional[SparseType],
89-
stoc: bool,
90-
managed: click.Choice,
9182
row_wise: bool,
92-
pooling: str,
9383
weighted_num_requires_grad: Optional[int],
94-
bounds_check_mode: int,
95-
output_dtype: SparseType,
96-
uvm_host_mapped: bool,
9784
cache_load_factor: float,
9885
# SSD params
9986
ssd_prefix: str,
@@ -110,6 +97,9 @@ def device( # noqa C901
11097
# Load TBE data configuration from cli arguments
11198
tbeconfig = TBEDataConfigLoader.load(context)
11299

100+
# Load common embedding op configuration from cli arguments
101+
embconfig = EmbeddingOpsCommonConfigLoader.load(context)
102+
113103
# Generate feature_requires_grad
114104
feature_requires_grad = (
115105
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
@@ -123,22 +113,8 @@ def device( # noqa C901
123113
# Determine the optimizer
124114
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD
125115

126-
# Determine the embedding location
127-
embedding_location = str_to_embedding_location(str(managed))
128-
if embedding_location is EmbeddingLocation.DEVICE and not torch.cuda.is_available():
129-
embedding_location = EmbeddingLocation.HOST
130-
131-
# Determine the pooling mode
132-
pooling_mode = str_to_pooling_mode(pooling)
133-
134116
# Construct the common split arguments for the embedding op
135-
common_split_args: Dict[str, Any] = {
136-
"weights_precision": weights_precision,
137-
"stochastic_rounding": stoc,
138-
"output_dtype": output_dtype,
139-
"pooling_mode": pooling_mode,
140-
"bounds_check_mode": BoundsCheckMode(bounds_check_mode),
141-
"uvm_host_mapped": uvm_host_mapped,
117+
common_split_args: Dict[str, Any] = embconfig.split_args() | {
142118
"optimizer": optimizer,
143119
"learning_rate": 0.1,
144120
"eps": 0.1,
@@ -154,7 +130,7 @@ def device( # noqa C901
154130
)
155131
for d in Ds
156132
],
157-
pooling_mode=pooling_mode,
133+
pooling_mode=embconfig.pooling_mode,
158134
use_cpu=not torch.cuda.is_available(),
159135
)
160136
elif emb_op_type == "ssd":
@@ -177,7 +153,7 @@ def device( # noqa C901
177153
(
178154
tbeconfig.E,
179155
d,
180-
embedding_location,
156+
embconfig.embedding_location,
181157
(
182158
ComputeDevice.CUDA
183159
if torch.cuda.is_available()
@@ -187,25 +163,27 @@ def device( # noqa C901
187163
for d in Ds
188164
],
189165
cache_precision=(
190-
weights_precision if cache_precision is None else cache_precision
166+
embconfig.weights_dtype
167+
if embconfig.cache_dtype is None
168+
else embconfig.cache_dtype
191169
),
192170
cache_algorithm=CacheAlgorithm.LRU,
193171
cache_load_factor=cache_load_factor,
194172
**common_split_args,
195173
)
196174
embedding_op = embedding_op.to(get_device())
197175

198-
if weights_precision == SparseType.INT8:
176+
if embconfig.weights_dtype == SparseType.INT8:
199177
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
200178
# min_val: float, max_val: float) -> None, (self:
201179
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
202180
# None, Tensor, Module]` is not a function.
203181
embedding_op.init_embedding_weights_uniform(-0.0003, 0.0003)
204182

205183
nparams = sum(d * tbeconfig.E for d in Ds)
206-
param_size_multiplier = weights_precision.bit_rate() / 8.0
207-
output_size_multiplier = output_dtype.bit_rate() / 8.0
208-
if pooling_mode.do_pooling():
184+
param_size_multiplier = embconfig.weights_dtype.bit_rate() / 8.0
185+
output_size_multiplier = embconfig.output_dtype.bit_rate() / 8.0
186+
if embconfig.pooling_mode.do_pooling():
209187
read_write_bytes = (
210188
output_size_multiplier * tbeconfig.batch_params.B * sum(Ds)
211189
+ param_size_multiplier
@@ -225,7 +203,7 @@ def device( # noqa C901
225203
* tbeconfig.pooling_params.L
226204
)
227205

228-
logging.info(f"Managed option: {managed}")
206+
logging.info(f"Managed option: {embconfig.embedding_location}")
229207
logging.info(
230208
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
231209
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
@@ -274,11 +252,11 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
274252
f"T: {time_per_iter * 1.0e6:.0f}us"
275253
)
276254

277-
if output_dtype == SparseType.INT8:
255+
if embconfig.output_dtype == SparseType.INT8:
278256
# backward bench not representative
279257
return
280258

281-
if pooling_mode.do_pooling():
259+
if embconfig.pooling_mode.do_pooling():
282260
grad_output = torch.randn(tbeconfig.batch_params.B, sum(Ds)).to(get_device())
283261
else:
284262
grad_output = torch.randn(

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,20 @@ class EmbeddingLocation(enum.IntEnum):
3333
HOST = 3
3434
MTIA = 4
3535

36-
37-
def str_to_embedding_location(key: str) -> EmbeddingLocation:
38-
lookup = {
39-
"device": EmbeddingLocation.DEVICE,
40-
"managed": EmbeddingLocation.MANAGED,
41-
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
42-
"host": EmbeddingLocation.HOST,
43-
"mtia": EmbeddingLocation.MTIA,
44-
}
45-
if key in lookup:
46-
return lookup[key]
47-
else:
48-
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
36+
@classmethod
37+
# pyre-ignore[3]
38+
def from_str(cls, key: str):
39+
lookup = {
40+
"device": EmbeddingLocation.DEVICE,
41+
"managed": EmbeddingLocation.MANAGED,
42+
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
43+
"host": EmbeddingLocation.HOST,
44+
"mtia": EmbeddingLocation.MTIA,
45+
}
46+
if key in lookup:
47+
return lookup[key]
48+
else:
49+
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
4950

5051

5152
class CacheAlgorithm(enum.Enum):
@@ -74,17 +75,18 @@ class PoolingMode(enum.IntEnum):
7475
def do_pooling(self) -> bool:
7576
return self is not PoolingMode.NONE
7677

77-
78-
def str_to_pooling_mode(key: str) -> PoolingMode:
79-
lookup = {
80-
"sum": PoolingMode.SUM,
81-
"mean": PoolingMode.MEAN,
82-
"none": PoolingMode.NONE,
83-
}
84-
if key in lookup:
85-
return lookup[key]
86-
else:
87-
raise ValueError(f"Cannot parse value into PoolingMode: {key}")
78+
@classmethod
79+
# pyre-ignore[3]
80+
def from_str(cls, key: str):
81+
lookup = {
82+
"sum": PoolingMode.SUM,
83+
"mean": PoolingMode.MEAN,
84+
"none": PoolingMode.NONE,
85+
}
86+
if key in lookup:
87+
return lookup[key]
88+
else:
89+
raise ValueError(f"Cannot parse value into PoolingMode: {key}")
8890

8991

9092
class BoundsCheckMode(enum.IntEnum):

fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
benchmark_requests_refer,
2020
benchmark_vbe,
2121
)
22-
from .config import TBEDataConfig # noqa F401
23-
from .config_loader import TBEDataConfigLoader # noqa F401
24-
from .config_param_models import BatchParams, IndicesParams, PoolingParams # noqa F401
22+
from .embedding_ops_common_config import EmbeddingOpsCommonConfigLoader # noqa F401
2523
from .eval_compression import ( # noqa F401
2624
benchmark_eval_compression,
2725
EvalCompressionBenchmarkOutput,
2826
)
2927
from .reporter import BenchmarkReporter # noqa F401
28+
from .tbe_data_config import TBEDataConfig # noqa F401
29+
from .tbe_data_config_loader import TBEDataConfigLoader # noqa F401
30+
from .tbe_data_config_param_models import ( # noqa F401
31+
BatchParams,
32+
IndicesParams,
33+
PoolingParams,
34+
)
3035
from .utils import fill_random_scale_bias # noqa F401

0 commit comments

Comments
 (0)