|
1 | 1 | import torch
|
2 | 2 | import triton
|
3 | 3 |
|
| 4 | +from transformers import __version__ as transformers_version |
4 | 5 | from transformers.models.llama.configuration_llama import LlamaConfig
|
5 | 6 | from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
6 | 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
@@ -32,7 +33,13 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
|
32 | 33 |
|
33 | 34 | head_dim = hidden_size // num_q_heads
|
34 | 35 | llama_config = LlamaConfig(head_dim=head_dim)
|
35 |
| - rotary_emb = LlamaRotaryEmbedding(llama_config, device=device) |
| 36 | + |
| 37 | + if transformers_version < "4.48.0": |
| 38 | + # LlamaRotaryEmbedding constructor signature changed in transformers 4.48.0 |
| 39 | + rotary_emb = LlamaRotaryEmbedding(head_dim=head_dim, device=device) |
| 40 | + else: |
| 41 | + llama_config = LlamaConfig(head_dim=head_dim) |
| 42 | + rotary_emb = LlamaRotaryEmbedding(llama_config, device=device) |
36 | 43 | q = torch.randn(
|
37 | 44 | (1, seq_len, num_q_heads, head_dim),
|
38 | 45 | device=device,
|
@@ -107,8 +114,13 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
|
107 | 114 | seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
|
108 | 115 |
|
109 | 116 | head_dim = hidden_size // num_q_heads
|
110 |
| - llama_config = LlamaConfig(head_dim=head_dim) |
111 |
| - rotary_emb = LlamaRotaryEmbedding(llama_config, device=device) |
| 117 | + |
| 118 | + if transformers_version < "4.48.0": |
| 119 | + # LlamaRotaryEmbedding constructor signature changed in transformers 4.48.0 |
| 120 | + rotary_emb = LlamaRotaryEmbedding(head_dim=head_dim, device=device) |
| 121 | + else: |
| 122 | + llama_config = LlamaConfig(head_dim=head_dim) |
| 123 | + rotary_emb = LlamaRotaryEmbedding(llama_config, device=device) |
112 | 124 | q = torch.randn(
|
113 | 125 | (1, seq_len, num_q_heads, head_dim),
|
114 | 126 | device=device,
|
|
0 commit comments