Description
Bug
When using FinchPress
with Qwen models and YARN rope scaling, the rerotation logic for rotary positional embeddings is incorrect. This results in a major degradation in performance
To Reproduce
Use Qwen2.5-7B-Instruct
with the following model config and FinchPress
:
model_kwargs.update({
"max_position_embeddings": 131072,
"rope_scaling": {
"factor": 4.0,
"original_max_position_embeddings": 32768,
"type": "yarn"
}
})
Evaluate on LongBench (e.g., NarrativeQA).
You will observe a sharp drop in performance unless the rerotation logic is corrected.
To reproduce the results with the fix, you can use this branch: https://github.com/giulio98/kvpress/tree/fix/yarn
Repository version
Branch: main
Fix available at: giulio98/fix/yarn
Result Comparison
expected_attention
(compression=0.5 NarrativeQA)
Setting | Score |
---|---|
expected_attention (no yarn) | 25.63 |
expected_attention (yarn) | 24.05 |
expected_attention (no yarn) rerotate | 29.03 |
expected_attention (yarn) rerotate | 28.34 |
expected_attention (no yarn) w/ fix | 29.03 |
expected_attention (yarn) w/ fix | 28.34 |
Note: I did not observe a drop in expected_attention
scores before vs. after the fix. This is likely because the expected_attention benchmark only contains ~20 unique NarrativeQA contexts, which may not be sufficient to expose the bug clearly.
finch
(compression=0.5 NarrativeQA)
Setting | Score |
---|---|
finch (no yarn) | 28.96 |
finch (yarn) | 6.84 |
finch (no yarn) w/ fix | 28.84 |
finch (yarn) w/ fix | 27.78 |
The proposed fix updates the rerotation logic to account for positional deltas in a way that is intended to be agnostic to RoPE scaling (including YARN and LLaMA 3). However, I would appreciate a second opinion or review to ensure the logic is general and robust.
Thanks in advance!