|
29 | 29 | print(f"is_causal: {is_causal}")
|
30 | 30 | for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
|
31 | 31 | flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
|
32 |
| - q = torch.randn(batch, head, seq_len, headdim).half().cuda() |
33 |
| - k = torch.randn(batch, head, seq_len, headdim).half().cuda() |
34 |
| - v = torch.randn(batch, head, seq_len, headdim).half().cuda() |
| 32 | + q = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda") |
| 33 | + k = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda") |
| 34 | + v = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda") |
35 | 35 | for i in range(5): sdpa(q, k, v, is_causal=is_causal)
|
36 | 36 | torch.cuda.synchronize()
|
37 | 37 | _, time = benchmark_forward(sdpa, q, k, v, is_causal=is_causal, repeats=100, verbose=False, desc='Triton')
|
|
41 | 41 | print(f"is_causal: {is_causal}")
|
42 | 42 | for seq_len in {1024, 2048, 4096, 8192, 16384, 32768}:
|
43 | 43 | flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1)
|
44 |
| - q = torch.randn(batch, head, seq_len, headdim).half().cuda() |
45 |
| - k = torch.randn(batch, head, seq_len, headdim).half().cuda() |
46 |
| - v = torch.randn(batch, head, seq_len, headdim).half().cuda() |
| 44 | + q = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda") |
| 45 | + k = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda") |
| 46 | + v = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda") |
47 | 47 | for i in range(5): sdpa(q, k, v, is_causal=is_causal)
|
48 | 48 | torch.cuda.synchronize()
|
49 | 49 | _, time = benchmark_forward(sdpa, q, k, v, is_causal=is_causal, repeats=100, verbose=False, desc='Triton')
|
|
0 commit comments