-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: add sageattention #2823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: add sageattention #2823
Conversation
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughSupport for SageAttention, a new attention implementation, has been integrated. This includes configuration schema updates, a monkeypatch for Hugging Face transformers, conditional patch application logic, and internal model loader changes to select SageAttention. Validation prevents incompatible use with sample packing and enforces GPU compute capability requirements. No public APIs were changed; all modifications are internal or configuration-related. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Config
participant ModelLoader
participant PatchManager
participant Transformers
participant SageAttention
User->>Config: Set sage_attention=True
Config->>Config: Validate config (disallow sample_packing + sage_attention, check GPU capability)
ModelLoader->>Config: Read sage_attention flag
ModelLoader->>PatchManager: Apply SageAttention patch if enabled
PatchManager->>Transformers: Register sage_attention_forward
ModelLoader->>Transformers: Set attn_implementation="sage_attention"
Transformers->>SageAttention: Use SageAttention for attention calls
Poem
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/monkeypatch/attention/sageattn.py (1)
41-112
: Well-implemented attention forward function with minor style improvements.The function correctly handles SageAttention's limitations, GQA/MQA support, and tensor layout transformations. The extensive validation ensures clear error messages for unsupported features.
Apply these minor style improvements suggested by static analysis:
if ( kwargs.get("output_attentions", False) - or kwargs.get("head_mask", None) is not None + or kwargs.get("head_mask") is not None ): - if kwargs.get("position_ids", None) is not None: + if kwargs.get("position_ids") is not None:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/loaders/model.py
(1 hunks)src/axolotl/loaders/patch_manager.py
(1 hunks)src/axolotl/monkeypatch/attention/sageattn.py
(1 hunks)src/axolotl/utils/schemas/config.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/attention/sageattn.py
60-60: Use kwargs.get("head_mask")
instead of kwargs.get("head_mask", None)
Replace kwargs.get("head_mask", None)
with kwargs.get("head_mask")
(SIM910)
77-77: Use kwargs.get("position_ids")
instead of kwargs.get("position_ids", None)
Replace kwargs.get("position_ids", None)
with kwargs.get("position_ids")
(SIM910)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/loaders/model.py (1)
550-554
: LGTM! Consistent attention implementation pattern.The SageAttention integration follows the same pattern as other attention implementations and correctly sets both the model kwargs and config attributes.
src/axolotl/utils/schemas/config.py (2)
497-502
: LGTM! Well-documented configuration field.The SageAttention configuration field follows the established pattern and includes a helpful description with a link to the source repository.
886-894
: Good validation logic for incompatible features.The validator correctly prevents using SageAttention with sample packing, which aligns with the current limitations documented in the monkeypatch implementation.
src/axolotl/monkeypatch/attention/sageattn.py (2)
17-39
: LGTM! Robust availability checking and import logic.The implementation properly handles the optional dependency with clear error messages and appropriate fallback behavior.
115-122
: LGTM! Clean patch registration.The patch function properly checks for SageAttention availability and registers the forward function with transformers.
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/axolotl/monkeypatch/attention/sageattn.py (2)
32-39
: Consider making installation instructions more maintainable.The hardcoded commit hash in the installation command may become outdated. Consider either:
- Using a version tag instead of a commit hash
- Adding a comment to remind maintainers to keep this updated
- Referring to the official installation documentation
- "`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`" + "`pip install git+https://github.com/thu-ml/SageAttention.git` or follow installation instructions at https://github.com/thu-ml/SageAttention/"
58-82
: Address static analysis hints and document significant limitations.The error handling is thorough, but there are two style improvements suggested by static analysis tools, and the limitations should be prominently documented.
Apply these style improvements:
- or kwargs.get("head_mask", None) is not None + or kwargs.get("head_mask") is not None- if kwargs.get("position_ids", None) is not None: + if kwargs.get("position_ids") is not None:Important: The lack of support for
attention_mask
andposition_ids
significantly limits this integration's applicability. Consider adding a prominent warning in the docstring about these constraints.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/axolotl/monkeypatch/attention/sageattn.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/attention/sageattn.py
60-60: Use kwargs.get("head_mask")
instead of kwargs.get("head_mask", None)
Replace kwargs.get("head_mask", None)
with kwargs.get("head_mask")
(SIM910)
77-77: Use kwargs.get("position_ids")
instead of kwargs.get("position_ids", None)
Replace kwargs.get("position_ids", None)
with kwargs.get("position_ids")
(SIM910)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/monkeypatch/attention/sageattn.py (3)
1-30
: LGTM! Well-structured conditional import pattern.The import section properly handles the optional SageAttention dependency with clear documentation and a standard conditional import pattern.
115-123
: LGTM! Proper integration with transformers attention registry.The patch function correctly registers SageAttention with transformers' global attention function registry following the established pattern.
83-112
: Verify causal mask inference logic and confirm tensor layout assumptions.The GQA/MQA handling and tensor operations look correct, but the causal mask inference should be verified.
Please verify that the causal mask inference logic matches transformers' behavior:
#!/bin/bash # Search for similar causal mask inference patterns in transformers codebase rg -A 5 -B 5 "is_causal.*query.*shape" --type py rg -A 5 -B 5 "getattr.*is_causal" --type pyThe tensor layout conversion from "HND" (batch, heads, seq_len, dim) to transformers format (batch, seq_len, heads, dim) using
transpose(1, 2)
appears mathematically correct.
@NanoCode012 what's the sage vs flash attn VRAM usage? |
@winglian , weirdly not getting vram savings as in benchmarks. Current early wandb result show that: about 20% faster with same vram usage. However, kernel benchmarking showed it using less vram (when <32k context at least). More runs needs to be done still. |
Description
Adds SageAttention https://github.com/thu-ml/SageAttention/
Since it has similar interface as sdpa_attention, I used that implementation and flash attention in transformers to cross check.
Motivation and Context
How has this been tested?
No test yet!
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit