Skip to content

Create base docker images for CUDA 12.8 with custom FlashAttention 3 installed #2685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

winglian
Copy link
Collaborator

@winglian winglian commented May 16, 2025

Summary by CodeRabbit

  • New Features

    • Added support for a new CUDA 12.8.1 configuration with a restricted CUDA architecture list and a "-hopper" suffix.
    • Introduced optional support for Flash Attention version 3 with automatic detection based on GPU architecture.
    • Added configuration option to enable or auto-select Flash Attention version 3.
  • Chores

    • Updated CUDA and Python versions across configurations and Docker base images.
    • Improved Docker image tag naming to include optional suffixes for better differentiation.
    • Enhanced the installation process for specific package dependencies based on CUDA architecture.
  • Tests

    • Updated tests to parameterize and verify behavior with different Flash Attention versions.
    • Improved test cleanup to reset monkeypatches related to Flash Attention utilities.
    • Refined test classes and decorators for streamlined testing processes.

Copy link

coderabbitai bot commented May 16, 2025

"""

Walkthrough

The changes update CUDA versions and add a new matrix entry with a "hopper" suffix and restricted CUDA architectures in workflows. Dockerfiles and CI scripts are modified to install a specific prebuilt flash-attn wheel for CUDA 12.6 and PyTorch 2.6. The model loader adds conditional support for Flash Attention 3, controlled by a new config option. Tests are updated for this feature and to reset monkeypatches.

Changes

File(s) Change Summary
.github/workflows/base.yml, .github/workflows/multi-gpu-e2e.yml Updated CUDA versions in matrix entries; added new CUDA 12.6.3 matrix entry with "-hopper" suffix and restricted CUDA arch; removed nightly_build flag; updated environment variable exports; modified Docker image tag generation to append suffix.
docker/Dockerfile-base, cicd/Dockerfile.jinja Changed CUDA, Python, and PyTorch versions; added curl to apt-get installs; updated flash-attn installation to download and install a prebuilt wheel when TORCH_CUDA_ARCH_LIST is "9.0+PTX" and PyTorch version is 2.6.0 or 2.7.0.
src/axolotl/utils/models.py Added conditional support for Flash Attention 3 in ModelLoader.apply_patches based on config and GPU compute capability; replaced default flash-attn functions with FA3 versions when enabled; minor comments added.
src/axolotl/utils/schemas/config.py Added new optional config field use_flash_attention_3 with default None to control Flash Attention 3 usage.
tests/conftest.py Added transformers.modeling_flash_attention_utils to modules reloaded in pytest fixture to reset monkeypatches after tests; restore original flash-attn functions after tests.
tests/e2e/test_packing_loss.py Converted test class from unittest to pytest style; removed with_temp_dir decorator from test method; removed unittest and with_temp_dir imports.
tests/e2e/multigpu/test_llama.py Parameterized test_lora_ddp_packed test method over use_flash_attention_3 values [False, "auto"]; updated test config accordingly.

Sequence Diagram(s)

sequenceDiagram
    participant Config as Config File
    participant ModelLoader
    participant FlashAttnUtils as FlashAttention Utils
    participant GPU

    Config->>ModelLoader: Provide use_flash_attention_3 setting
    ModelLoader->>GPU: Query compute capability
    alt use_flash_attention_3 enabled (True or auto + GPU >= 90)
        ModelLoader->>FlashAttnUtils: Import FA3 functions
        ModelLoader->>FlashAttnUtils: Override default flash-attn funcs with FA3
    else
        ModelLoader->>FlashAttnUtils: Use default flash-attn functions
    end
Loading

Poem

🐇 In code's green meadow, new hops appear,
CUDA climbs higher, the "hopper" is near.
Flash Attention 3 leaps in with grace,
Patching models to quicken the race.
Tests now dance in pytest's bright light,
Docker tags shine with suffixes bright!
🌿✨
"""


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between d6f64a3 and 9bdf4b1.

📒 Files selected for processing (1)
  • src/axolotl/utils/models.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/utils/models.py
⏰ Context from checks skipped due to timeout of 90000ms (10)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
docker/Dockerfile-base (1)

41-48: Improve code indentation for readability.

The indentation in the conditional block is inconsistent, making it harder to read. Consider standardizing the indentation.

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
-        git clone https://github.com/Dao-AILab/flash-attention.git; \
-        git checkout v2.7.4.post1; \
-        cd flash-attention/hopper; \
-        FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \
-    elif if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
-        pip3 install flash-attn==2.7.4.post1; \
-    fi
+    git clone https://github.com/Dao-AILab/flash-attention.git; \
+    git checkout v2.7.4.post1; \
+    cd flash-attention/hopper; \
+    FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \
+elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
+    pip3 install flash-attn==2.7.4.post1; \
+fi
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 3a5b495 and 0484f64.

📒 Files selected for processing (2)
  • .github/workflows/base.yml (2 hunks)
  • docker/Dockerfile-base (1 hunks)
🧰 Additional context used
🪛 actionlint (1.7.7)
.github/workflows/base.yml

97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string}

(expression)

