Skip to content

Create base docker images for CUDA 12.8 with custom FlashAttention 3 installed #2685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions .github/workflows/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,18 @@ jobs:
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
Comment on lines +55 to +56
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taken from readme

We highly recommend CUDA 12.8 for best performance.

As only pytorch 2.7 has 12.8, should we swap to that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fa3 doesn't compile with torch 2.7.0 due to an error that I don't have offhand, but is related to an API change in torch.

cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
suffix: "-hopper"
torch_cuda_arch_list: "9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
Expand Down Expand Up @@ -87,7 +94,7 @@ jobs:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix tag-generation referencing an undefined property
The expression still uses matrix.axolotl_extras, which isn’t defined in any include block and will evaluate to an error. Since you’ve introduced matrix.suffix, drop all references to axolotl_extras. For example:

- tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
+ tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.suffix || '' }}

This will ensure your CI tags render correctly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.suffix || '' }}
🧰 Tools
🪛 actionlint (1.7.7)

97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string}

(expression)

🤖 Prompt for AI Agents
In .github/workflows/base.yml at line 97, the tag-generation expression
references an undefined property matrix.axolotl_extras, causing errors. Remove
all references to matrix.axolotl_extras from the tags expression and rely solely
on matrix.suffix for any additional tag suffixes to ensure the CI tags render
correctly without errors.

labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
Expand Down
11 changes: 7 additions & 4 deletions .github/workflows/multi-gpu-e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,25 @@ jobs:
pytorch: 2.6.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
suffix: "-hopper"
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
Expand All @@ -68,7 +72,6 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
Expand Down
5 changes: 5 additions & 0 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi

RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this your built wheel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, built it in a docker image on a lambda instance, and then extracted it

pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
Expand Down
18 changes: 11 additions & 7 deletions docker/Dockerfile-base
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG CUDA_VERSION="12.4.1"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4

FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder

ENV PATH="/root/miniconda3/bin:${PATH}"

ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.5.1"
ARG CUDA="124"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"

ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST

RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
Expand All @@ -38,6 +38,10 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi
37 changes: 37 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,41 @@ def apply_patches(self) -> None:
)

if self.cfg.flash_attention:
use_fa3 = False
if self.cfg.use_flash_attention_3 is True:
use_fa3 = True
elif self.cfg.use_flash_attention_3 == "auto":
if torch.cuda.get_device_capability() >= (9, 0):
# FA3 is only available on Hopper GPUs and newer
use_fa3 = True
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)

def flash_attn_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_func_v3(*args, **kwargs)[0]

def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_varlen_func_v3(*args, **kwargs)[0]

transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3_wrapper
)
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
flash_attn_varlen_func_v3_wrapper
)
LOG.info("Switched to Flash Attention v3")

self.patch_attention()

if self.cfg.sample_packing and self.cfg.s2_attention:
Expand Down Expand Up @@ -699,13 +734,15 @@ def patch_attention(self) -> None:

patch_mllama()

# TODO deprecate soon
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)

replace_btlm_attn_with_flash_attn(self.cfg.base_model)

# TODO deprecate soon
if (
self.model_config.model_type == "stablelm_epoch"
and self.cfg.sample_packing
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class AxolotlInputConfig(
flash_attn_fuse_qkv: bool | None = None
flash_attn_fuse_mlp: bool | None = None
flash_optimum: bool | None = None
use_flash_attention_3: Literal["auto"] | bool | None = None

eager_attention: bool | None = None

Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def temp_dir():

@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
import transformers.modeling_flash_attention_utils
from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
Expand All @@ -434,6 +435,19 @@ def cleanup_monkeypatches():
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
original_fa_func = None
original_fa_varlen_func = None
if (
importlib.util.find_spec("flash_attn")
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
and hasattr(
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
)
):
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
original_fa_varlen_func = (
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
)
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
Expand All @@ -444,6 +458,11 @@ def cleanup_monkeypatches():
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
if original_fa_func:
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
original_fa_varlen_func
)

# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
Expand All @@ -458,6 +477,7 @@ def cleanup_monkeypatches():
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",),
("transformers.modeling_flash_attention_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
Expand Down
9 changes: 8 additions & 1 deletion tests/e2e/multigpu/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ def test_lora_ddp(self, temp_dir):
"gradient_accumulation_steps",
[1, 2],
)
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
@pytest.mark.parametrize(
"use_flash_attention_3",
[False, "auto"],
)
def test_lora_ddp_packed(
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
Expand Down Expand Up @@ -138,6 +144,7 @@ def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"use_flash_attention_3": use_flash_attention_3,
}
)

Expand Down
6 changes: 2 additions & 4 deletions tests/e2e/test_packing_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import os
import unittest

from transformers.utils import is_torch_bf16_gpu_available

Expand All @@ -14,18 +13,17 @@
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault

from .utils import check_tensorboard, with_temp_dir
from .utils import check_tensorboard

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestPackedLlama(unittest.TestCase):
class TestPackedLlama:
"""
Test case for Packed training of llama models
"""

@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
Expand Down
Loading