Skip to content

feat: add sageattention #2823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft

feat: add sageattention #2823

wants to merge 14 commits into from

Conversation

NanoCode012
Copy link
Collaborator

@NanoCode012 NanoCode012 commented Jun 23, 2025

Description

Adds SageAttention https://github.com/thu-ml/SageAttention/

Since it has similar interface as sdpa_attention, I used that implementation and flash attention in transformers to cross check.

Motivation and Context

How has this been tested?

No test yet!

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features
    • Added support for SageAttention as a selectable attention mechanism.
    • Introduced a configuration option to enable SageAttention.
  • Bug Fixes
    • Added validation to prevent enabling SageAttention alongside unsupported sample packing.
    • Added validation to enforce hardware compatibility for SageAttention.
  • Documentation
    • Updated configuration descriptions with references to SageAttention for user guidance.

Copy link

coderabbitai bot commented Jun 23, 2025

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

Support for SageAttention, a new attention implementation, has been integrated. This includes configuration schema updates, a monkeypatch for Hugging Face transformers, conditional patch application logic, and internal model loader changes to select SageAttention. Validation prevents incompatible use with sample packing and enforces GPU compute capability requirements. No public APIs were changed; all modifications are internal or configuration-related.

Changes

File(s) Change Summary
src/axolotl/utils/schemas/config.py Added sage_attention config option and validators to disallow sample packing with SageAttention and to enforce GPU compute capability.
src/axolotl/loaders/model.py Updated internal logic to support "sage_attention" as an attention implementation option in _set_attention_config.
src/axolotl/loaders/patch_manager.py Added _apply_sageattn_patches method to conditionally apply SageAttention patch during pre-model load patching.
src/axolotl/monkeypatch/attention/sageattn.py New file implementing SageAttention monkeypatch integration with transformers, including forward and patch functions.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Config
    participant ModelLoader
    participant PatchManager
    participant Transformers
    participant SageAttention

    User->>Config: Set sage_attention=True
    Config->>Config: Validate config (disallow sample_packing + sage_attention, check GPU capability)
    ModelLoader->>Config: Read sage_attention flag
    ModelLoader->>PatchManager: Apply SageAttention patch if enabled
    PatchManager->>Transformers: Register sage_attention_forward
    ModelLoader->>Transformers: Set attn_implementation="sage_attention"
    Transformers->>SageAttention: Use SageAttention for attention calls
Loading

Poem

In the warren, code hops anew,
SageAttention joins the view!
If you pack your samples tight,
Sage will say, "That’s not right!"
GPUs must meet the call,
Or SageAttention won’t run at all.
So patch and load with rabbit glee—
Smarter models, soon you’ll see!
🥕✨


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/axolotl/monkeypatch/attention/sageattn.py (1)

41-112: Well-implemented attention forward function with minor style improvements.

The function correctly handles SageAttention's limitations, GQA/MQA support, and tensor layout transformations. The extensive validation ensures clear error messages for unsupported features.

Apply these minor style improvements suggested by static analysis:

     if (
         kwargs.get("output_attentions", False)
-        or kwargs.get("head_mask", None) is not None
+        or kwargs.get("head_mask") is not None
     ):

-    if kwargs.get("position_ids", None) is not None:
+    if kwargs.get("position_ids") is not None:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0494359 and 2d130f5.

📒 Files selected for processing (4)
  • src/axolotl/loaders/model.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/attention/sageattn.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/attention/sageattn.py

60-60: Use kwargs.get("head_mask") instead of kwargs.get("head_mask", None)

Replace kwargs.get("head_mask", None) with kwargs.get("head_mask")

(SIM910)


77-77: Use kwargs.get("position_ids") instead of kwargs.get("position_ids", None)

Replace kwargs.get("position_ids", None) with kwargs.get("position_ids")

(SIM910)

⏰ Context from checks skipped due to timeout of 90000ms (9)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
🔇 Additional comments (5)
src/axolotl/loaders/model.py (1)

550-554: LGTM! Consistent attention implementation pattern.

The SageAttention integration follows the same pattern as other attention implementations and correctly sets both the model kwargs and config attributes.

src/axolotl/utils/schemas/config.py (2)

497-502: LGTM! Well-documented configuration field.

The SageAttention configuration field follows the established pattern and includes a helpful description with a link to the source repository.


886-894: Good validation logic for incompatible features.

The validator correctly prevents using SageAttention with sample packing, which aligns with the current limitations documented in the monkeypatch implementation.

src/axolotl/monkeypatch/attention/sageattn.py (2)

17-39: LGTM! Robust availability checking and import logic.

