Skip to content

Fix for Qwen with Yarn #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import logging
from pathlib import Path
from typing import Optional
from typing import Optional, Any

import torch
from datasets import load_dataset
Expand Down Expand Up @@ -103,6 +103,8 @@ def evaluate(
max_context_length: Optional[int] = None,
compress_questions: bool = False,
key_channel_compression_ratio: float = 0.5,
rope_scaling: Optional[dict] = None,
max_position_embeddings: Optional[int] = None,
):
"""
Evaluate a model on a dataset using a press and save the results
Expand Down Expand Up @@ -131,6 +133,14 @@ def evaluate(
Whether to compress the questions as well, by default False
key_channel_compression_ratio : float, optional
key Channel Compression ratio for the channel press, by default 0.5
rope_scaling : dict, optional
RoPE-scaling configuration dictionary passed to
model config's `rope_scaling field.
(e.g. {"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768});
by default None. If set, you **must** also provide ``max_position_embeddings``.
max_position_embeddings : int, optional
The value to set for ``max_position_embeddings`` in the model config when ``rope_scaling`` is used.
Required if ``rope_scaling`` is not ``None``; ignored otherwise.
"""

assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
Expand Down Expand Up @@ -184,7 +194,7 @@ def evaluate(
press.compression_ratio = compression_ratio # type:ignore[attr-defined]

# Initialize pipeline with the correct attention implementation
model_kwargs = {"torch_dtype": "auto"}
model_kwargs: dict[str, Any] = {"torch_dtype": "auto"}
if isinstance(press, ObservedAttentionPress):
model_kwargs["attn_implementation"] = "eager"
else:
Expand All @@ -194,6 +204,16 @@ def evaluate(
model_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
pass
if rope_scaling is not None:
if max_position_embeddings is None:
raise ValueError("max_position_embeddings must be given when rope_scaling is used")

model_kwargs.update(
{
"max_position_embeddings": max_position_embeddings,
"rope_scaling": rope_scaling,
}
)

if device == "auto":
pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
Expand Down
16 changes: 5 additions & 11 deletions kvpress/presses/finch_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import torch
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import rotate_half

from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.key_rerotation_press import KeyRerotationPress


@dataclass
Expand Down Expand Up @@ -93,18 +93,12 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs):
chunk_indices = i + chunk_scores.topk(n_kept, dim=-1).indices
indices.append(chunk_indices)
indices = torch.cat(indices, dim=-1)

indices = torch.sort(indices, dim=2).values
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

# Rerotate keys
if self.rerotate_keys:
cos, sin = kwargs["position_embeddings"]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
keys = keys.gather(2, indices).contiguous()
cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
indices = torch.sort(indices, dim=2).values
keys = KeyRerotationPress.rerotate_keys(module, indices, keys)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
else:
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
keys = keys.gather(2, indices).contiguous()

values = values.gather(2, indices).contiguous()
Expand Down
96 changes: 80 additions & 16 deletions kvpress/presses/key_rerotation_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,85 @@ class KeyRerotationPress(BasePress):
def __post_init__(self):
assert isinstance(self.press, ScorerPress)

@staticmethod
def _rerotate_cos_sin(x, inv_freq, selected_positions):
"""
Compute cosine and sine rotary positional embeddings required to
re-rotate pruned keys back into the canonical RoPE space.

Parameters
----------
x : torch.Tensor
Any key-like tensor that provides ``dtype`` and ``device``.
Shape ``(bsz, num_key_value_heads, q_len, d)``.
inv_freq : torch.Tensor
``module.rotary_emb.inv_freq``. Shape ``(d//2,)``.
selected_positions : torch.Tensor
Indices of the *kept* tokens.
Shape ``(bsz, num_key_value_heads, n_kept)``.

Returns
-------
cos, sin : torch.Tensor
Cosine and sine embeddings, each of shape
``(bsz, num_key_value_heads, n_kept, d)``, matching ``dtype``/``device`` of ``x``.
"""
bsz, num_key_value_heads, n_kept = selected_positions.shape
device = selected_positions.device
device_type = x.device.type
dtype = x.dtype
# Original positional indices
idx = torch.arange(0, n_kept, device=device) # (n_kept,)
idx = idx.unsqueeze(0) # (1, n_kept)
inv_freq = inv_freq[None, None, :, None].float().expand(bsz, num_key_value_heads, -1, 1)
idx = idx[:, None, :].float().expand(bsz, num_key_value_heads, n_kept)
# Compute delta between original and selected positions
delta_pos = idx - selected_positions # (bsz, num_key_value_heads, n_kept)
delta_pos = delta_pos.unsqueeze(2) # (bsz, num_key_value_heads, 1, n_kept)

device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"

with torch.autocast(device_type=device_type, enabled=False):
# Compute the new freq by scaling inv_freq by delta
freqs = delta_pos.float() * inv_freq.float() # (bsz, num_key_value_heads, d//2, n_kept)
freqs = freqs.transpose(2, 3) # (bsz, num_key_value_heads, n_kept, d//2)
emb = torch.cat((freqs, freqs), dim=-1)
# Compute cosine and sine required to re-rotate keys to selected positions
cos = emb.cos().contiguous()
sin = emb.sin().contiguous()
return cos.to(dtype=dtype), sin.to(dtype=dtype)

@staticmethod
def rerotate_keys(
module: nn.Module,
indices: torch.Tensor,
keys: torch.Tensor,
) -> torch.Tensor:
"""
Rerotate keys to have a uniform RoPE representation of keys after pruning.

Parameters
----------
module : nn.Module
The model module containing the rotary embedding.
indices : torch.Tensor
Indices of the kept tokens after pruning.
keys : torch.Tensor
The keys tensor to be rerotated.

Returns
-------
torch.Tensor
The rerotated keys tensor of shape
``(bsz, num_heads, n_kept, d)``.
"""
new_cos, new_sin = KeyRerotationPress._rerotate_cos_sin(keys,
module.rotary_emb.inv_freq,
indices)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
keys = keys.gather(2, indices).contiguous()
return (keys * new_cos) + (rotate_half(keys) * new_sin)

def compress(
self,
module: nn.Module,
Expand All @@ -50,22 +129,7 @@ def compress(
n_kept = int(q_len * (1 - self.press.compression_ratio))
indices = scores.topk(n_kept, dim=-1).indices
indices = torch.sort(indices, dim=2).values
keys = self.rerotate_keys(module, indices, keys)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)

cos, sin = kwargs["position_embeddings"]
# Rerotate as follows
# 1. keys = RoPE(W_k * hidden_states)
# 2. keys_unrotated = RoPE^-1(keys)
# 3. keys_pruned = prune(keys_unrotated)
# 4. keys = RoPE(keys_pruned)

# 2. Inverse of rotation matrix is equivalent to setting sin -> -sin in the equation below
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
# 3. Prune keys
keys = keys.gather(2, indices).contiguous()
# 4. Apply RoPE
cos, sin = cos[:, :n_kept], sin[:, :n_kept]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))

values = values.gather(2, indices).contiguous()
return keys, values
57 changes: 49 additions & 8 deletions tests/presses/test_key_rerotation_press_rope.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

#
# Extended to test both the *default* and the *YaRN-scaled* rotary-embedding
# variants with the smallest possible code changes.

import inspect
from dataclasses import dataclass
from copy import deepcopy

import pytest
import torch
from torch import nn
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaForCausalLM, rotate_half
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaRotaryEmbedding,
rotate_half,
)
from transformers import Gemma3ForCausalLM

from kvpress import KeyRerotationPress, ScorerPress
from tests.fixtures import unit_test_model # noqa: F401


@pytest.mark.parametrize("rope_variant", ["default", "yarn"])
@pytest.mark.parametrize("precision", ["full", "half"])
def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: LlamaForCausalLM, precision): # noqa: F811
def test_rerotate_keys_is_matches_reference_implementation(
unit_test_model: LlamaForCausalLM, # noqa: F811
rope_variant,
precision,
):
"""
Compare KeyRerotationPress' rerotation of keys with the reference implementation.
In the reference implementation, we are computing
Compare KeyRerotationPress' rerotation of keys with the reference
implementation.

Reference path:
1. keys = W_k * hidden_states
2. keys_pruned = prune(keys)
3. keys = RoPE(keys_pruned)

Press path:
1. keys = W_k * hidden_states
2. keys = RoPE(keys)
3. keys_pruned = KeyRerotationPress.rerotate_keys(...)
"""
if rope_variant == "yarn":
layer0 = unit_test_model.model.layers[0]
cfg = deepcopy(layer0.self_attn.config)
cfg.rope_scaling = {
"factor": 4.0,
"original_max_position_embeddings": 32768,
"rope_type": "yarn",
}
cfg.max_position_embeddings = 131072
try:
unit_test_model.model.rotary_emb = LlamaRotaryEmbedding(cfg, device=unit_test_model.device)
except KeyError:
pytest.skip("YaRN rotary-embedding not available in this transformers version.")

for layer in unit_test_model.model.layers:
if isinstance(unit_test_model, Gemma3ForCausalLM) and layer.is_sliding:
# Skip layers with sliding window attention, only for Gemma3
continue
layer.self_attn.rotary_emb = unit_test_model.model.rotary_emb

if precision == "half" and torch.cuda.is_available():
unit_test_model = unit_test_model.cuda().half()
elif precision == "half" and not torch.cuda.is_available():
pytest.skip("Half precision test is skipped because CUDA is not available.")
elif precision == "half":
pytest.skip("Half-precision test skipped because CUDA is not available.")

original_press = RandomPressStoreIndices(compression_ratio=0.5)
key_rerotation_press = KeyRerotationPress(press=original_press)
Expand All @@ -47,7 +88,7 @@ def test_rerotate_keys_is_matches_reference_implementation(unit_test_model: Llam
keys,
values,
attentions=None,
kwargs={"position_embeddings": get_rope_embeddings(module, keys)},
kwargs={},
)

indices = original_press.indices
Expand Down