-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Create base docker images for CUDA 12.8 with custom FlashAttention 3 installed #2685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
""" WalkthroughThe changes update CUDA versions and add a new matrix entry with a "hopper" suffix and restricted CUDA architectures in workflows. Dockerfiles and CI scripts are modified to install a specific prebuilt flash-attn wheel for CUDA 12.6 and PyTorch 2.6. The model loader adds conditional support for Flash Attention 3, controlled by a new config option. Tests are updated for this feature and to reset monkeypatches. Changes
Sequence Diagram(s)sequenceDiagram
participant Config as Config File
participant ModelLoader
participant FlashAttnUtils as FlashAttention Utils
participant GPU
Config->>ModelLoader: Provide use_flash_attention_3 setting
ModelLoader->>GPU: Query compute capability
alt use_flash_attention_3 enabled (True or auto + GPU >= 90)
ModelLoader->>FlashAttnUtils: Import FA3 functions
ModelLoader->>FlashAttnUtils: Override default flash-attn funcs with FA3
else
ModelLoader->>FlashAttnUtils: Use default flash-attn functions
end
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms (10)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
docker/Dockerfile-base (1)
41-48
: Improve code indentation for readability.The indentation in the conditional block is inconsistent, making it harder to read. Consider standardizing the indentation.
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ - git clone https://github.com/Dao-AILab/flash-attention.git; \ - git checkout v2.7.4.post1; \ - cd flash-attention/hopper; \ - FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \ - elif if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ - pip3 install flash-attn==2.7.4.post1; \ - fi + git clone https://github.com/Dao-AILab/flash-attention.git; \ + git checkout v2.7.4.post1; \ + cd flash-attention/hopper; \ + FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \ +elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ + pip3 install flash-attn==2.7.4.post1; \ +fi
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (2)
.github/workflows/base.yml
(2 hunks)docker/Dockerfile-base
(1 hunks)
🧰 Additional context used
🪛 actionlint (1.7.7)
.github/workflows/base.yml
97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string}
(expression)
⏰ Context from checks skipped due to timeout of 90000ms (14)
- GitHub Check: build-base (128, 12.8.1, 3.11, next, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: build-base (128, 12.8.1, 3.11, nightly, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (3)
.github/workflows/base.yml (2)
50-50
: CUDA version update looks good.Updating to the latest CUDA 12.8.1 version ensures compatibility with newer drivers and GPUs.
55-61
: Well-structured Hopper-specific configuration.This new matrix entry creates a specialized Docker image optimized specifically for NVIDIA Hopper architecture (H100 GPUs) by:
- Using a restricted CUDA architecture list "9.0+PTX"
- Adding the "-hopper" suffix for clear identification
- Using PyTorch 2.7.0 which has optimizations for this architecture
This corresponds with the FlashAttention build logic in the Dockerfile.
docker/Dockerfile-base (1)
41-46
: Good optimization for Hopper architecture.Building flash-attention from source with Hopper-specific optimizations is a good approach. The special flags disable features not needed for Hopper and increase build parallelism.
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
docker/Dockerfile-base (2)
41-50
: Consider adding explanatory comments for maintainability.While the implementation is correct, adding comments explaining why different installation methods are needed for different architectures would improve maintainability for future contributors.
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ + # Build flash-attention from source for Hopper architecture (9.0+PTX) \ + # with specific optimizations for this architecture \ git clone https://github.com/Dao-AILab/flash-attention.git; \ cd flash-attention; \ git checkout v2.7.4.post1; \ cd hopper; \ FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \ cd ../..; \ elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ + # Use pre-built package for PyTorch 2.7.0 with other architectures \ pip3 install flash-attn==2.7.4.post1; \ fi
41-50
: Add error handling for git operations.The current implementation doesn't check if git operations succeed. Consider adding basic error checking to make debugging easier if the build fails.
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ - git clone https://github.com/Dao-AILab/flash-attention.git; \ - cd flash-attention; \ - git checkout v2.7.4.post1; \ - cd hopper; \ - FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \ - cd ../..; \ + git clone https://github.com/Dao-AILab/flash-attention.git && \ + cd flash-attention && \ + git checkout v2.7.4.post1 && \ + cd hopper && \ + FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install && \ + cd ../.. || echo "Failed to build flash-attention from source" && exit 1; \ elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ pip3 install flash-attn==2.7.4.post1; \ fi
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
docker/Dockerfile-base
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (14)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: build-base (128, 12.8.1, 3.11, next, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: pre-commit
🔇 Additional comments (2)
docker/Dockerfile-base (2)
41-47
: Custom build path for Hopper architecture looks good.The implementation correctly builds FlashAttention from source when targeting the Hopper architecture (CUDA 9.0+PTX). Using environment variables
FLASH_ATTENTION_DISABLE_SM80=TRUE
andFLASH_ATTENTION_DISABLE_FP8=TRUE
is appropriate for optimizing for this specific architecture.The directory navigation has been properly implemented with the correct
cd
commands, addressing previous issues.
48-48
: Conditional logic for PyTorch 2.7.0 is correctly implemented.The conditional statement has the correct syntax, with no duplicate "if" keywords that were present in previous versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
docker/Dockerfile-base (1)
41-50
:⚠️ Potential issueMissing fallback for other PyTorch/CUDA combinations.
The current implementation only installs flash-attention when
TORCH_CUDA_ARCH_LIST
is "9.0+PTX" or whenPYTORCH_VERSION
is "2.7.0". However, the workflow runs with other combinations that don't satisfy either condition.Add an else clause to ensure flash-attention is always installed:
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ git clone https://github.com/Dao-AILab/flash-attention.git; \ cd flash-attention; \ git checkout v2.7.4.post1; \ cd hopper; \ FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=64 python setup.py install; \ cd ../..; \ elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ pip3 install flash-attn==2.7.4.post1; \ - fi + else \ + # Fallback for all other PyTorch/CUDA combinations + pip3 install flash-attn==2.7.4.post1; \ + fi
🧹 Nitpick comments (1)
docker/Dockerfile-base (1)
46-46
: Consider using environment variable for MAX_JOBS.You're setting MAX_JOBS=64 directly in the build command, but there's a MAX_JOBS ARG defined at line 4 that's set to 4 by default. For consistency, consider using that variable.
- FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=64 python setup.py install; \ + FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=$MAX_JOBS python setup.py install; \
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
docker/Dockerfile-base
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (13)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: build-base (128, 12.8.1, 3.11, next, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: pre-commit
🔇 Additional comments (1)
docker/Dockerfile-base (1)
41-47
: Good implementation for Hopper architecture-specific build.The changes correctly set up building flash-attention from source for the "9.0+PTX" CUDA architecture (NVIDIA Hopper/H100 GPUs). The code properly:
- Clones the repository
- Checks out the specific version
- Builds from the hopper directory with appropriate flags
- Navigates back to the original directory
docker/Dockerfile-base
Outdated
cd flash-attention; \ | ||
git checkout v2.7.4.post1; \ | ||
cd hopper; \ | ||
FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=64 python setup.py install; \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this MAX_JOBS too high? What kind of hardware is the CPU building on?
docker/Dockerfile-base
Outdated
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ | ||
git clone https://github.com/Dao-AILab/flash-attention.git; \ | ||
cd flash-attention; \ | ||
git checkout v2.7.4.post1; \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw that this release is Jan 2025 and there's been a lot of changes to hopper
dir since then.
- cuda: "126" | ||
cuda_version: 12.6.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taken from readme
We highly recommend CUDA 12.8 for best performance.
As only pytorch 2.7 has 12.8, should we swap to that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fa3 doesn't compile with torch 2.7.0 due to an error that I don't have offhand, but is related to an API change in torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
docker/Dockerfile-base (1)
41-50
: Missing fallback for other PyTorch versionsThe installation logic only handles two cases:
- When TORCH_CUDA_ARCH_LIST="9.0+PTX" (build from source)
- When PYTORCH_VERSION="2.7.0" (pip install)
Other combinations from the workflow matrix (PyTorch 2.5.1, 2.6.0, nightly, next) will not have flash-attention installed.
#!/bin/bash # Extract matrix combinations from workflow file to verify coverage grep -r "PYTORCH_VERSION" .github/workflows/ | grep -v "2.7.0" grep -r "TORCH_CUDA_ARCH_LIST" .github/workflows/ | grep -v "9.0+PTX"Consider adding a fallback else clause:
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ git clone https://github.com/Dao-AILab/flash-attention.git; \ cd flash-attention; \ git checkout v2.7.4.post1; \ cd hopper; \ FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=32 python setup.py install; \ cd ../..; \ elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ pip3 install flash-attn==2.7.4.post1; \ + else \ + # Default installation for other PyTorch/CUDA combinations + pip3 install flash-attn==2.7.4.post1; \ fi
🧹 Nitpick comments (1)
docker/Dockerfile-base (1)
41-47
: Consider adding error handlingThe multi-step process of cloning, checking out, and building from source lacks error handling. If any step fails, the build might proceed with an incomplete or missing flash-attention installation.
Consider adding error checking:
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ + set -e && \ git clone https://github.com/Dao-AILab/flash-attention.git; \ cd flash-attention; \ git checkout v2.7.4.post1; \ cd hopper; \ FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=32 python setup.py install; \ cd ../..; \
This ensures that if any command fails, the entire RUN instruction will fail, preventing partial installations.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
docker/Dockerfile-base
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (12)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: pre-commit
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (2)
docker/Dockerfile-base (2)
41-47
: Building flash-attention from source for Hopper GPUs looks goodThe implementation correctly clones and builds the flash-attention package from source when targeting Hopper GPUs (CUDA arch 9.0+PTX). The directory navigation has been fixed as suggested in the previous review.
46-46
: MAX_JOBS value has been reducedGood adjustment of MAX_JOBS from 128 to 32, which addresses the previous concern about setting it too high. This provides a more reasonable balance between build parallelism and resource usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
docker/Dockerfile-base (1)
41-50
: Add a fallback clause for flash-attention installationThe current implementation only handles when
TORCH_CUDA_ARCH_LIST
is "9.0+PTX" or whenPYTORCH_VERSION
is "2.7.0". You need to add a fallback clause for cases where neither condition is met to ensure flash-attention is always installed.RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ git clone https://github.com/Dao-AILab/flash-attention.git; \ cd flash-attention; \ git checkout v2.7.4.post1; \ cd hopper; \ FLASH_ATTENTION_DISABLE_SM80=TRUE MAX_JOBS=${MAX_JOBS} python setup.py install; \ cd ../..; \ elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ pip3 install flash-attn==2.7.4.post1; \ fi + else \ + # Fallback for all other PyTorch/CUDA combinations + pip3 install flash-attn==2.7.4.post1; \ + fi
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
docker/Dockerfile-base
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (12)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: pre-commit
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (3)
docker/Dockerfile-base (3)
46-47
: Approve usage of MAX_JOBS variableGood improvement using the configurable MAX_JOBS ARG instead of a hardcoded value, and the proper directory navigation with
cd ..
commands has been implemented correctly.
1-2
: Approve CUDA version updatesThe updates to CUDA version (12.4.1) and empty CUDNN version string are appropriate for modern ML environments.
10-12
: Approve Python and PyTorch updatesThe updates to Python 3.11 and PyTorch 2.5.1 with CUDA 124 are consistent with the objective of creating updated base images.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
docker/Dockerfile-base (1)
41-47
: 🛠️ Refactor suggestionAdd a fallback for other matrix combinations.
The updated logic now handles two specific cases:
- When TORCH_CUDA_ARCH_LIST="9.0+PTX", it installs the custom FlashAttention 3 beta wheel
- When PYTORCH_VERSION="2.7.0", it installs flash-attn 2.7.4.post1
However, it still doesn't have a fallback for other matrix combinations (e.g., PyTorch 2.5.1 with multi-SM architecture lists).
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ pip3 install flash-attn==2.7.4.post1; \ + else \ + # Fallback for all other PyTorch/CUDA combinations + pip3 install flash-attn==2.7.4.post1; \ fi
🧹 Nitpick comments (1)
docker/Dockerfile-base (1)
41-44
: Consider making the wheel URL more configurable.The wheel URL has hardcoded components (version 3.0.0b1, Python 3.11) but also uses variables (${CUDA}, ${PYTORCH_VERSION}). For better maintainability, consider making more components configurable through ARG/ENV variables.
+ARG FLASH_ATTN_VERSION="3.0.0b1" +ARG PYTHON_WHEEL_TAG="cp311-cp311" RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ - curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ - pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ + curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-${FLASH_ATTN_VERSION}-${PYTHON_WHEEL_TAG}-linux_x86_64.whl; \ + pip3 install --no-cache-dir flash_attn_interface-${FLASH_ATTN_VERSION}-${PYTHON_WHEEL_TAG}-linux_x86_64.whl; \ + rm flash_attn_interface-${FLASH_ATTN_VERSION}-${PYTHON_WHEEL_TAG}-linux_x86_64.whl; \
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
docker/Dockerfile-base
(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (11)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (3)
docker/Dockerfile-base (3)
10-12
: LGTM: Version updates are consistent with each other.The Python and PyTorch version updates are properly synchronized, and the CUDA shorthand "124" correctly matches the CUDA version 12.4.1.
31-31
: LGTM: Added curl dependency for wheel download.Adding curl to the pip install list is necessary for the flash-attn wheel download in the updated installation logic.
1-2
:❓ Verification inconclusive
Version inconsistency between PR title and implementation.
The PR title mentions "CUDA 12.8" but the Dockerfile specifies CUDA 12.4.1. Also, the CUDNN_VERSION has been changed from a specific version to an empty string. Please clarify if this is intentional.
🏁 Script executed:
#!/bin/bash # Description: Check CUDA references in other files to ensure consistency echo "Searching for CUDA version references in workflow files:" grep -r "CUDA.*12.8" --include="*.yml" .github grep -r "CUDA.*12.4" --include="*.yml" .githubLength of output: 220
🏁 Script executed:
#!/bin/bash # Verify all CUDA_VERSION declarations echo "🔍 Checking for all ARG CUDA_VERSION in the repo" grep -R "ARG CUDA_VERSION" -n . # Verify all CUDNN_VERSION declarations echo "🔍 Checking for all ARG CUDNN_VERSION in the repo" grep -R "ARG CUDNN_VERSION" -n . # Look for any literal “12.8” occurrences echo "🔍 Searching for any references to ‘12.8’" grep -R "12\.8" -n . # Look for any literal “12.4” occurrences beyond Dockerfile-base echo "🔍 Searching for any references to ‘12.4’" grep -R "12\.4" -n .Length of output: 4101
Clarify CUDA & cuDNN version alignment across Dockerfiles & workflows
It looks like the PR introduces CUDA 12.8.1 in the “next” and “nightly” Dockerfiles but leaves the main
docker/Dockerfile-base
at CUDA 12.4.1 with an emptyCUDNN_VERSION
. Meanwhile, CI workflows reference both versions.• docker/Dockerfile-base (lines 1–2):
ARG CUDA_VERSION="12.4.1"
ARG CUDNN_VERSION=""
• docker/Dockerfile-base-next & docker/Dockerfile-base-nightly:
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
• .github/workflows/**/*: jobs defined for both cuda_version: 12.4.1 and cuda_version: 12.8.1Please confirm the intended strategy:
- Keep the stable base at CUDA 12.4.1 and use separate images for CUDA 12.8.1?
- Or bump the main Dockerfile-base to CUDA 12.8.1 (and pin CUDNN)?
- Should the blank
CUDNN_VERSION
in the main Dockerfile be restored or documented?Once clarified, we can align the ARGs and workflows accordingly.
docker/Dockerfile-base
Outdated
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ | ||
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ | ||
pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ | ||
rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify FlashAttention 3 wheel availability.
The code downloads a specific FlashAttention 3.0.0b1 wheel that's hardcoded for Python 3.11. Verify that this wheel is publicly available and will remain accessible at the specified URL.
🏁 Script executed:
#!/bin/bash
# Description: Test if the FlashAttention 3.0.0b1 wheel URL is accessible
curl -I -L https://d1dttdx32dkk5p.cloudfront.net/fa3/cu124/torch-2.5.1/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl
Length of output: 746
FlashAttention wheel download is inaccessible (HTTP 403)
The hardcoded FlashAttention 3.0.0b1 wheel URL returns a 403 Forbidden, so the Docker build will fail at the curl
/pip install
step. This wheel does not appear to be publicly available.
Please update or replace this step:
- Confirm the correct public URL for the FlashAttention 3.0.0b1 wheel matching
${CUDA}
and Python version, or host a mirror that allows downloads. - Alternatively, fall back to installing via
pip install flash-attn
(if available) or include a prebuilt wheel in the repo.
Locations to update:
- docker/Dockerfile-base, lines 41–44: the
curl -L -O ...flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl
and subsequentpip3 install
commands.
🤖 Prompt for AI Agents
In docker/Dockerfile-base around lines 41 to 44, the current step downloads a
FlashAttention 3.0.0b1 wheel from a hardcoded URL that returns a 403 Forbidden
error, causing the Docker build to fail. To fix this, replace the curl download
and pip install commands with either a valid public URL for the wheel matching
the CUDA and Python version or switch to installing FlashAttention via pip
directly using 'pip install flash-attn'. Alternatively, include a prebuilt wheel
in the repository and install from there. Update the Dockerfile accordingly to
ensure the wheel is accessible and installable during the build.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
docker/Dockerfile-base (2)
41-47
:⚠️ Potential issueFix missing fallback for other PyTorch/CUDA combinations.
The conditional logic only handles two specific cases:
- When TORCH_CUDA_ARCH_LIST="9.0+PTX"
- When PYTORCH_VERSION="2.7.0"
There's no fallback for other matrix combinations, which means flash-attn won't be installed in those cases.
Add an else clause to ensure flash-attention is installed for all combinations:
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ pip3 install flash-attn==2.7.4.post1; \ + else \ + # Fallback for all other PyTorch/CUDA combos + pip3 install flash-attn==2.7.4.post1; \ fi
41-44
:⚠️ Potential issueFix inaccessible wheel URL (403 Forbidden).
The wheel URL returns a 403 Forbidden error, which will cause Docker builds to fail. This issue was previously identified but hasn't been addressed.
Consider these alternatives:
- Obtain and use a publicly accessible URL for the wheel
- Host the wheel in your own repository or storage
- Build the wheel from source
- Use a publicly available version from PyPI
Example fix option:
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ - curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ - pip3 install --no-cache-dir flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn_interface-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ + # Build from source for Hopper architecture + git clone https://github.com/Dao-AILab/flash-attention.git && \ + cd flash-attention && \ + git checkout v3.0.0 && \ # Use appropriate tag for FlashAttention 3 + FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=4 python setup.py install && \ + cd .. && \ + rm -rf flash-attention; \
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
docker/Dockerfile-base
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (12)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: pre-commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
cicd/Dockerfile.jinja (1)
35-39
: Conditionally installing Flash Attention 3 wheel.The conditional installation of the Flash Attention 3 wheel for the specific combination of PyTorch 2.6.0 and CUDA 12.6 is appropriate. Using a prebuilt wheel from CloudFront is efficient.
However, consider the following improvements:
- Add a comment explaining why this specific wheel is needed for Hopper GPUs
- Consider using a more version-controlled source for the wheel file or add a hash verification step
- Add error handling in case the download fails
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \ + # Install prebuilt Flash Attention 3 wheel for Hopper GPUs (compute capability 9.0+) \ curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ + if [ ! -f flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl ]; then \ + echo "Failed to download Flash Attention 3 wheel"; \ + exit 1; \ + fi; \ pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ fisrc/axolotl/utils/models.py (3)
635-636
: Consider simplifying the nested condition.The nested if statements could be simplified for better readability.
- elif self.cfg.use_flash_attention_3 == "auto": - if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90: + elif self.cfg.use_flash_attention_3 == "auto" and int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90:🧰 Tools
🪛 Ruff (0.11.9)
635-636: Use a single
if
statement instead of nestedif
statements(SIM102)
723-724
: Note the TODO for deprecation.This TODO comment indicates intent to deprecate this code path soon. Consider creating a tracking issue if one doesn't already exist.
731-732
: Another deprecation TODO marked.Similar to the previous TODO, this marks another code section for future deprecation. Consider consolidating these TODOs into a single tracking issue for better maintainability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (7)
.github/workflows/multi-gpu-e2e.yml
(1 hunks)cicd/Dockerfile.jinja
(1 hunks)docker/Dockerfile-base
(2 hunks)src/axolotl/utils/models.py
(2 hunks)src/axolotl/utils/schemas/config.py
(1 hunks)tests/conftest.py
(1 hunks)tests/e2e/test_packing_loss.py
(3 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/schemas/config.py
🚧 Files skipped from review as they are similar to previous changes (1)
- docker/Dockerfile-base
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/e2e/test_packing_loss.py (2)
tests/e2e/utils.py (1)
check_tensorboard
(135-149)tests/conftest.py (1)
temp_dir
(414-419)
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py
635-636: Use a single if
statement instead of nested if
statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (10)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
🔇 Additional comments (7)
tests/conftest.py (1)
461-461
: Good addition to the module reset list.Including
"transformers.modeling_flash_attention_utils"
in the modules to reset ensures that any monkeypatches to Flash Attention 3 utilities are properly cleaned up between tests, maintaining test isolation when testing the new Flash Attention 3 functionality..github/workflows/multi-gpu-e2e.yml (1)
35-41
: LGTM: Proper matrix configuration for Hopper GPUs.The addition of a new matrix entry for CUDA 12.6.3 with PyTorch 2.6.0 and the "-hopper" suffix aligns with the Flash Attention 3 support being added in this PR. This configuration will allow testing on Hopper architecture GPUs.
tests/e2e/test_packing_loss.py (4)
8-8
: Good replacement of unittest with pytest.Replacing unittest with pytest aligns with the project's testing approach. This change is appropriate.
23-23
: Clean class refactoring.Refactoring the test class to follow pytest patterns rather than using unittest.TestCase improves consistency with the project's testing style.
28-32
: Excellent test parameterization for Flash Attention 3.Using pytest's parameterization to test both without Flash Attention 3 and with "auto" mode ensures the feature works correctly in both configurations. This is a good testing practice.
60-60
: Good configuration addition.Adding the
use_flash_attention_3
parameter to the test configuration ensures that the test properly exercises the new Flash Attention 3 support.src/axolotl/utils/models.py (1)
632-652
: Well-implemented Flash Attention 3 integration.The conditional support for Flash Attention 3 is thoughtfully implemented, with proper checks for both configuration intent and hardware capability. The code intelligently enables FA3 based on the
use_flash_attention_3
config setting (either explicitly enabled or automatically enabled for Hopper GPUs with compute capability 90+).🧰 Tools
🪛 Ruff (0.11.9)
635-636: Use a single
if
statement instead of nestedif
statements(SIM102)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
src/axolotl/utils/models.py (3)
632-660
: Good implementation of Flash Attention 3 support.The conditional Flash Attention 3 support is well-implemented with automatic detection for Hopper GPUs. The wrapper functions elegantly handle the API differences between versions.
Two minor improvements worth considering:
- The nested if statements could be simplified for better readability
- Consider adding a brief comment explaining the key differences between FA2 and FA3 that necessitate these wrappers
- use_fa3 = False - if self.cfg.use_flash_attention_3 is True: - use_fa3 = True - elif self.cfg.use_flash_attention_3 == "auto": - if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90: - # FA3 is only available on Hopper GPUs and newer - use_fa3 = True + # FA3 is enabled explicitly or automatically for Hopper GPUs (compute capability 90+) + use_fa3 = (self.cfg.use_flash_attention_3 is True or + (self.cfg.use_flash_attention_3 == "auto" and + int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90))🧰 Tools
🪛 Ruff (0.11.9)
635-636: Use a single
if
statement instead of nestedif
statements(SIM102)
731-732
: Consider providing more context for the TODO comment.This TODO comment suggests future deprecation, but doesn't provide context about why or when it should be deprecated, or reference any tracking issue. This makes it harder for contributors to understand the roadmap.
- # TODO deprecate soon + # TODO: Deprecate btlm flash attention patch soon, as it's being replaced by the standard implementation in #ISSUE_NUMBER
739-740
: Consider providing more context for the TODO comment.Similar to the previous TODO, this comment lacks context about the deprecation plan.
- # TODO deprecate soon + # TODO: Deprecate stablelm_epoch flash attention patch for sample packing soon, as it's being replaced by the standard implementation in #ISSUE_NUMBER
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (2)
src/axolotl/utils/models.py
(2 hunks)tests/conftest.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/conftest.py
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py
635-636: Use a single if
statement instead of nested if
statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (17)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
- GitHub Check: pre-commit
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: pre-commit
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
src/axolotl/utils/models.py (3)
632-660
: Well-implemented Flash Attention 3 conditional support.The implementation intelligently enables FA3 based on GPU capabilities while maintaining backward compatibility. The wrapper functions handle the interface differences between Flash Attention 2 and 3 appropriately.
Consider simplifying the nested if statements for determining when to use FA3:
- use_fa3 = False - if self.cfg.use_flash_attention_3 is True: - use_fa3 = True - elif self.cfg.use_flash_attention_3 == "auto": - if int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90: - # FA3 is only available on Hopper GPUs and newer - use_fa3 = True + use_fa3 = self.cfg.use_flash_attention_3 is True or ( + self.cfg.use_flash_attention_3 == "auto" and + int(self.cfg.capabilities.compute_capability.split("_")[-1]) >= 90 # Hopper GPUs or newer + )🧰 Tools
🪛 Ruff (0.11.9)
635-636: Use a single
if
statement instead of nestedif
statements(SIM102)
731-732
: Consider expanding on the deprecation plan.The TODO comment doesn't specify the reason for deprecation or a timeline. Consider adding more context.
- # TODO deprecate soon + # TODO: Deprecate soon - Replace with the new Flash Attention 3 implementation (issue #XXXX)
739-740
: Consider expanding on the deprecation plan.Similar to the previous TODO, adding more context would be helpful.
- # TODO deprecate soon + # TODO: Deprecate soon - Replace with the new Flash Attention 3 implementation (issue #XXXX)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
src/axolotl/utils/models.py
(2 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py
635-636: Use a single if
statement instead of nested if
statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (14)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/utils/models.py (1)
731-732
: Consider providing more context in TODO comments.These TODO comments about deprecating specific code paths don't provide enough information about when or why they should be deprecated.
Consider expanding these comments to include:
- Reason for deprecation
- Approximate timeline
- Alternative approach or replacement
Also applies to: 739-740
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (3)
src/axolotl/utils/models.py
(2 hunks)tests/e2e/multigpu/test_llama.py
(2 hunks)tests/e2e/test_packing_loss.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- tests/e2e/test_packing_loss.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/e2e/multigpu/test_llama.py (2)
tests/e2e/multigpu/test_gemma3.py (1)
test_lora_ddp_packed
(36-100)tests/conftest.py (1)
temp_dir
(414-419)
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py
635-636: Use a single if
statement instead of nested if
statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (14)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: pre-commit
🔇 Additional comments (5)
tests/e2e/multigpu/test_llama.py (3)
104-107
: Good addition of parameterized testing for Flash Attention 3 support.Adding a new parameter to test both disabled and automatic detection settings is a good approach for testing this new feature.
109-110
: Correctly updated method signature for new parameter.The method signature correctly includes the new parameter from the parameterization.
147-147
: Properly passing the parameter to the configuration.This ensures the Flash Attention 3 setting will be passed through to the model loader where it's needed.
src/axolotl/utils/models.py (2)
632-660
: Well-implemented conditional Flash Attention 3 support.The implementation nicely handles:
- Explicit enabling via configuration
- Automatic detection based on GPU capability (CUDA 9.0+ for Hopper)
- Package availability check
- Function signature adaptation through wrapper functions
- Appropriate logging
However, the nested if statements at lines 635-636 were flagged in static analysis. While functionally correct, consider a cleaner approach:
- if self.cfg.use_flash_attention_3 is True: - use_fa3 = True - elif self.cfg.use_flash_attention_3 == "auto": - if torch.cuda.get_device_capability() >= (9, 0): - # FA3 is only available on Hopper GPUs and newer - use_fa3 = True + use_fa3 = (self.cfg.use_flash_attention_3 is True or + (self.cfg.use_flash_attention_3 == "auto" and + torch.cuda.get_device_capability() >= (9, 0))) # FA3 is only available on Hopper GPUs and newerThis is optional as the current implementation is also clear and readable.
🧰 Tools
🪛 Ruff (0.11.9)
635-636: Use a single
if
statement instead of nestedif
statements(SIM102)
645-651
: Appropriate wrapper functions to handle API differences.The wrapper functions correctly handle removing the
dropout_p
parameter and extracting just the first return value from the FA3 functions. This approach properly bridges the differences between the Flash Attention 3 interface and what transformers expects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
src/axolotl/utils/models.py (3)
632-666
: Well-implemented Flash Attention 3 support with GPU capability checkThe implementation conditionally enables Flash Attention 3 based on configuration settings and hardware capabilities, with proper fallback mechanisms in place. The wrapper functions handle the interface differences between FA2 and FA3 effectively.
The nested if statements could be simplified into a single conditional:
- if self.cfg.use_flash_attention_3 is True: - use_fa3 = True - elif self.cfg.use_flash_attention_3 == "auto": - if torch.cuda.get_device_capability() >= (9, 0): - # FA3 is only available on Hopper GPUs and newer - use_fa3 = True + use_fa3 = (self.cfg.use_flash_attention_3 is True or + (self.cfg.use_flash_attention_3 == "auto" and + torch.cuda.get_device_capability() >= (9, 0))) + # FA3 is only available on Hopper GPUs and newer when in auto mode🧰 Tools
🪛 Ruff (0.11.9)
635-636: Use a single
if
statement instead of nestedif
statements(SIM102)
737-738
: Consider adding more context to the TODO commentThe TODO comment is somewhat vague. Adding more details about why this needs to be deprecated and the preferred alternative would help future developers.
- # TODO deprecate soon + # TODO deprecate soon - Replace with proper Flash Attention integration for all model types
745-746
: Consider adding more context to the TODO commentSame issue as the previous TODO - providing more context would be helpful for future maintenance.
- # TODO deprecate soon + # TODO deprecate soon - Replace with proper Flash Attention integration for all model types
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (2)
src/axolotl/utils/models.py
(2 hunks)src/axolotl/utils/schemas/config.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/utils/schemas/config.py
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py
635-636: Use a single if
statement instead of nested if
statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (15)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
src/axolotl/utils/models.py (3)
632-639
: Good implementation of conditional Flash Attention 3 with auto-detection.The code correctly determines when to use Flash Attention 3 based on either explicit configuration (
True
) or auto-detection ("auto"
) by checking the CUDA device capability. This ensures FA3 is only enabled on compatible hardware (Hopper GPUs with CUDA capability 9.0+).I do have one minor suggestion to improve readability:
- if self.cfg.use_flash_attention_3 is True: - use_fa3 = True - elif self.cfg.use_flash_attention_3 == "auto": - if torch.cuda.get_device_capability() >= (9, 0): - # FA3 is only available on Hopper GPUs and newer - use_fa3 = True + if self.cfg.use_flash_attention_3 is True or ( + self.cfg.use_flash_attention_3 == "auto" and + torch.cuda.get_device_capability() >= (9, 0) # FA3 is only available on Hopper GPUs and newer + ): + use_fa3 = True🧰 Tools
🪛 Ruff (0.11.9)
635-636: Use a single
if
statement instead of nestedif
statements(SIM102)
737-738
: Consider adding more context to the TODO comment.The "TODO deprecate soon" comment should include more information about why this needs to be deprecated and possibly a timeline or reference.
- # TODO deprecate soon + # TODO: Deprecate btlm flash attention patch soon as it's being replaced by the more generic FA3 approach
745-746
: Consider adding more context to the TODO comment.Similar to the previous comment, this TODO should be more descriptive about why it needs to be deprecated.
- # TODO deprecate soon + # TODO: Deprecate stablelm_epoch flash attention patch soon as it's being replaced by the more generic FA3 approach.github/workflows/base.yml (1)
55-61
: Introduce Hopper-specific build variant
You’ve added a CUDA 12.6.3 + PyTorch 2.6.0 matrix entry targeting only compute capability9.0+PTX
with a-hopper
suffix. This correctly scopes the build to Hopper GPUs. Consider trimming the leading dash from thesuffix
value (e.g.suffix: "hopper"
) and moving the dash into the tag expression—this keeps labels and values consistent and avoids embedding formatting in data.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (9)
.github/workflows/base.yml
(2 hunks).github/workflows/multi-gpu-e2e.yml
(1 hunks)cicd/Dockerfile.jinja
(1 hunks)docker/Dockerfile-base
(2 hunks)src/axolotl/utils/models.py
(2 hunks)src/axolotl/utils/schemas/config.py
(1 hunks)tests/conftest.py
(4 hunks)tests/e2e/multigpu/test_llama.py
(2 hunks)tests/e2e/test_packing_loss.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
- src/axolotl/utils/schemas/config.py
- cicd/Dockerfile.jinja
- .github/workflows/multi-gpu-e2e.yml
- tests/conftest.py
- tests/e2e/test_packing_loss.py
- tests/e2e/multigpu/test_llama.py
- docker/Dockerfile-base
🧰 Additional context used
🪛 actionlint (1.7.7)
.github/workflows/base.yml
97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string}
(expression)
🪛 Ruff (0.11.9)
src/axolotl/utils/models.py
635-636: Use a single if
statement instead of nested if
statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (15)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, -hopper, 9.0+PTX)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, -hopper, 2)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.5.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX)
🔇 Additional comments (2)
src/axolotl/utils/models.py (1)
639-666
: Well-implemented Flash Attention 3 integration with backward compatibility.This section properly imports Flash Attention 3 functions after verifying the package is available, and wraps them to maintain compatibility with the transformer's existing function signatures. The wrapper functions handle the differences between FA2 and FA3 by:
- Removing the
dropout_p
parameter- Adjusting argument positions when
softmax_scale
is provided- Only returning the first element of the result
The implementation is thorough and includes proper logging when FA3 is activated.
.github/workflows/base.yml (1)
50-50
: Consistent CUDA version bump for the 2.7.0 build
The change from12.6.3
to12.8.1
aligns with the PR objective to support CUDA 12.8. Confirm that yourdocker/Dockerfile-base
(and any FlashAttention install steps) have been updated to handle CUDA 12.8.1 accordingly.
@@ -87,7 +94,7 @@ jobs: | |||
context: . | |||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }} | |||
push: ${{ github.event_name != 'pull_request' }} | |||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} | |||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix tag-generation referencing an undefined property
The expression still uses matrix.axolotl_extras
, which isn’t defined in any include block and will evaluate to an error. Since you’ve introduced matrix.suffix
, drop all references to axolotl_extras
. For example:
- tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
+ tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.suffix || '' }}
This will ensure your CI tags render correctly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }} | |
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.suffix || '' }} |
🧰 Tools
🪛 actionlint (1.7.7)
97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string}
(expression)
🤖 Prompt for AI Agents
In .github/workflows/base.yml at line 97, the tag-generation expression
references an undefined property matrix.axolotl_extras, causing errors. Remove
all references to matrix.axolotl_extras from the tags expression and rely solely
on matrix.suffix for any additional tag suffixes to ensure the CI tags render
correctly without errors.
@casper-hansen Doesn't seem any faster (slower if anything) ![]() |
@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ | |||
fi | |||
|
|||
RUN pip install packaging==23.2 setuptools==75.8.0 | |||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \ | |||
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this your built wheel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, built it in a docker image on a lambda instance, and then extracted it
@winglian Unless there is something wrong with the new flash attention version or the speed improvement advertisement is wrong, then I think it must be something with the setup. Quick checks:
|
Summary by CodeRabbit
New Features
Chores
Tests