The implementation properly handles the optional dependency with clear error messages and appropriate fallback behavior.


115-122: LGTM! Clean patch registration.

The patch function properly checks for SageAttention availability and registers the forward function with transformers.

Copy link
Contributor

github-actions bot commented Jun 23, 2025

@github-actions github-actions bot temporarily deployed to preview June 23, 2025 10:13 Inactive
Copy link

codecov bot commented Jun 23, 2025

Codecov Report

Attention: Patch coverage is 15.90909% with 74 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/monkeypatch/attention/sage_attn.py 0.00% 67 Missing ⚠️
src/axolotl/utils/schemas/config.py 76.92% 3 Missing ⚠️
src/axolotl/loaders/model.py 33.33% 2 Missing ⚠️
src/axolotl/loaders/patch_manager.py 60.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@github-actions github-actions bot temporarily deployed to preview June 23, 2025 11:00 Inactive
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/axolotl/monkeypatch/attention/sageattn.py (2)

32-39: Consider making installation instructions more maintainable.

The hardcoded commit hash in the installation command may become outdated. Consider either:

  1. Using a version tag instead of a commit hash
  2. Adding a comment to remind maintainers to keep this updated
  3. Referring to the official installation documentation
-            "`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
+            "`pip install git+https://github.com/thu-ml/SageAttention.git` or follow installation instructions at https://github.com/thu-ml/SageAttention/"

58-82: Address static analysis hints and document significant limitations.

The error handling is thorough, but there are two style improvements suggested by static analysis tools, and the limitations should be prominently documented.

Apply these style improvements:

-        or kwargs.get("head_mask", None) is not None
+        or kwargs.get("head_mask") is not None
-    if kwargs.get("position_ids", None) is not None:
+    if kwargs.get("position_ids") is not None:

Important: The lack of support for attention_mask and position_ids significantly limits this integration's applicability. Consider adding a prominent warning in the docstring about these constraints.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9a269c4 and 67bc55b.

📒 Files selected for processing (1)
  • src/axolotl/monkeypatch/attention/sageattn.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/monkeypatch/attention/sageattn.py

60-60: Use kwargs.get("head_mask") instead of kwargs.get("head_mask", None)

Replace kwargs.get("head_mask", None) with kwargs.get("head_mask")

(SIM910)


77-77: Use kwargs.get("position_ids") instead of kwargs.get("position_ids", None)

Replace kwargs.get("position_ids", None) with kwargs.get("position_ids")

(SIM910)

⏰ Context from checks skipped due to timeout of 90000ms (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/monkeypatch/attention/sageattn.py (3)

1-30: LGTM! Well-structured conditional import pattern.

The import section properly handles the optional SageAttention dependency with clear documentation and a standard conditional import pattern.


115-123: LGTM! Proper integration with transformers attention registry.

The patch function correctly registers SageAttention with transformers' global attention function registry following the established pattern.


83-112: Verify causal mask inference logic and confirm tensor layout assumptions.

The GQA/MQA handling and tensor operations look correct, but the causal mask inference should be verified.

Please verify that the causal mask inference logic matches transformers' behavior:

#!/bin/bash
# Search for similar causal mask inference patterns in transformers codebase
rg -A 5 -B 5 "is_causal.*query.*shape" --type py
rg -A 5 -B 5 "getattr.*is_causal" --type py

The tensor layout conversion from "HND" (batch, heads, seq_len, dim) to transformers format (batch, seq_len, heads, dim) using transpose(1, 2) appears mathematically correct.

@github-actions github-actions bot temporarily deployed to preview June 23, 2025 11:16 Inactive
@github-actions github-actions bot temporarily deployed to preview June 23, 2025 11:26 Inactive
@github-actions github-actions bot temporarily deployed to preview June 24, 2025 04:21 Inactive
@NanoCode012 NanoCode012 marked this pull request as draft June 24, 2025 11:48
@github-actions github-actions bot temporarily deployed to preview June 25, 2025 02:33 Inactive
@github-actions github-actions bot temporarily deployed to preview June 25, 2025 11:45 Inactive
@github-actions github-actions bot temporarily deployed to preview June 27, 2025 03:52 Inactive
@winglian
Copy link
Collaborator

@NanoCode012 what's the sage vs flash attn VRAM usage?

@NanoCode012
Copy link
Collaborator Author

@winglian , weirdly not getting vram savings as in benchmarks.

Current early wandb result show that: about 20% faster with same vram usage. However, kernel benchmarking showed it using less vram (when <32k context at least).

More runs needs to be done still.

@github-actions github-actions bot temporarily deployed to preview July 1, 2025 04:50 Inactive
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants