Skip to content

Fix LlamaRotaryEmbedding Tests [#520] #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions benchmark/scripts/benchmark_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

head_dim = hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)

config = LlamaConfig(
max_position_embeddings=512,
head_dim=head_dim,
)

rotary_emb = LlamaRotaryEmbedding(config=config, device=device)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
Expand Down Expand Up @@ -105,7 +111,13 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

head_dim = hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)

config = LlamaConfig(
max_position_embeddings=512,
head_dim=head_dim,
)

rotary_emb = LlamaRotaryEmbedding(config=config, device=device)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
Expand Down
16 changes: 14 additions & 2 deletions test/transformers/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from test.utils import supports_bfloat16
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.llama.configuration_llama import LlamaConfig

from liger_kernel.ops.rope import LigerRopeFunction
from liger_kernel.transformers.functional import liger_rope
Expand Down Expand Up @@ -57,7 +58,13 @@ def test_correctness(
atol,
rtol,
):
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)

config = LlamaConfig(
max_position_embeddings=512,
head_dim=head_dim,
)

rotary_emb = LlamaRotaryEmbedding(config=config, device=device)

_tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype)

Expand Down Expand Up @@ -133,7 +140,12 @@ def test_functional_correctness(
k1 = _k.clone().requires_grad_(True)
k2 = _k.clone().requires_grad_(True)

rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
config = LlamaConfig(
max_position_embeddings=512,
head_dim=head_dim,
)

rotary_emb = LlamaRotaryEmbedding(config=config, device=device)

pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
if expand_position_ids:
Expand Down