Skip to content

Commit f94ae3c

Browse files
committed
split fa3 compile
1 parent d79da6e commit f94ae3c

File tree

7 files changed

+223
-112
lines changed

7 files changed

+223
-112
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
22
project(sgl-kernel LANGUAGES CXX CUDA)
33

4-
# we only want to download 3rd, but not build them.
5-
# FetchContent_MakeAvailable will build it.
64
cmake_policy(SET CMP0169 OLD)
5+
6+
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
7+
78
set(BUILD_FA3, OFF)
89

910
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
@@ -23,6 +24,8 @@ elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
2324
endif()
2425

2526
find_package(Torch REQUIRED)
27+
# clean Torch Flag
28+
clear_cuda_arches(CMAKE_FLAG)
2629

2730
include(FetchContent)
2831

@@ -93,14 +96,13 @@ set(SGL_KERNEL_CUDA_FLAGS
9396
"-gencode=arch=compute_90,code=sm_90"
9497
"-std=c++17"
9598
"-DFLASHINFER_ENABLE_F16"
99+
"-DCUTE_USE_PACKED_TUPLE=1"
96100
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
97101
"-DCUTLASS_VERSIONS_GENERATED"
98-
"-DCUTE_USE_PACKED_TUPLE=1"
99102
"-DCUTLASS_TEST_LEVEL=0"
100103
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
101104
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
102105
"--expt-relaxed-constexpr"
103-
"--use_fast_math"
104106
"-Xcompiler=-Wconversion"
105107
"-Xcompiler=-fno-strict-aliasing"
106108
)
@@ -180,18 +182,36 @@ set(SOURCES
180182
"csrc/speculative/eagle_utils.cu"
181183
"csrc/speculative/speculative_sampling.cu"
182184
"csrc/speculative/packbit.cu"
183-
"csrc/torch_extension.cc"
185+
"csrc/common_extension.cc"
184186
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
185187
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
186188
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
187-
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
188-
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
189-
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
190189
)
191190

192191
# set flash-attention sources file
193192
# BF16 source files
194193
if (BUILD_FA3)
194+
set(SGL_FLASH_KERNEL_CUDA_FLAGS
195+
"-DNDEBUG"
196+
"-DOPERATOR_NAMESPACE=sgl-kernel"
197+
"-O3"
198+
"-Xcompiler"
199+
"-fPIC"
200+
"-gencode=arch=compute_90a,code=sm_90a"
201+
"-std=c++17"
202+
"-DCUTE_USE_PACKED_TUPLE=1"
203+
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
204+
"-DCUTLASS_VERSIONS_GENERATED"
205+
"-DCUTLASS_TEST_LEVEL=0"
206+
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
207+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
208+
"--expt-relaxed-constexpr"
209+
"--expt-extended-lambda"
210+
"--use_fast_math"
211+
"-Xcompiler=-Wconversion"
212+
"-Xcompiler=-fno-strict-aliasing"
213+
)
214+
195215
file(GLOB FA3_BF16_GEN_SRCS
196216
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
197217
file(GLOB FA3_BF16_GEN_SRCS_
@@ -214,27 +234,23 @@ if (BUILD_FA3)
214234

215235
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
216236

217-
list(APPEND SOURCES
237+
set(FLASH_SOURCES
238+
"csrc/flash_extension.cc"
218239
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
219240
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
220241
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
221-
"${FA3_GEN_SRCS}")
222-
endif()
242+
"${FA3_GEN_SRCS}"
243+
)
223244

224-
# Support abi3 for build
225-
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
245+
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
226246

227-
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
247+
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
248+
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
249+
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
228250

229-
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
251+
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
230252

231-
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
232-
233-
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
234-
235-
# Add some flash-attention custom flag for inference
236-
if (BUILD_FA3)
237-
target_compile_definitions(common_ops PRIVATE
253+
target_compile_definitions(flash_ops PRIVATE
238254
FLASHATTENTION_DISABLE_SM8x
239255
FLASHATTENTION_DISABLE_BACKWARD
240256
FLASHATTENTION_DISABLE_DROPOUT
@@ -246,6 +262,14 @@ if (BUILD_FA3)
246262
)
247263
endif()
248264

265+
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
266+
267+
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
268+
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
269+
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
270+
271+
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
272+
249273
# JIT Logic
250274
# DeepGEMM
251275

sgl-kernel/cmake/utils.cmake

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake
2+
#
3+
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
4+
# `CUDA_ARCH_FLAGS`.
5+
#
6+
# Example:
7+
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
8+
# clear_cuda_arches(CUDA_ARCH_FLAGS)
9+
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
10+
# CMAKE_CUDA_FLAGS="-Wall"
11+
#
12+
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
13+
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
14+
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
15+
${CMAKE_CUDA_FLAGS})
16+
17+
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
18+
# and passed back via the `CUDA_ARCHITECTURES` property.
19+
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
20+
${CMAKE_CUDA_FLAGS})
21+
endmacro()

sgl-kernel/csrc/torch_extension.cc renamed to sgl-kernel/csrc/common_extension.cc

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818

1919
#include "sgl_kernel_ops.h"
2020

21-
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
21+
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
2222
/*
2323
* From csrc/allreduce
2424
*/
@@ -202,45 +202,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
202202
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
203203
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
204204
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
205-
206-
/*
207-
* From flash-attention
208-
*/
209-
m.def(
210-
"fwd(Tensor! q,"
211-
" Tensor k,"
212-
" Tensor v,"
213-
" Tensor? k_new,"
214-
" Tensor? v_new,"
215-
" Tensor? q_v,"
216-
" Tensor!? out,"
217-
" Tensor? cu_seqlens_q,"
218-
" Tensor? cu_seqlens_k,"
219-
" Tensor? cu_seqlens_k_new,"
220-
" Tensor? seqused_q,"
221-
" Tensor? seqused_k,"
222-
" int? max_seqlen_q,"
223-
" int? max_seqlen_k,"
224-
" Tensor? page_table,"
225-
" Tensor? kv_batch_idx,"
226-
" Tensor? leftpad_k,"
227-
" Tensor? rotary_cos,"
228-
" Tensor? rotary_sin,"
229-
" Tensor? seqlens_rotary,"
230-
" Tensor? q_descale,"
231-
" Tensor? k_descale,"
232-
" Tensor? v_descale,"
233-
" float softmax_scale,"
234-
" bool is_causal,"
235-
" int window_size_left,"
236-
" int window_size_right,"
237-
" float softcap,"
238-
" bool is_rotary_interleaved,"
239-
" Tensor? scheduler_metadata,"
240-
" int num_splits,"
241-
" bool? pack_gqa,"
242-
" int sm_margin) -> Tensor[]");
243-
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
244205
}
245206

