Skip to content

Modernized packaging metadata #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ For the stable Triton-only version, refer to [SageAttention-1](https://github.co
pip install sageattention==1.0.6
```

To use SageAttention 2.1.1, please **compile from source**:
To use SageAttention 2.1.1, please first
[install PyTorch](https://pytorch.org/get-started/locally/) and then
**install from source**:
```
git clone https://github.com/thu-ml/SageAttention.git
cd sageattention
python setup.py install # or pip install -e .
pip install --no-build-isolation git+https://github.com/thu-ml/SageAttention.git
```

To benchmark the speed against FlashAttention3, please compile FlashAttention3 from source:
Expand Down
29 changes: 29 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[build-system]
requires = [
"packaging",
"setuptools >= 64",
]
build-backend = "setuptools.build_meta"

[project]
name = "sageattention"
description = "Accurate and efficient plug-and-play low-bit attention"
version = "2.1.1"
readme = "README.md"
authors = [
{name = "SageAttention team"}
]
license = "Apache-2.0"
requires-python = ">= 3.9"
dependencies = [
"torch >= 2.3.0",
"triton >= 3.0.0",
]

[project.urls]
"Source code" = "https://github.com/thu-ml/SageAttention"
"Issue tracker" = "https://github.com/thu-ml/SageAttention/issues"

[tool.setuptools.packages.find]
include = ["sageattention*"]
namespaces = false
327 changes: 165 additions & 162 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,175 +14,178 @@
limitations under the License.
"""

import os
import subprocess
from packaging.version import parse, Version
from typing import List, Set
import sys
import warnings

from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME

HAS_SM80 = False
HAS_SM86 = False
HAS_SM89 = False
HAS_SM90 = False
HAS_SM120 = False

# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "12.0"}

# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]

if CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")

def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"""Get the CUDA version from nvcc.

Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
from typing import Any

from setuptools import setup


def get_extra_args() -> dict[str, Any]:
# Bypass detection unless we're building the extensions
build_commands = {"bdist_wheel", "build_ext", "develop", "install"}
if not build_commands.intersection(sys.argv):
return {}

from packaging.version import parse, Version

try:
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
except ImportError as exc:
raise ImportError(
"Failed to import 'torch' which is required to build this package. "
"Ensure that 'torch' is installed and visible to the Python interpreter, "
"and that you are using the '--no-build-isolation' flag when installing "
"this package with 'pip'."
) from exc

HAS_SM80 = False
HAS_SM86 = False
HAS_SM89 = False
HAS_SM90 = False
HAS_SM120 = False

# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]

if CUDA_HOME is None:
raise RuntimeError(
"Cannot find a CUDA installation. "
"CUDA must be available to build the package."
)

# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
compute_capabilities = set()
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
continue
compute_capabilities.add(f"{major}.{minor}")

# Get the CUDA version from nvcc.
# Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], text=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version

# Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
compute_capabilities = set()
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}")
continue
compute_capabilities.add(f"{major}.{minor}")

nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities:
raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.")
else:
print(f"Detect GPUs with compute capabilities: {compute_capabilities}")

# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("12.0"):
raise RuntimeError("CUDA 12.0 or higher is required to build the package.")
if nvcc_cuda_version < Version("12.4") and any(cc.startswith("8.9") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 12.4 or higher is required for compute capability 8.9.")
if nvcc_cuda_version < Version("12.3") and any(cc.startswith("9.0") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 12.3 or higher is required for compute capability 9.0.")
if nvcc_cuda_version < Version("12.8") and any(cc.startswith("12.0") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 12.8 or higher is required for compute capability 12.0.")

# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
if capability.startswith("8.0"):
HAS_SM80 = True
num = "80"
elif capability.startswith("8.6"):
HAS_SM86 = True
num = "86"
elif capability.startswith("8.9"):
HAS_SM89 = True
num = "89"
elif capability.startswith("9.0"):
HAS_SM90 = True
num = "90a" # need to use sm90a instead of sm90 to use wgmma ptx instruction.
elif capability.startswith("12.0"):
HAS_SM120 = True
num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions.
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]

ext_modules = []

if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120:
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm80",
sources=[
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)

if HAS_SM89 or HAS_SM120:
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm89",
sources=[
"csrc/qattn/pybind_sm89.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)

if HAS_SM90:
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm90",
sources=[
"csrc/qattn/pybind_sm90.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
],
if not compute_capabilities:
raise RuntimeError("No GPUs found. Please specify the target GPU architectures or build on a machine with GPUs.")
else:
print(f"Detect GPUs with compute capabilities: {compute_capabilities}")

# Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("12.0"):
raise RuntimeError("CUDA 12.0 or higher is required to build the package.")
if nvcc_cuda_version < Version("12.4") and any(cc.startswith("8.9") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 12.4 or higher is required for compute capability 8.9.")
if nvcc_cuda_version < Version("12.3") and any(cc.startswith("9.0") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 12.3 or higher is required for compute capability 9.0.")
if nvcc_cuda_version < Version("12.8") and any(cc.startswith("12.0") for cc in compute_capabilities):
raise RuntimeError(
"CUDA 12.8 or higher is required for compute capability 12.0.")

# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
if capability.startswith("8.0"):
HAS_SM80 = True
num = "80"
elif capability.startswith("8.6"):
HAS_SM86 = True
num = "86"
elif capability.startswith("8.9"):
HAS_SM89 = True
num = "89"
elif capability.startswith("9.0"):
HAS_SM90 = True
num = "90a" # need to use sm90a instead of sm90 to use wgmma ptx instruction.
elif capability.startswith("12.0"):
HAS_SM120 = True
num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions.

NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]

ext_modules = []

if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120:
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm80",
sources=[
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)

if HAS_SM89 or HAS_SM120:
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm89",
sources=[
"csrc/qattn/pybind_sm89.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm89.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(qattn_extension)

if HAS_SM90:
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm90",
sources=[
"csrc/qattn/pybind_sm90.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
extra_link_args=['-lcuda'],
)
ext_modules.append(qattn_extension)

# Fused kernels.
fused_extension = CUDAExtension(
name="sageattention._fused",
sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
extra_link_args=['-lcuda'],
)
ext_modules.append(qattn_extension)

# Fused kernels.
fused_extension = CUDAExtension(
name="sageattention._fused",
sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(fused_extension)

setup(
name='sageattention',
version='2.1.1',
author='SageAttention team',
license='Apache 2.0 License',
description='Accurate and efficient plug-and-play low-bit attention.',
long_description=open('README.md', encoding='utf-8').read(),
long_description_content_type='text/markdown',
url='https://github.com/thu-ml/SageAttention',
packages=find_packages(),
python_requires='>=3.9',
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)
ext_modules.append(fused_extension)

return {
"extensions": ext_modules,
"cmdclass": {"build_ext": BuildExtension},
}


setup(**get_extra_args())