Skip to content

Support cuda 12.8.1 and SBSA wheels #1507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,25 @@ jobs:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
# Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-22.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1']
cuda-version: ['12.9.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.5 does not support Python 3.13
- torch-version: '2.4.0'
python-version: '3.13'

# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ ubuntu-22.04, ubuntu-22.04-arm ]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
torch-version: ['2.4.0', '2.5.1', '2.6.0', '2.7.1']
cuda-version: ['12.9.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.5 does not support Python 3.13
- torch-version: '2.4.0'
python-version: '3.13'
steps:
- name: Checkout
uses: actions/checkout@v4
Expand All @@ -72,6 +70,7 @@ jobs:
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_ARCH=$(echo ${{ matrix.arch }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
Expand Down Expand Up @@ -103,6 +102,8 @@ jobs:
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools==75.8.0
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
Expand All @@ -118,10 +119,10 @@ jobs:
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
# pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
# Hard-coding this version of pytorch-triton for torch 2.8.0.dev20250425
pip install jinja2
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.3.0+gitab727c40-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_${{ matrix.arch }}.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_${{ matrix.arch }}.whl
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
Expand All @@ -132,7 +133,6 @@ jobs:
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash

- name: Build wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
Expand Down Expand Up @@ -188,7 +188,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.12'

- name: Install dependencies
run: |
Expand Down
40 changes: 34 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ def cuda_archs() -> str:
return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;120").split(";")


def get_platform():
def get_arch():
"""
Returns the platform name as used in wheel filenames.
Returns the system aarch for the current system.
"""
if sys.platform.startswith("linux"):
return f'linux_{platform.uname().machine}'
if platform.machine() == "x86_64":
return "x86_64"
if platform.machine() == "arm64" or platform.machine() == "aarch64":
return "aarch64"
elif sys.platform == "darwin":
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
return f"macosx_{mac_version}_x86_64"
Expand All @@ -85,6 +88,28 @@ def get_platform():
raise ValueError("Unsupported platform: {}".format(sys.platform))


def get_system() -> str:
"""
Returns the system name as used in wheel filenames.
"""
if platform.system() == "Windows":
return "win"
elif platform.system() == "Darwin":
mac_version = ".".join(platform.mac_ver()[0].split(".")[:1])
return f"macos_{mac_version}"
elif platform.system() == "Linux":
return "linux"
else:
raise ValueError("Unsupported system: {}".format(platform.system()))


def get_platform() -> str:
"""
Returns the platform name as used in wheel filenames.
"""
return f"{get_system()}_{get_arch()}"


def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
Expand Down Expand Up @@ -189,6 +214,9 @@ def validate_and_update_archs(archs):
if bare_metal_version >= Version("12.8") and "120" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_120,code=sm_120")
if bare_metal_version >= Version("13.0") and "110" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_110,code=sm_110")

# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
Expand Down Expand Up @@ -419,7 +447,7 @@ def validate_and_update_archs(archs):
)


def get_package_version():
def get_package_version() -> str:
with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
public_version = ast.literal_eval(version_match.group(1))
Expand All @@ -430,7 +458,7 @@ def get_package_version():
return str(public_version)


def get_wheel_url():
def get_wheel_url() -> tuple[str, str]:
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
Expand Down Expand Up @@ -469,7 +497,7 @@ class CachedWheelsCommand(_bdist_wheel):
wheel available and short-circuits the standard full build pipeline.
"""

def run(self):
def run(self) -> None:
if FORCE_BUILD:
return super().run()

Expand Down