Skip to content

Commit 800f3d1

Browse files
mxz297facebook-github-bot
authored andcommitted
fix llm shapes in quantize bench and add ldm shapes (pytorch#689)
Summary: X-link: pytorch#3611 Pull Request resolved: facebookresearch/FBGEMM#689 As title Reviewed By: jwfromm Differential Revision: D68594150 fbshipit-source-id: 252e31519be5a05819d8eb3af817f422e3b70b62
1 parent ff8d3df commit 800f3d1

File tree

1 file changed

+49
-9
lines changed

1 file changed

+49
-9
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,61 @@ def set_amd_env_vars() -> None:
3131
os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30"
3232

3333

34-
def get_llama_shapes() -> List[Tuple[int, int, int]]:
34+
def get_llama_shapes() -> List[Tuple[int, int, int, int]]:
3535
# Helper function that returns a list of shapes relevant to llama.
3636

3737
llama_shapes = []
3838
for M in [1, 16, 32, 64, 96, 128, 16384]:
3939
# Add shapes for llama 70B
4040
llama_shapes += [
41-
(M, 1280, 8192),
42-
(M, 8192, 1024),
43-
(M, 7168, 8192),
44-
(M, 8192, 3584),
41+
(1, M, 1280, 8192),
42+
(1, M, 8192, 1024),
43+
(1, M, 7168, 8192),
44+
(1, M, 8192, 3584),
4545
]
4646
# Add shapes for llama 405B
4747
llama_shapes += [
48-
(M, 13312, 6656),
49-
(M, 13312, 16384),
50-
(M, 16384, 6656),
51-
(M, 16384, 16384),
48+
(1, M, 13312, 6656),
49+
(1, M, 13312, 16384),
50+
(1, M, 16384, 6656),
51+
(1, M, 16384, 16384),
5252
]
5353

5454
return llama_shapes
5555

5656

57+
def get_ldm_shapes() -> List[Tuple[int, int, int, int]]:
58+
# Helper function that returns a list of shapes relevant to ldm.
59+
return [
60+
(1, 1536, 3584, 3584),
61+
(1, 8192, 9728, 3584),
62+
(1, 8192, 3584, 9728),
63+
(1, 8192, 3584, 3584),
64+
(1, 4096, 3584, 3584),
65+
(1, 768, 3584, 3584),
66+
(1, 4096, 9728, 3584),
67+
(1, 4096, 3584, 9728),
68+
(1, 7200, 3584, 3584),
69+
(1, 7200, 9728, 3584),
70+
(1, 7200, 3584, 9728),
71+
(1, 3600, 3584, 3584),
72+
(1, 3600, 9728, 3584),
73+
(1, 3600, 3584, 9728),
74+
(1, 1536, 4096, 4096),
75+
(1, 3600, 4096, 4096),
76+
(1, 3600, 11008, 4096),
77+
(1, 3600, 4096, 11008),
78+
(1, 4096, 4096, 4096),
79+
(1, 4096, 11008, 4096),
80+
(1, 4096, 4096, 11008),
81+
(1, 32768, 128, 8192),
82+
(1, 32768, 8192, 1024),
83+
(1, 32768, 8192, 3072),
84+
(1, 32768, 3072, 8192),
85+
(1, 32768, 1024, 8192),
86+
]
87+
88+
5789
def benchmark_grouped(
5890
quantize_ops: List[QuantizeOpBase],
5991
b: List[int],
@@ -297,6 +329,8 @@ def main(args: Any):
297329
B = [int(b) for b in args.B.strip().split(",")]
298330
if args.use_llama_shapes:
299331
MNK = get_llama_shapes()
332+
elif args.use_ldm_shapes:
333+
MNK = get_ldm_shapes()
300334
else:
301335
if args.M is None:
302336
M = [1, 4, 8, 16, 32, 64, 128, 2048, 4096, 8192, 16384]
@@ -419,6 +453,12 @@ def invoke_main() -> None:
419453
action="store_true",
420454
help="If set, benchmark using fixed shapes relevant to llama workloads.",
421455
)
456+
parser.add_argument(
457+
"--use_ldm_shapes",
458+
default=False,
459+
action="store_true",
460+
help="If set, benchmark using fixed shapes relevant to ldm workloads.",
461+
)
422462

423463
args = parser.parse_args()
424464
main(args)

0 commit comments

Comments
 (0)