Skip to content

Commit 8ba4886

Browse files
q10facebook-github-bot
authored andcommitted
Add more helper methods for TBE benchmarking (pytorch#828)
Summary: Pull Request resolved: facebookresearch/FBGEMM#828 X-link: pytorch#3747 - Integrate TBEDataConfig into TBE device benchmark CLI Reviewed By: sryap Differential Revision: D70296132 fbshipit-source-id: 7ebc3971296d6bba246070137d2c67383bd69ae3
1 parent d647ec9 commit 8ba4886

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ class EmbeddingLocation(enum.IntEnum):
3434
MTIA = 4
3535

3636

37+
def str_to_embedding_location(key: str) -> EmbeddingLocation:
38+
lookup = {
39+
"device": EmbeddingLocation.DEVICE,
40+
"managed": EmbeddingLocation.MANAGED,
41+
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
42+
"host": EmbeddingLocation.HOST,
43+
"mtia": EmbeddingLocation.MTIA,
44+
}
45+
if key in lookup:
46+
return lookup[key]
47+
else:
48+
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
49+
50+
3751
class CacheAlgorithm(enum.Enum):
3852
LRU = 0
3953
LFU = 1
@@ -57,6 +71,21 @@ class PoolingMode(enum.IntEnum):
5771
MEAN = 1
5872
NONE = 2
5973

74+
def do_pooling(self) -> bool:
75+
return self is not PoolingMode.NONE
76+
77+
78+
def str_to_pooling_mode(key: str) -> PoolingMode:
79+
lookup = {
80+
"sum": PoolingMode.SUM,
81+
"mean": PoolingMode.MEAN,
82+
"none": PoolingMode.NONE,
83+
}
84+
if key in lookup:
85+
return lookup[key]
86+
else:
87+
raise ValueError(f"Cannot parse value into PoolingMode: {key}")
88+
6089

6190
class BoundsCheckMode(enum.IntEnum):
6291
# Raise an exception (CPU) or device-side assert (CUDA)

0 commit comments

Comments
 (0)