Skip to content

Commit 2aecfa8

Browse files
committed
[update] initialize tensors on cuda directly for benchmarking
1 parent ecb0aa2 commit 2aecfa8

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

bench/bench_baseline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
print(f"is_causal: {is_causal}")
3030
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
3131
flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
32-
q = torch.randn(batch, head, seq_len, headdim).half().cuda()
33-
k = torch.randn(batch, head, seq_len, headdim).half().cuda()
34-
v = torch.randn(batch, head, seq_len, headdim).half().cuda()
32+
q = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda")
33+
k = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda")
34+
v = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda")
3535
for i in range(5): sdpa(q, k, v, is_causal=is_causal)
3636
torch.cuda.synchronize()
3737
_, time = benchmark_forward(sdpa, q, k, v, is_causal=is_causal, repeats=100, verbose=False, desc='Triton')
@@ -41,9 +41,9 @@
4141
print(f"is_causal: {is_causal}")
4242
for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
4343
flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
44-
q = torch.randn(batch, head, seq_len, headdim).half().cuda()
45-
k = torch.randn(batch, head, seq_len, headdim).half().cuda()
46-
v = torch.randn(batch, head, seq_len, headdim).half().cuda()
44+
q = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda")
45+
k = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda")
46+
v = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda")
4747
for i in range(5): sdpa(q, k, v, is_causal=is_causal)
4848
torch.cuda.synchronize()
4949
_, time = benchmark_forward(sdpa, q, k, v, is_causal=is_causal, repeats=100, verbose=False, desc='Triton')

0 commit comments

Comments
 (0)