@@ -33,6 +33,30 @@ class EmbeddingLocation(enum.IntEnum):
33
33
HOST = 3
34
34
MTIA = 4
35
35
36
+ @classmethod
37
+ # pyre-ignore [3]
38
+ def from_str (cls , key : str ):
39
+ lookup = {
40
+ "device" : EmbeddingLocation .DEVICE ,
41
+ "managed" : EmbeddingLocation .MANAGED ,
42
+ "managed_caching" : EmbeddingLocation .MANAGED_CACHING ,
43
+ "host" : EmbeddingLocation .HOST ,
44
+ "mtia" : EmbeddingLocation .MTIA ,
45
+ }
46
+ if key in lookup :
47
+ return lookup [key ]
48
+ else :
49
+ raise ValueError (f"Cannot parse value into { cls } : { key } " )
50
+
51
+ def __str__ (self ) -> str :
52
+ return {
53
+ EmbeddingLocation .DEVICE : "device" ,
54
+ EmbeddingLocation .MANAGED : "managed" ,
55
+ EmbeddingLocation .MANAGED_CACHING : "managed_caching" ,
56
+ EmbeddingLocation .HOST : "host" ,
57
+ EmbeddingLocation .MTIA : "mtia" ,
58
+ }[self ]
59
+
36
60
37
61
class CacheAlgorithm (enum .Enum ):
38
62
LRU = 0
@@ -57,6 +81,29 @@ class PoolingMode(enum.IntEnum):
57
81
MEAN = 1
58
82
NONE = 2
59
83
84
+ @classmethod
85
+ # pyre-ignore [3]
86
+ def from_str (cls , key : str ):
87
+ lookup = {
88
+ "sum" : PoolingMode .SUM ,
89
+ "mean" : PoolingMode .MEAN ,
90
+ "none" : PoolingMode .NONE ,
91
+ }
92
+ if key in lookup :
93
+ return lookup [key ]
94
+ else :
95
+ raise ValueError (f"Cannot parse value into { cls } : { key } " )
96
+
97
+ def __str__ (self ) -> str :
98
+ return {
99
+ PoolingMode .SUM : "sum" ,
100
+ PoolingMode .MEAN : "mean" ,
101
+ PoolingMode .NONE : "none" ,
102
+ }[self ]
103
+
104
+ def do_pooling (self ) -> bool :
105
+ return self is not PoolingMode .NONE
106
+
60
107
61
108
class BoundsCheckMode (enum .IntEnum ):
62
109
# Raise an exception (CPU) or device-side assert (CUDA)
0 commit comments