246207
REGISTER_EXTENSION(common_ops)

sgl-kernel/csrc/flash_extension.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/* Copyright 2025 SGLang Team. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include <ATen/core/dispatch/Dispatcher.h>
16+
#include <torch/all.h>
17+
#include <torch/library.h>
18+
19+
#include "sgl_flash_kernel_ops.h"
20+
21+
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
22+
/*
23+
* From flash-attention
24+
*/
25+
m.def(
26+
"fwd(Tensor! q,"
27+
" Tensor k,"
28+
" Tensor v,"
29+
" Tensor? k_new,"
30+
" Tensor? v_new,"
31+
" Tensor? q_v,"
32+
" Tensor!? out,"
33+
" Tensor? cu_seqlens_q,"
34+
" Tensor? cu_seqlens_k,"
35+
" Tensor? cu_seqlens_k_new,"
36+
" Tensor? seqused_q,"
37+
" Tensor? seqused_k,"
38+
" int? max_seqlen_q,"
39+
" int? max_seqlen_k,"
40+
" Tensor? page_table,"
41+
" Tensor? kv_batch_idx,"
42+
" Tensor? leftpad_k,"
43+
" Tensor? rotary_cos,"
44+
" Tensor? rotary_sin,"
45+
" Tensor? seqlens_rotary,"
46+
" Tensor? q_descale,"
47+
" Tensor? k_descale,"
48+
" Tensor? v_descale,"
49+
" float softmax_scale,"
50+
" bool is_causal,"
51+
" int window_size_left,"
52+
" int window_size_right,"
53+
" float softcap,"
54+
" bool is_rotary_interleaved,"
55+
" Tensor? scheduler_metadata,"
56+
" int num_splits,"
57+
" bool? pack_gqa,"
58+
" int sm_margin) -> Tensor[]");
59+
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
60+
}
61+
62+
REGISTER_EXTENSION(flash_ops)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/* Copyright 2025 SGLang Team. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <ATen/ATen.h>
19+
#include <ATen/Tensor.h>
20+
#include <Python.h>
21+
#include <torch/library.h>
22+
#include <torch/torch.h>
23+
24+
#include <vector>
25+
26+
#include "sgl_kernel_torch_shim.h"
27+
28+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
29+
30+
#define _CONCAT(A, B) A##B
31+
#define CONCAT(A, B) _CONCAT(A, B)
32+
33+
#define _STRINGIFY(A) #A
34+
#define STRINGIFY(A) _STRINGIFY(A)
35+
36+
#define REGISTER_EXTENSION(NAME) \
37+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
38+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
39+
return PyModule_Create(&module); \
40+
}
41+
42+
/*
43+
* From flash-attention
44+
*/
45+
std::vector<at::Tensor> mha_fwd(
46+
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
47+
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
48+
// h_k, d) if there is page_table.
49+
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
50+
// page_size, h_k, dv) if there is page_table.
51+
std::optional<const at::Tensor>&
52+
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
53+
std::optional<const at::Tensor>&
54+
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
55+
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
56+
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
57+
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
58+
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
59+
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
60+
std::optional<const at::Tensor>&
61+
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
62+
std::optional<const at::Tensor>&
63+
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
64+
std::optional<int> max_seqlen_q_,
65+
// TODO: check if we need max_seqlen_k
66+
std::optional<int> max_seqlen_k_,
67+
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
68+
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
69+
std::optional<const at::Tensor>& leftpad_k_, // b
70+
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
71+
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
72+
std::optional<const at::Tensor>& seqlens_rotary_, // b
73+
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
74+
std::optional<at::Tensor>& k_descale_, // (b, h_k)
75+
std::optional<at::Tensor>& v_descale_, // (b, h_k)
76+
float const softmax_scale,
77+
bool is_causal,
78+
int window_size_left,
79+
int window_size_right,
80+
float const softcap,
81+
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
82+
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
83+
int num_splits,
84+
std::optional<bool> pack_gqa_,
85+
int const sm_margin);

0 commit comments

Comments
 (0)