@@ -31,29 +31,61 @@ def set_amd_env_vars() -> None:
31
31
os .environ ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS" ] = "30"
32
32
33
33
34
- def get_llama_shapes () -> List [Tuple [int , int , int ]]:
34
+ def get_llama_shapes () -> List [Tuple [int , int , int , int ]]:
35
35
# Helper function that returns a list of shapes relevant to llama.
36
36
37
37
llama_shapes = []
38
38
for M in [1 , 16 , 32 , 64 , 96 , 128 , 16384 ]:
39
39
# Add shapes for llama 70B
40
40
llama_shapes += [
41
- (M , 1280 , 8192 ),
42
- (M , 8192 , 1024 ),
43
- (M , 7168 , 8192 ),
44
- (M , 8192 , 3584 ),
41
+ (1 , M , 1280 , 8192 ),
42
+ (1 , M , 8192 , 1024 ),
43
+ (1 , M , 7168 , 8192 ),
44
+ (1 , M , 8192 , 3584 ),
45
45
]
46
46
# Add shapes for llama 405B
47
47
llama_shapes += [
48
- (M , 13312 , 6656 ),
49
- (M , 13312 , 16384 ),
50
- (M , 16384 , 6656 ),
51
- (M , 16384 , 16384 ),
48
+ (1 , M , 13312 , 6656 ),
49
+ (1 , M , 13312 , 16384 ),
50
+ (1 , M , 16384 , 6656 ),
51
+ (1 , M , 16384 , 16384 ),
52
52
]
53
53
54
54
return llama_shapes
55
55
56
56
57
+ def get_ldm_shapes () -> List [Tuple [int , int , int , int ]]:
58
+ # Helper function that returns a list of shapes relevant to ldm.
59
+ return [
60
+ (1 , 1536 , 3584 , 3584 ),
61
+ (1 , 8192 , 9728 , 3584 ),
62
+ (1 , 8192 , 3584 , 9728 ),
63
+ (1 , 8192 , 3584 , 3584 ),
64
+ (1 , 4096 , 3584 , 3584 ),
65
+ (1 , 768 , 3584 , 3584 ),
66
+ (1 , 4096 , 9728 , 3584 ),
67
+ (1 , 4096 , 3584 , 9728 ),
68
+ (1 , 7200 , 3584 , 3584 ),
69
+ (1 , 7200 , 9728 , 3584 ),
70
+ (1 , 7200 , 3584 , 9728 ),
71
+ (1 , 3600 , 3584 , 3584 ),
72
+ (1 , 3600 , 9728 , 3584 ),
73
+ (1 , 3600 , 3584 , 9728 ),
74
+ (1 , 1536 , 4096 , 4096 ),
75
+ (1 , 3600 , 4096 , 4096 ),
76
+ (1 , 3600 , 11008 , 4096 ),
77
+ (1 , 3600 , 4096 , 11008 ),
78
+ (1 , 4096 , 4096 , 4096 ),
79
+ (1 , 4096 , 11008 , 4096 ),
80
+ (1 , 4096 , 4096 , 11008 ),
81
+ (1 , 32768 , 128 , 8192 ),
82
+ (1 , 32768 , 8192 , 1024 ),
83
+ (1 , 32768 , 8192 , 3072 ),
84
+ (1 , 32768 , 3072 , 8192 ),
85
+ (1 , 32768 , 1024 , 8192 ),
86
+ ]
87
+
88
+
57
89
def benchmark_grouped (
58
90
quantize_ops : List [QuantizeOpBase ],
59
91
b : List [int ],
@@ -297,6 +329,8 @@ def main(args: Any):
297
329
B = [int (b ) for b in args .B .strip ().split ("," )]
298
330
if args .use_llama_shapes :
299
331
MNK = get_llama_shapes ()
332
+ elif args .use_ldm_shapes :
333
+ MNK = get_ldm_shapes ()
300
334
else :
301
335
if args .M is None :
302
336
M = [1 , 4 , 8 , 16 , 32 , 64 , 128 , 2048 , 4096 , 8192 , 16384 ]
@@ -419,6 +453,12 @@ def invoke_main() -> None:
419
453
action = "store_true" ,
420
454
help = "If set, benchmark using fixed shapes relevant to llama workloads." ,
421
455
)
456
+ parser .add_argument (
457
+ "--use_ldm_shapes" ,
458
+ default = False ,
459
+ action = "store_true" ,
460
+ help = "If set, benchmark using fixed shapes relevant to ldm workloads." ,
461
+ )
422
462
423
463
args = parser .parse_args ()
424
464
main (args )
0 commit comments