-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Create base docker images for CUDA 12.8 with custom FlashAttention 3 installed #2685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
9b0be4f
c126d5c
4ef2e82
65c6c98
ddd7c55
79daf5b
e1b74d7
37220ab
bd34d0b
9a3d0c9
34b68dd
fb5ef6d
a064f1c
8c4bc59
b221507
323a9cb
bb6464c
0735454
d6f64a3
9bdf4b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -47,11 +47,18 @@ jobs: | |||||
pytorch: 2.7.0 | ||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" | ||||||
- cuda: "128" | ||||||
cuda_version: 12.6.3 | ||||||
cuda_version: 12.8.1 | ||||||
cudnn_version: "" | ||||||
python_version: "3.11" | ||||||
pytorch: 2.7.0 | ||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" | ||||||
- cuda: "126" | ||||||
cuda_version: 12.6.3 | ||||||
cudnn_version: "" | ||||||
python_version: "3.11" | ||||||
pytorch: 2.6.0 | ||||||
suffix: "-hopper" | ||||||
torch_cuda_arch_list: "9.0+PTX" | ||||||
- cuda: "128" | ||||||
cuda_version: 12.8.1 | ||||||
cudnn_version: "" | ||||||
|
@@ -87,7 +94,7 @@ jobs: | |||||
context: . | ||||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }} | ||||||
push: ${{ github.event_name != 'pull_request' }} | ||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} | ||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix tag-generation referencing an undefined property - tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
+ tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.suffix || '' }} This will ensure your CI tags render correctly. 📝 Committable suggestion
Suggested change
🧰 Tools🪛 actionlint (1.7.7)97-97: property "axolotl_extras" is not defined in object type {cuda: number; cuda_version: string; cudnn_version: string; python_version: number; pytorch: string; suffix: string; torch_cuda_arch_list: string} (expression) 🤖 Prompt for AI Agents
|
||||||
labels: ${{ steps.metadata.outputs.labels }} | ||||||
build-args: | | ||||||
CUDA_VERSION=${{ matrix.cuda_version }} | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \ | |
fi | ||
|
||
RUN pip install packaging==23.2 setuptools==75.8.0 | ||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \ | ||
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this your built wheel? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, built it in a docker image on a lambda instance, and then extracted it |
||
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ | ||
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \ | ||
fi | ||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ | ||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ | ||
else \ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taken from readme
As only pytorch 2.7 has 12.8, should we swap to that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fa3 doesn't compile with torch 2.7.0 due to an error that I don't have offhand, but is related to an API change in torch.