Skip to content

Commit 1807f8c

Browse files
q10facebook-github-bot
authored andcommitted
Add more helper methods for TBE benchmarking (pytorch#3747)
Summary: X-link: facebookresearch/FBGEMM#828 - Integrate TBEDataConfig into TBE device benchmark CLI Reviewed By: sryap Differential Revision: D70296132
1 parent 14f618d commit 1807f8c

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ class EmbeddingLocation(enum.IntEnum):
3333
HOST = 3
3434
MTIA = 4
3535

36+
@classmethod
37+
# pyre-ignore [3]
38+
def from_str(cls, key: str):
39+
lookup = {
40+
"device": EmbeddingLocation.DEVICE,
41+
"managed": EmbeddingLocation.MANAGED,
42+
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
43+
"host": EmbeddingLocation.HOST,
44+
"mtia": EmbeddingLocation.MTIA,
45+
}
46+
if key in lookup:
47+
return lookup[key]
48+
else:
49+
raise ValueError(f"Cannot parse value into {cls}: {key}")
50+
51+
def __str__(self) -> str:
52+
return {
53+
EmbeddingLocation.DEVICE: "device",
54+
EmbeddingLocation.MANAGED: "managed",
55+
EmbeddingLocation.MANAGED_CACHING: "managed_caching",
56+
EmbeddingLocation.HOST: "host",
57+
EmbeddingLocation.MTIA: "mtia",
58+
}[self]
59+
3660

3761
class CacheAlgorithm(enum.Enum):
3862
LRU = 0
@@ -57,6 +81,29 @@ class PoolingMode(enum.IntEnum):
5781
MEAN = 1
5882
NONE = 2
5983

84+
@classmethod
85+
# pyre-ignore [3]
86+
def from_str(cls, key: str):
87+
lookup = {
88+
"sum": PoolingMode.SUM,
89+
"mean": PoolingMode.MEAN,
90+
"none": PoolingMode.NONE,
91+
}
92+
if key in lookup:
93+
return lookup[key]
94+
else:
95+
raise ValueError(f"Cannot parse value into {cls}: {key}")
96+
97+
def __str__(self) -> str:
98+
return {
99+
PoolingMode.SUM: "sum",
100+
PoolingMode.MEAN: "mean",
101+
PoolingMode.NONE: "none",
102+
}[self]
103+
104+
def do_pooling(self) -> bool:
105+
return self is not PoolingMode.NONE
106+
60107

61108
class BoundsCheckMode(enum.IntEnum):
62109
# Raise an exception (CPU) or device-side assert (CUDA)

0 commit comments

Comments
 (0)