⏰ Context from checks skipped due to timeout of 90000ms (14)
  • GitHub Check: build-base (128, 12.8.1, 3.11, next, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: build-base (128, 12.8.1, 3.11, nightly, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (3)
.github/workflows/base.yml (2)

50-50: CUDA version update looks good.

Updating to the latest CUDA 12.8.1 version ensures compatibility with newer drivers and GPUs.


55-61: Well-structured Hopper-specific configuration.

This new matrix entry creates a specialized Docker image optimized specifically for NVIDIA Hopper architecture (H100 GPUs) by:

  1. Using a restricted CUDA architecture list "9.0+PTX"
  2. Adding the "-hopper" suffix for clear identification
  3. Using PyTorch 2.7.0 which has optimizations for this architecture

This corresponds with the FlashAttention build logic in the Dockerfile.

docker/Dockerfile-base (1)

41-46: Good optimization for Hopper architecture.

Building flash-attention from source with Hopper-specific optimizations is a good approach. The special flags disable features not needed for Hopper and increase build parallelism.

Copy link

codecov bot commented May 16, 2025

Codecov Report

Attention: Patch coverage is 74.07407% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/utils/models.py 73.07% 7 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian winglian mentioned this pull request May 16, 2025
5 tasks
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
docker/Dockerfile-base (2)

41-50: Consider adding explanatory comments for maintainability.

While the implementation is correct, adding comments explaining why different installation methods are needed for different architectures would improve maintainability for future contributors.

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
+        # Build flash-attention from source for Hopper architecture (9.0+PTX) \
+        # with specific optimizations for this architecture \
         git clone https://github.com/Dao-AILab/flash-attention.git; \
         cd flash-attention; \
         git checkout v2.7.4.post1; \
         cd hopper; \
         FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \
         cd ../..; \
     elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
+        # Use pre-built package for PyTorch 2.7.0 with other architectures \
         pip3 install flash-attn==2.7.4.post1; \
     fi

41-50: Add error handling for git operations.

The current implementation doesn't check if git operations succeed. Consider adding basic error checking to make debugging easier if the build fails.

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
-        git clone https://github.com/Dao-AILab/flash-attention.git; \
-        cd flash-attention; \
-        git checkout v2.7.4.post1; \
-        cd hopper; \
-        FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \
-        cd ../..; \
+        git clone https://github.com/Dao-AILab/flash-attention.git && \
+        cd flash-attention && \
+        git checkout v2.7.4.post1 && \
+        cd hopper && \
+        FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install && \
+        cd ../.. || echo "Failed to build flash-attention from source" && exit 1; \
     elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
         pip3 install flash-attn==2.7.4.post1; \
     fi
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between e499848 and bc98ad2.

📒 Files selected for processing (1)
  • docker/Dockerfile-base (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (14)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: build-base (128, 12.8.1, 3.11, next, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: pre-commit
🔇 Additional comments (2)
docker/Dockerfile-base (2)

41-47: Custom build path for Hopper architecture looks good.

The implementation correctly builds FlashAttention from source when targeting the Hopper architecture (CUDA 9.0+PTX). Using environment variables FLASH_ATTENTION_DISABLE_SM80=TRUE and FLASH_ATTENTION_DISABLE_FP8=TRUE is appropriate for optimizing for this specific architecture.

The directory navigation has been properly implemented with the correct cd commands, addressing previous issues.


48-48: Conditional logic for PyTorch 2.7.0 is correctly implemented.

The conditional statement has the correct syntax, with no duplicate "if" keywords that were present in previous versions.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
docker/Dockerfile-base (1)

41-50: ⚠️ Potential issue

Missing fallback for other PyTorch/CUDA combinations.

The current implementation only installs flash-attention when TORCH_CUDA_ARCH_LIST is "9.0+PTX" or when PYTORCH_VERSION is "2.7.0". However, the workflow runs with other combinations that don't satisfy either condition.

Add an else clause to ensure flash-attention is always installed:

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
         git clone https://github.com/Dao-AILab/flash-attention.git; \
         cd flash-attention; \
         git checkout v2.7.4.post1; \
         cd hopper; \
         FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=64 python setup.py install; \
         cd ../..; \
     elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
         pip3 install flash-attn==2.7.4.post1; \
-    fi
+    else \
+        # Fallback for all other PyTorch/CUDA combinations
+        pip3 install flash-attn==2.7.4.post1; \
+    fi
🧹 Nitpick comments (1)
docker/Dockerfile-base (1)

46-46: Consider using environment variable for MAX_JOBS.

You're setting MAX_JOBS=64 directly in the build command, but there's a MAX_JOBS ARG defined at line 4 that's set to 4 by default. For consistency, consider using that variable.

-        FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=64 python setup.py install; \
+        FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=$MAX_JOBS python setup.py install; \
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0f97473 and b70aa7b.

📒 Files selected for processing (1)
  • docker/Dockerfile-base (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (13)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: build-base (128, 12.8.1, 3.11, next, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: pre-commit
🔇 Additional comments (1)
docker/Dockerfile-base (1)

41-47: Good implementation for Hopper architecture-specific build.

The changes correctly set up building flash-attention from source for the "9.0+PTX" CUDA architecture (NVIDIA Hopper/H100 GPUs). The code properly:

  • Clones the repository
  • Checks out the specific version
  • Builds from the hopper directory with appropriate flags
  • Navigates back to the original directory

cd flash-attention; \
git checkout v2.7.4.post1; \
cd hopper; \
FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=64 python setup.py install; \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this MAX_JOBS too high? What kind of hardware is the CPU building on?

RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
git clone https://github.com/Dao-AILab/flash-attention.git; \
cd flash-attention; \
git checkout v2.7.4.post1; \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that this release is Jan 2025 and there's been a lot of changes to hopper dir since then.

Comment on lines +55 to +56
- cuda: "126"
cuda_version: 12.6.3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taken from readme

We highly recommend CUDA 12.8 for best performance.

As only pytorch 2.7 has 12.8, should we swap to that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fa3 doesn't compile with torch 2.7.0 due to an error that I don't have offhand, but is related to an API change in torch.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
docker/Dockerfile-base (1)

41-50: Missing fallback for other PyTorch versions

The installation logic only handles two cases:

  1. When TORCH_CUDA_ARCH_LIST="9.0+PTX" (build from source)
  2. When PYTORCH_VERSION="2.7.0" (pip install)

Other combinations from the workflow matrix (PyTorch 2.5.1, 2.6.0, nightly, next) will not have flash-attention installed.

#!/bin/bash
# Extract matrix combinations from workflow file to verify coverage
grep -r "PYTORCH_VERSION" .github/workflows/ | grep -v "2.7.0"
grep -r "TORCH_CUDA_ARCH_LIST" .github/workflows/ | grep -v "9.0+PTX"

Consider adding a fallback else clause:

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
         git clone https://github.com/Dao-AILab/flash-attention.git; \
         cd flash-attention; \
         git checkout v2.7.4.post1; \
         cd hopper; \
         FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=32 python setup.py install; \
         cd ../..; \
     elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
         pip3 install flash-attn==2.7.4.post1; \
+    else \
+        # Default installation for other PyTorch/CUDA combinations
+        pip3 install flash-attn==2.7.4.post1; \
     fi
🧹 Nitpick comments (1)
docker/Dockerfile-base (1)

41-47: Consider adding error handling

The multi-step process of cloning, checking out, and building from source lacks error handling. If any step fails, the build might proceed with an incomplete or missing flash-attention installation.

Consider adding error checking:

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
+        set -e && \
         git clone https://github.com/Dao-AILab/flash-attention.git; \
         cd flash-attention; \
         git checkout v2.7.4.post1; \
         cd hopper; \
         FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=32 python setup.py install; \
         cd ../..; \

This ensures that if any command fails, the entire RUN instruction will fail, preventing partial installations.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between b70aa7b and f822c29.

📒 Files selected for processing (1)
  • docker/Dockerfile-base (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (12)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: pre-commit
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (2)
docker/Dockerfile-base (2)

41-47: Building flash-attention from source for Hopper GPUs looks good

The implementation correctly clones and builds the flash-attention package from source when targeting Hopper GPUs (CUDA arch 9.0+PTX). The directory navigation has been fixed as suggested in the previous review.


46-46: MAX_JOBS value has been reduced

Good adjustment of MAX_JOBS from 128 to 32, which addresses the previous concern about setting it too high. This provides a more reasonable balance between build parallelism and resource usage.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
docker/Dockerfile-base (1)

41-50: Add a fallback clause for flash-attention installation

The current implementation only handles when TORCH_CUDA_ARCH_LIST is "9.0+PTX" or when PYTORCH_VERSION is "2.7.0". You need to add a fallback clause for cases where neither condition is met to ensure flash-attention is always installed.

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
         git clone https://github.com/Dao-AILab/flash-attention.git; \
         cd flash-attention; \
         git checkout v2.7.4.post1; \
         cd hopper; \
         FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=${MAX_JOBS} python setup.py install; \
         cd ../..; \
     elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
         pip3 install flash-attn==2.7.4.post1; \
     fi
+    else \
+        # Fallback for all other PyTorch/CUDA combinations
+        pip3 install flash-attn==2.7.4.post1; \
+    fi
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 74f0651 and c207ecb.

📒 Files selected for processing (1)
  • docker/Dockerfile-base (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (12)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: pre-commit
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (3)
docker/Dockerfile-base (3)

46-47: Approve usage of MAX_JOBS variable

Good improvement using the configurable MAX_JOBS ARG instead of a hardcoded value, and the proper directory navigation with cd .. commands has been implemented correctly.


1-2: Approve CUDA version updates

The updates to CUDA version (12.4.1) and empty CUDNN version string are appropriate for modern ML environments.


10-12: Approve Python and PyTorch updates

The updates to Python 3.11 and PyTorch 2.5.1 with CUDA 124 are consistent with the objective of creating updated base images.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
docker/Dockerfile-base (1)

41-47: 🛠️ Refactor suggestion

Add a fallback for other matrix combinations.

The updated logic now handles two specific cases:

  1. When TORCH_CUDA_ARCH_LIST="9.0+PTX", it installs the custom FlashAttention 3 beta wheel
  2. When PYTORCH_VERSION="2.7.0", it installs flash-attn 2.7.4.post1

However, it still doesn't have a fallback for other matrix combinations (e.g., PyTorch 2.5.1 with multi-SM architecture lists).

RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
        curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
        pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
        rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
    elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
        pip3 install flash-attn==2.7.4.post1; \
+    else \
+        # Fallback for all other PyTorch/CUDA combinations
+        pip3 install flash-attn==2.7.4.post1; \
    fi
🧹 Nitpick comments (1)
docker/Dockerfile-base (1)

41-44: Consider making the wheel URL more configurable.

The wheel URL has hardcoded components (version 3.0.0b1, Python 3.11) but also uses variables (${CUDA}, ${PYTORCH_VERSION}). For better maintainability, consider making more components configurable through ARG/ENV variables.

+ARG FLASH_ATTN_VERSION="3.0.0b1"
+ARG PYTHON_WHEEL_TAG="cp311-cp311"

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
-        curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
-        pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
-        rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
+        curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-${FLASH_ATTN_VERSION}-${PYTHON_WHEEL_TAG}-linux_x86_64.whl; \
+        pip3 install --no-cache-dir flash_attn_interface-${FLASH_ATTN_VERSION}-${PYTHON_WHEEL_TAG}-linux_x86_64.whl; \
+        rm flash_attn_interface-${FLASH_ATTN_VERSION}-${PYTHON_WHEEL_TAG}-linux_x86_64.whl; \
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between aeea1ed and 72ac14c.

📒 Files selected for processing (1)
  • docker/Dockerfile-base (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (11)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (3)
docker/Dockerfile-base (3)

10-12: LGTM: Version updates are consistent with each other.

The Python and PyTorch version updates are properly synchronized, and the CUDA shorthand "124" correctly matches the CUDA version 12.4.1.


31-31: LGTM: Added curl dependency for wheel download.

Adding curl to the pip install list is necessary for the flash-attn wheel download in the updated installation logic.


1-2:

❓ Verification inconclusive

Version inconsistency between PR title and implementation.

The PR title mentions "CUDA 12.8" but the Dockerfile specifies CUDA 12.4.1. Also, the CUDNN_VERSION has been changed from a specific version to an empty string. Please clarify if this is intentional.


🏁 Script executed:

#!/bin/bash
# Description: Check CUDA references in other files to ensure consistency
echo "Searching for CUDA version references in workflow files:"
grep -r "CUDA.*12.8" --include="*.yml" .github
grep -r "CUDA.*12.4" --include="*.yml" .github

Length of output: 220


🏁 Script executed:

#!/bin/bash
# Verify all CUDA_VERSION declarations
echo "🔍 Checking for all ARG CUDA_VERSION in the repo"
grep -R "ARG CUDA_VERSION" -n .

# Verify all CUDNN_VERSION declarations
echo "🔍 Checking for all ARG CUDNN_VERSION in the repo"
grep -R "ARG CUDNN_VERSION" -n .

# Look for any literal “12.8” occurrences
echo "🔍 Searching for any references to ‘12.8’"
grep -R "12\.8" -n .

# Look for any literal “12.4” occurrences beyond Dockerfile-base
echo "🔍 Searching for any references to ‘12.4’"
grep -R "12\.4" -n .

Length of output: 4101


Clarify CUDA & cuDNN version alignment across Dockerfiles & workflows

It looks like the PR introduces CUDA 12.8.1 in the “next” and “nightly” Dockerfiles but leaves the main docker/Dockerfile-base at CUDA 12.4.1 with an empty CUDNN_VERSION. Meanwhile, CI workflows reference both versions.

• docker/Dockerfile-base (lines 1–2):
ARG CUDA_VERSION="12.4.1"
ARG CUDNN_VERSION=""
• docker/Dockerfile-base-next & docker/Dockerfile-base-nightly:
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
• .github/workflows/**/*: jobs defined for both cuda_version: 12.4.1 and cuda_version: 12.8.1

Please confirm the intended strategy:

  1. Keep the stable base at CUDA 12.4.1 and use separate images for CUDA 12.8.1?
  2. Or bump the main Dockerfile-base to CUDA 12.8.1 (and pin CUDNN)?
  3. Should the blank CUDNN_VERSION in the main Dockerfile be restored or documented?

Once clarified, we can align the ARGs and workflows accordingly.

Comment on lines 41 to 44
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify FlashAttention 3 wheel availability.

The code downloads a specific FlashAttention 3.0.0b1 wheel that's hardcoded for Python 3.11. Verify that this wheel is publicly available and will remain accessible at the specified URL.


🏁 Script executed:

#!/bin/bash
# Description: Test if the FlashAttention 3.0.0b1 wheel URL is accessible
curl -I -L https://d1dttdx32dkk5p.cloudfront.net/fa3/cu124/torch-2.5.1/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl

Length of output: 746


FlashAttention wheel download is inaccessible (HTTP 403)
The hardcoded FlashAttention 3.0.0b1 wheel URL returns a 403 Forbidden, so the Docker build will fail at the curl/pip install step. This wheel does not appear to be publicly available.

Please update or replace this step:

  • Confirm the correct public URL for the FlashAttention 3.0.0b1 wheel matching ${CUDA} and Python version, or host a mirror that allows downloads.
  • Alternatively, fall back to installing via pip install flash-attn (if available) or include a prebuilt wheel in the repo.

Locations to update:

  • docker/Dockerfile-base, lines 41–44: the curl -L -O ...flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl and subsequent pip3 install commands.
🤖 Prompt for AI Agents
In docker/Dockerfile-base around lines 41 to 44, the current step downloads a
FlashAttention 3.0.0b1 wheel from a hardcoded URL that returns a 403 Forbidden
error, causing the Docker build to fail. To fix this, replace the curl download
and pip install commands with either a valid public URL for the wheel matching
the CUDA and Python version or switch to installing FlashAttention via pip
directly using 'pip install flash-attn'. Alternatively, include a prebuilt wheel
in the repository and install from there. Update the Dockerfile accordingly to
ensure the wheel is accessible and installable during the build.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
docker/Dockerfile-base (2)

41-47: ⚠️ Potential issue

Fix missing fallback for other PyTorch/CUDA combinations.

The conditional logic only handles two specific cases:

  1. When TORCH_CUDA_ARCH_LIST="9.0+PTX"
  2. When PYTORCH_VERSION="2.7.0"

There's no fallback for other matrix combinations, which means flash-attn won't be installed in those cases.

Add an else clause to ensure flash-attention is installed for all combinations:

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
         curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
         pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
         rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
     elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
         pip3 install flash-attn==2.7.4.post1; \
+    else \
+        # Fallback for all other PyTorch/CUDA combos
+        pip3 install flash-attn==2.7.4.post1; \
     fi

41-44: ⚠️ Potential issue

Fix inaccessible wheel URL (403 Forbidden).

The wheel URL returns a 403 Forbidden error, which will cause Docker builds to fail. This issue was previously identified but hasn't been addressed.

Consider these alternatives:

  1. Obtain and use a publicly accessible URL for the wheel
  2. Host the wheel in your own repository or storage
  3. Build the wheel from source
  4. Use a publicly available version from PyPI

Example fix option:

 RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
-        curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
-        pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
-        rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
+        # Build from source for Hopper architecture
+        git clone https://github.com/Dao-AILab/flash-attention.git && \
+        cd flash-attention && \
+        git checkout v3.0.0 && \  # Use appropriate tag for FlashAttention 3
+        FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=4 python setup.py install && \
+        cd .. && \
+        rm -rf flash-attention; \
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 72ac14c and eb3de69.

📒 Files selected for processing (1)
  • docker/Dockerfile-base (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (12)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: pre-commit

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
cicd/Dockerfile.jinja (1)

35-39: Conditionally installing Flash Attention 3 wheel.

The conditional installation of the Flash Attention 3 wheel for the specific combination of PyTorch 2.6.0 and CUDA 12.6 is appropriate. Using a prebuilt wheel from CloudFront is efficient.

However, consider the following improvements:

  1. Add a comment explaining why this specific wheel is needed for Hopper GPUs
  2. Consider using a more version-controlled source for the wheel file or add a hash verification step
  3. Add error handling in case the download fails
 RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
+        # Install prebuilt Flash Attention 3 wheel for Hopper GPUs (compute capability 9.0+) \
         curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
+        if [ ! -f flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl ]; then \
+            echo "Failed to download Flash Attention 3 wheel"; \
+            exit 1; \
+        fi; \
         pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
         rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
     fi
src/axolotl/utils/models.py (3)

635-636: Consider simplifying the nested condition.

The nested if statements could be simplified for better readability.

-            elif self.cfg.use_flash_attention_3 == "auto":
-                if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
+            elif self.cfg.use_flash_attention_3 == "auto" and int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
🧰 Tools
🪛 Ruff (0.11.9)

635-636: Use a single if statement instead of nested if statements

(SIM102)


723-724: Note the TODO for deprecation.

This TODO comment indicates intent to deprecate this code path soon. Consider creating a tracking issue if one doesn't already exist.


731-732: Another deprecation TODO marked.

Similar to the previous TODO, this marks another code section for future deprecation. Consider consolidating these TODOs into a single tracking issue for better maintainability.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between eb3de69 and e46f094.

📒 Files selected for processing (7)
  • .github/workflows/multi-gpu-e2e.yml (1 hunks)
  • cicd/Dockerfile.jinja (1 hunks)
  • docker/Dockerfile-base (2 hunks)
  • src/axolotl/utils/models.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • tests/conftest.py (1 hunks)
  • tests/e2e/test_packing_loss.py (3 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/utils/schemas/config.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • docker/Dockerfile-base
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/e2e/test_packing_loss.py (2)
tests/e2e/utils.py (1)
  • check_tensorboard (135-149)
tests/conftest.py (1)
  • temp_dir (414-419)
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py

635-636: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (10)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
🔇 Additional comments (7)
tests/conftest.py (1)

461-461: Good addition to the module reset list.

Including "transformers.modeling_flash_attention_utils" in the modules to reset ensures that any monkeypatches to Flash Attention 3 utilities are properly cleaned up between tests, maintaining test isolation when testing the new Flash Attention 3 functionality.

.github/workflows/multi-gpu-e2e.yml (1)

35-41: LGTM: Proper matrix configuration for Hopper GPUs.

The addition of a new matrix entry for CUDA 12.6.3 with PyTorch 2.6.0 and the "-hopper" suffix aligns with the Flash Attention 3 support being added in this PR. This configuration will allow testing on Hopper architecture GPUs.

tests/e2e/test_packing_loss.py (4)

8-8: Good replacement of unittest with pytest.

Replacing unittest with pytest aligns with the project's testing approach. This change is appropriate.


23-23: Clean class refactoring.

Refactoring the test class to follow pytest patterns rather than using unittest.TestCase improves consistency with the project's testing style.


28-32: Excellent test parameterization for Flash Attention 3.

Using pytest's parameterization to test both without Flash Attention 3 and with "auto" mode ensures the feature works correctly in both configurations. This is a good testing practice.


60-60: Good configuration addition.

Adding the use_flash_attention_3 parameter to the test configuration ensures that the test properly exercises the new Flash Attention 3 support.

src/axolotl/utils/models.py (1)

632-652: Well-implemented Flash Attention 3 integration.

The conditional support for Flash Attention 3 is thoughtfully implemented, with proper checks for both configuration intent and hardware capability. The code intelligently enables FA3 based on the use_flash_attention_3 config setting (either explicitly enabled or automatically enabled for Hopper GPUs with compute capability 90+).

🧰 Tools
🪛 Ruff (0.11.9)

635-636: Use a single if statement instead of nested if statements

(SIM102)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
src/axolotl/utils/models.py (3)

632-660: Good implementation of Flash Attention 3 support.

The conditional Flash Attention 3 support is well-implemented with automatic detection for Hopper GPUs. The wrapper functions elegantly handle the API differences between versions.

Two minor improvements worth considering:

  1. The nested if statements could be simplified for better readability
  2. Consider adding a brief comment explaining the key differences between FA2 and FA3 that necessitate these wrappers
-            use_fa3 = False
-            if self.cfg.use_flash_attention_3 is True:
-                use_fa3 = True
-            elif self.cfg.use_flash_attention_3 == "auto":
-                if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
-                    # FA3 is only available on Hopper GPUs and newer
-                    use_fa3 = True
+            # FA3 is enabled explicitly or automatically for Hopper GPUs (compute capability 90+)
+            use_fa3 = (self.cfg.use_flash_attention_3 is True or 
+                      (self.cfg.use_flash_attention_3 == "auto" and 
+                       int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90))
🧰 Tools
🪛 Ruff (0.11.9)

635-636: Use a single if statement instead of nested if statements

(SIM102)


731-732: Consider providing more context for the TODO comment.

This TODO comment suggests future deprecation, but doesn't provide context about why or when it should be deprecated, or reference any tracking issue. This makes it harder for contributors to understand the roadmap.

-            # TODO deprecate soon
+            # TODO: Deprecate btlm flash attention patch soon, as it's being replaced by the standard implementation in #ISSUE_NUMBER

739-740: Consider providing more context for the TODO comment.

Similar to the previous TODO, this comment lacks context about the deprecation plan.

-            # TODO deprecate soon
+            # TODO: Deprecate stablelm_epoch flash attention patch for sample packing soon, as it's being replaced by the standard implementation in #ISSUE_NUMBER
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between e46f094 and c039543.

📒 Files selected for processing (2)
  • src/axolotl/utils/models.py (2 hunks)
  • tests/conftest.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/conftest.py
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py

635-636: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (17)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
  • GitHub Check: pre-commit
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: pre-commit
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
src/axolotl/utils/models.py (3)

632-660: Well-implemented Flash Attention 3 conditional support.

The implementation intelligently enables FA3 based on GPU capabilities while maintaining backward compatibility. The wrapper functions handle the interface differences between Flash Attention 2 and 3 appropriately.

Consider simplifying the nested if statements for determining when to use FA3:

-            use_fa3 = False
-            if self.cfg.use_flash_attention_3 is True:
-                use_fa3 = True
-            elif self.cfg.use_flash_attention_3 == "auto":
-                if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:
-                    # FA3 is only available on Hopper GPUs and newer
-                    use_fa3 = True
+            use_fa3 = self.cfg.use_flash_attention_3 is True or (
+                self.cfg.use_flash_attention_3 == "auto" and
+                int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90  # Hopper GPUs or newer
+            )
🧰 Tools
🪛 Ruff (0.11.9)

635-636: Use a single if statement instead of nested if statements

(SIM102)


731-732: Consider expanding on the deprecation plan.

The TODO comment doesn't specify the reason for deprecation or a timeline. Consider adding more context.

-            # TODO deprecate soon
+            # TODO: Deprecate soon - Replace with the new Flash Attention 3 implementation (issue #XXXX)

739-740: Consider expanding on the deprecation plan.

Similar to the previous TODO, adding more context would be helpful.

-            # TODO deprecate soon
+            # TODO: Deprecate soon - Replace with the new Flash Attention 3 implementation (issue #XXXX) 
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 701c6e2 and b995828.

📒 Files selected for processing (1)
  • src/axolotl/utils/models.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py

635-636: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (14)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/axolotl/utils/models.py (1)

731-732: Consider providing more context in TODO comments.

These TODO comments about deprecating specific code paths don't provide enough information about when or why they should be deprecated.

Consider expanding these comments to include:

  • Reason for deprecation
  • Approximate timeline
  • Alternative approach or replacement

Also applies to: 739-740

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between b995828 and 9a9c905.

📒 Files selected for processing (3)
  • src/axolotl/utils/models.py (2 hunks)
  • tests/e2e/multigpu/test_llama.py (2 hunks)
  • tests/e2e/test_packing_loss.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tests/e2e/test_packing_loss.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/e2e/multigpu/test_llama.py (2)
tests/e2e/multigpu/test_gemma3.py (1)
  • test_lora_ddp_packed (36-100)
tests/conftest.py (1)
  • temp_dir (414-419)
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py

635-636: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (14)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: pre-commit
🔇 Additional comments (5)
tests/e2e/multigpu/test_llama.py (3)

104-107: Good addition of parameterized testing for Flash Attention 3 support.

Adding a new parameter to test both disabled and automatic detection settings is a good approach for testing this new feature.


109-110: Correctly updated method signature for new parameter.

The method signature correctly includes the new parameter from the parameterization.


147-147: Properly passing the parameter to the configuration.

This ensures the Flash Attention 3 setting will be passed through to the model loader where it's needed.

src/axolotl/utils/models.py (2)

632-660: Well-implemented conditional Flash Attention 3 support.

The implementation nicely handles:

  1. Explicit enabling via configuration
  2. Automatic detection based on GPU capability (CUDA 9.0+ for Hopper)
  3. Package availability check
  4. Function signature adaptation through wrapper functions
  5. Appropriate logging

However, the nested if statements at lines 635-636 were flagged in static analysis. While functionally correct, consider a cleaner approach:

-            if self.cfg.use_flash_attention_3 is True:
-                use_fa3 = True
-            elif self.cfg.use_flash_attention_3 == "auto":
-                if torch.cuda.get_device_capability() >= (9, 0):
-                    # FA3 is only available on Hopper GPUs and newer
-                    use_fa3 = True
+            use_fa3 = (self.cfg.use_flash_attention_3 is True or 
+                      (self.cfg.use_flash_attention_3 == "auto" and 
+                       torch.cuda.get_device_capability() >= (9, 0)))  # FA3 is only available on Hopper GPUs and newer

This is optional as the current implementation is also clear and readable.

🧰 Tools
🪛 Ruff (0.11.9)

635-636: Use a single if statement instead of nested if statements

(SIM102)


645-651: Appropriate wrapper functions to handle API differences.

The wrapper functions correctly handle removing the dropout_p parameter and extracting just the first return value from the FA3 functions. This approach properly bridges the differences between the Flash Attention 3 interface and what transformers expects.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
src/axolotl/utils/models.py (3)

632-666: Well-implemented Flash Attention 3 support with GPU capability check

The implementation conditionally enables Flash Attention 3 based on configuration settings and hardware capabilities, with proper fallback mechanisms in place. The wrapper functions handle the interface differences between FA2 and FA3 effectively.

The nested if statements could be simplified into a single conditional:

-            if self.cfg.use_flash_attention_3 is True:
-                use_fa3 = True
-            elif self.cfg.use_flash_attention_3 == "auto":
-                if torch.cuda.get_device_capability() >= (9, 0):
-                    # FA3 is only available on Hopper GPUs and newer
-                    use_fa3 = True
+            use_fa3 = (self.cfg.use_flash_attention_3 is True or 
+                      (self.cfg.use_flash_attention_3 == "auto" and 
+                       torch.cuda.get_device_capability() >= (9, 0)))
+            # FA3 is only available on Hopper GPUs and newer when in auto mode
🧰 Tools
🪛 Ruff (0.11.9)

635-636: Use a single if statement instead of nested if statements

(SIM102)


737-738: Consider adding more context to the TODO comment

The TODO comment is somewhat vague. Adding more details about why this needs to be deprecated and the preferred alternative would help future developers.

-            # TODO deprecate soon
+            # TODO deprecate soon - Replace with proper Flash Attention integration for all model types

745-746: Consider adding more context to the TODO comment

Same issue as the previous TODO - providing more context would be helpful for future maintenance.

-            # TODO deprecate soon
+            # TODO deprecate soon - Replace with proper Flash Attention integration for all model types
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 9a9c905 and 4199074.

📒 Files selected for processing (2)
  • src/axolotl/utils/models.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/utils/schemas/config.py
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py

635-636: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (15)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
src/axolotl/utils/models.py (3)

632-639: Good implementation of conditional Flash Attention 3 with auto-detection.

The code correctly determines when to use Flash Attention 3 based on either explicit configuration (True) or auto-detection ("auto") by checking the CUDA device capability. This ensures FA3 is only enabled on compatible hardware (Hopper GPUs with CUDA capability 9.0+).

I do have one minor suggestion to improve readability:

-            if self.cfg.use_flash_attention_3 is True:
-                use_fa3 = True
-            elif self.cfg.use_flash_attention_3 == "auto":
-                if torch.cuda.get_device_capability() >= (9, 0):
-                    # FA3 is only available on Hopper GPUs and newer
-                    use_fa3 = True
+            if self.cfg.use_flash_attention_3 is True or (
+                self.cfg.use_flash_attention_3 == "auto" and 
+                torch.cuda.get_device_capability() >= (9, 0)  # FA3 is only available on Hopper GPUs and newer
+            ):
+                use_fa3 = True
🧰 Tools
🪛 Ruff (0.11.9)

635-636: Use a single if statement instead of nested if statements

(SIM102)


737-738: Consider adding more context to the TODO comment.

The "TODO deprecate soon" comment should include more information about why this needs to be deprecated and possibly a timeline or reference.

-            # TODO deprecate soon
+            # TODO: Deprecate btlm flash attention patch soon as it's being replaced by the more generic FA3 approach

745-746: Consider adding more context to the TODO comment.

Similar to the previous comment, this TODO should be more descriptive about why it needs to be deprecated.

-            # TODO deprecate soon
+            # TODO: Deprecate stablelm_epoch flash attention patch soon as it's being replaced by the more generic FA3 approach
.github/workflows/base.yml (1)

55-61: Introduce Hopper-specific build variant
You’ve added a CUDA 12.6.3 + PyTorch 2.6.0 matrix entry targeting only compute capability 9.0+PTX with a -hopper suffix. This correctly scopes the build to Hopper GPUs. Consider trimming the leading dash from the suffix value (e.g. suffix: "hopper") and moving the dash into the tag expression—this keeps labels and values consistent and avoids embedding formatting in data.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 4199074 and d6f64a3.

📒 Files selected for processing (9)
  • .github/workflows/base.yml (2 hunks)
  • .github/workflows/multi-gpu-e2e.yml (1 hunks)
  • cicd/Dockerfile.jinja (1 hunks)
  • docker/Dockerfile-base (2 hunks)
  • src/axolotl/utils/models.py (2 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • tests/conftest.py (4 hunks)
  • tests/e2e/multigpu/test_llama.py (2 hunks)
  • tests/e2e/test_packing_loss.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
  • src/axolotl/utils/schemas/config.py
  • cicd/Dockerfile.jinja
  • .github/workflows/multi-gpu-e2e.yml
  • tests/conftest.py
  • tests/e2e/test_packing_loss.py
  • tests/e2e/multigpu/test_llama.py
  • docker/Dockerfile-base
🧰 Additional context used
🪛 actionlint (1.7.7)
.github/workflows/base.yml

97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string}

(expression)

🪛 Ruff (0.11.9)
src/axolotl/utils/models.py

635-636: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (15)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (2)
src/axolotl/utils/models.py (1)

639-666: Well-implemented Flash Attention 3 integration with backward compatibility.

This section properly imports Flash Attention 3 functions after verifying the package is available, and wraps them to maintain compatibility with the transformer's existing function signatures. The wrapper functions handle the differences between FA2 and FA3 by:

  1. Removing the dropout_p parameter
  2. Adjusting argument positions when softmax_scale is provided
  3. Only returning the first element of the result

The implementation is thorough and includes proper logging when FA3 is activated.

.github/workflows/base.yml (1)

50-50: Consistent CUDA version bump for the 2.7.0 build
The change from 12.6.3 to 12.8.1 aligns with the PR objective to support CUDA 12.8. Confirm that your docker/Dockerfile-base (and any FlashAttention install steps) have been updated to handle CUDA 12.8.1 accordingly.

@@ -87,7 +94,7 @@ jobs:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix tag-generation referencing an undefined property
The expression still uses matrix.axolotl_extras, which isn’t defined in any include block and will evaluate to an error. Since you’ve introduced matrix.suffix, drop all references to axolotl_extras. For example:

- tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
+ tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.suffix || '' }}

This will ensure your CI tags render correctly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.suffix || '' }}
🧰 Tools
🪛 actionlint (1.7.7)

97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string}

(expression)

🤖 Prompt for AI Agents
In .github/workflows/base.yml at line 97, the tag-generation expression
references an undefined property matrix.axolotl_extras, causing errors. Remove
all references to matrix.axolotl_extras from the tags expression and rely solely
on matrix.suffix for any additional tag suffixes to ensure the CI tags render
correctly without errors.

@winglian
Copy link
Collaborator Author

@casper-hansen Doesn't seem any faster (slower if anything)
https://wandb.ai/axolotl-ai/fa3-h100/workspace?nw=nwuserwingaxolotl

Screenshot 2025-05-18 at 4 47 13 PM

@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi

RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this your built wheel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, built it in a docker image on a lambda instance, and then extracted it

@casper-hansen
Copy link
Collaborator

casper-hansen commented May 19, 2025

@winglian Unless there is something wrong with the new flash attention version or the speed improvement advertisement is wrong, then I think it must be something with the setup.

Quick checks:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants