Skip to content

Commit 9c1ce42

Browse files
committed
Add backward compatible condition for test_rope.py
1 parent ba44c57 commit 9c1ce42

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

benchmark/scripts/benchmark_rope.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import triton
33

4+
from transformers import __version__ as transformers_version
45
from transformers.models.llama.configuration_llama import LlamaConfig
56
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
67
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -32,7 +33,13 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
3233

3334
head_dim = hidden_size // num_q_heads
3435
llama_config = LlamaConfig(head_dim=head_dim)
35-
rotary_emb = LlamaRotaryEmbedding(llama_config, device=device)
36+
37+
if transformers_version < "4.48.0":
38+
# LlamaRotaryEmbedding constructor signature changed in transformers 4.48.0
39+
rotary_emb = LlamaRotaryEmbedding(head_dim=head_dim, device=device)
40+
else:
41+
llama_config = LlamaConfig(head_dim=head_dim)
42+
rotary_emb = LlamaRotaryEmbedding(llama_config, device=device)
3643
q = torch.randn(
3744
(1, seq_len, num_q_heads, head_dim),
3845
device=device,
@@ -107,8 +114,13 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
107114
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
108115

109116
head_dim = hidden_size // num_q_heads
110-
llama_config = LlamaConfig(head_dim=head_dim)
111-
rotary_emb = LlamaRotaryEmbedding(llama_config, device=device)
117+
118+
if transformers_version < "4.48.0":
119+
# LlamaRotaryEmbedding constructor signature changed in transformers 4.48.0
120+
rotary_emb = LlamaRotaryEmbedding(head_dim=head_dim, device=device)
121+
else:
122+
llama_config = LlamaConfig(head_dim=head_dim)
123+
rotary_emb = LlamaRotaryEmbedding(llama_config, device=device)
112124
q = torch.randn(
113125
(1, seq_len, num_q_heads, head_dim),
114126
device=device,

test/transformers/test_rope.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from test.utils import supports_bfloat16
5+
from transformers import __version__ as transformers_version
56
from transformers.models.llama.configuration_llama import LlamaConfig
67
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
78
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -58,8 +59,12 @@ def test_correctness(
5859
atol,
5960
rtol,
6061
):
61-
llama_config = LlamaConfig(head_dim=head_dim)
62-
rotary_emb = LlamaRotaryEmbedding(llama_config, device=device)
62+
if transformers_version < "4.48.0":
63+
# LlamaRotaryEmbedding constructor signature changed in transformers 4.48.0
64+
rotary_emb = LlamaRotaryEmbedding(head_dim=head_dim, device=device)
65+
else:
66+
llama_config = LlamaConfig(head_dim=head_dim)
67+
rotary_emb = LlamaRotaryEmbedding(llama_config, device=device)
6368

6469
_tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype)
6570

0 commit comments

Comments
 (0)