Skip to content

Commit c859bd6

Browse files
mingfeimachunyuan-wyanbing-jblzheng
authored andcommitted
Add optimized native kernels in sgl-kernel (sgl-project#5150)
Co-authored-by: Chunyuan WU <[email protected]> Co-authored-by: YanbingJiang <[email protected]> Co-authored-by: blzheng <[email protected]>
1 parent 2d60fa1 commit c859bd6

20 files changed

+7792
-0
lines changed

sgl-kernel/csrc/cpu/activation.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#include "common.h"
2+
#include "vec.h"
3+
4+
namespace {
5+
6+
template <typename scalar_t, typename func_t, typename vec_func_t>
7+
void act_and_mul_kernel_impl(
8+
scalar_t* __restrict__ output,
9+
const scalar_t* __restrict__ input,
10+
int64_t num_tokens,
11+
int64_t dim,
12+
const func_t& f,
13+
const vec_func_t& vf) {
14+
using bVec = at::vec::Vectorized<scalar_t>;
15+
using fVec = at::vec::Vectorized<float>;
16+
17+
constexpr int64_t kVecSize = bVec::size();
18+
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
19+
for (int64_t i = begin; i < end; ++i) {
20+
// local ptrs
21+
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
22+
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
23+
scalar_t* __restrict__ output_ptr = output + i * dim;
24+
25+
int64_t d;
26+
#pragma GCC unroll 4
27+
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
28+
bVec x_bvec = bVec::loadu(input_ptr + d);
29+
fVec x_fvec0, x_fvec1;
30+
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
31+
32+
bVec y_bvec = bVec::loadu(input_other_ptr + d);
33+
fVec y_fvec0, y_fvec1;
34+
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
35+
36+
x_fvec0 = vf(x_fvec0);
37+
x_fvec1 = vf(x_fvec1);
38+
39+
x_fvec0 = x_fvec0 * y_fvec0;
40+
x_fvec1 = x_fvec1 * y_fvec1;
41+
42+
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
43+
x_bvec.store(output_ptr + d);
44+
}
45+
#pragma GCC unroll 4
46+
for (; d < dim; ++d) {
47+
float x_val = static_cast<float>(input_ptr[d]);
48+
float y_val = static_cast<float>(input_other_ptr[d]);
49+
output_ptr[d] = f(x_val) * y_val;
50+
}
51+
}
52+
});
53+
}
54+
55+
} // anonymous namespace
56+
57+
// input : {num_tokens, 2 * d}
58+
// output : {num_tokens, d}
59+
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
60+
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
61+
auto sizes = input.sizes().vec();
62+
int64_t last_dim = input.ndimension() - 1;
63+
int64_t d = sizes[last_dim] / 2;
64+
sizes[last_dim] = d;
65+
int64_t num_tokens = input.numel() / input.size(-1);
66+
at::Tensor out = at::empty(sizes, input.options());
67+
68+
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
69+
using Vec = at::vec::Vectorized<float>;
70+
act_and_mul_kernel_impl(
71+
out.data_ptr<scalar_t>(),
72+
input.data_ptr<scalar_t>(),
73+
num_tokens,
74+
d,
75+
[](float x) { return x / (1.f + std::exp(-x)); },
76+
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
77+
});
78+
return out;
79+
}

sgl-kernel/csrc/cpu/bmm.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#include "common.h"
2+
#include "gemm.h"
3+
#include "vec.h"
4+
5+
namespace {
6+
7+
template <typename scalar_t>
8+
void bmm_kernel_impl(
9+
scalar_t* __restrict__ out,
10+
const scalar_t* __restrict__ mat1,
11+
const scalar_t* __restrict__ mat2,
12+
int64_t B,
13+
int64_t M,
14+
int64_t N,
15+
int64_t K,
16+
int64_t mat1_strideB,
17+
int64_t mat1_strideM,
18+
int64_t out_strideB,
19+
int64_t out_strideM,
20+
float scale = 0.f) {
21+
constexpr int64_t BLOCK_M = block_size_m();
22+
constexpr int64_t BLOCK_N = block_size_n();
23+
const int64_t MB = div_up(M, BLOCK_M);
24+
const int64_t NB = div_up(N, BLOCK_N);
25+
26+
// mat2 contiguous in [B, N, K]
27+
int64_t mat2_strideB = N * K;
28+
int64_t mat2_strideN = K;
29+
30+
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
31+
32+
// parallel on [B, MB, NB]
33+
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
34+
int64_t bs{0}, mb{0}, nb{0};
35+
data_index_init(begin, bs, B, mb, MB, nb, NB);
36+
37+
// for brgemm, use float32 for accumulate
38+
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
39+
40+
for (int i = begin; i < end; ++i) {
41+
UNUSED(i);
42+
int mb_start = mb * BLOCK_M;
43+
int mb_size = std::min(M - mb_start, BLOCK_M);
44+
int nb_start = nb * BLOCK_N;
45+
int nb_size = std::min(N - nb_start, BLOCK_N);
46+
47+
tinygemm_kernel<scalar_t>(
48+
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
49+
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
50+
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
51+
/* Ctmp*/ Ctmp,
52+
/* M */ mb_size,
53+
/* N */ nb_size,
54+
/* K */ K,
55+
/* lda */ mat1_strideM,
56+
/* ldb */ nb_size,
57+
/* ldc */ out_strideM,
58+
/* brg */ use_brgemm);
59+
60+
// move to the next index
61+
data_index_step(bs, B, mb, MB, nb, NB);
62+
}
63+
64+
if (use_brgemm) {
65+
at::native::cpublas::brgemm_release();
66+
}
67+
});
68+
}
69+
70+
} // anonymous namespace
71+
72+
// mat1 : [B, M, K]
73+
// mat2 : [B, N, K] or [B, OC, IC]
74+
// out : [B, M, N]
75+
// scale: [] 0-dim tensor for per tensor quant
76+
//
77+
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) {
78+
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
79+
80+
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
81+
82+
// input and out could be non-contiguous
83+
// weight needs to be contiguous in [OC, IC] order
84+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
85+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
86+
CHECK_INPUT(mat2);
87+
CHECK_DIM(3, out);
88+
CHECK_DIM(3, mat1);
89+
CHECK_DIM(3, mat2);
90+
91+
int64_t B = mat1.size(0);
92+
int64_t M = mat1.size(1);
93+
int64_t N = mat2.size(1);
94+
int64_t K = mat1.size(2);
95+
96+
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
97+
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
98+
99+
int64_t mat1_strideB = mat1.stride(0);
100+
int64_t mat1_strideM = mat1.stride(1);
101+
int64_t out_strideB = out.stride(0);
102+
int64_t out_strideM = out.stride(1);
103+
104+
// check shapes
105+
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
106+
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
107+
108+
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
109+
bmm_kernel_impl<scalar_t>(
110+
out.data_ptr<scalar_t>(),
111+
mat1.data_ptr<scalar_t>(),
112+
packed_w.data_ptr<scalar_t>(),
113+
B,
114+
M,
115+
N,
116+
K,
117+
mat1_strideB,
118+
mat1_strideM,
119+
out_strideB,
120+
out_strideM);
121+
});
122+
}

sgl-kernel/csrc/cpu/common.h

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/Parallel.h>
5+
#include <ATen/record_function.h>
6+
7+
#if defined(_OPENMP)
8+
#include <omp.h>
9+
#endif
10+
11+
namespace {
12+
13+
// dispatch bool
14+
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
15+
[&] { \
16+
if (BOOL_V) { \
17+
constexpr bool BOOL_NAME = true; \
18+
return __VA_ARGS__(); \
19+
} else { \
20+
constexpr bool BOOL_NAME = false; \
21+
return __VA_ARGS__(); \
22+
} \
23+
}()
24+
25+
// dispatch: bfloat16, float16, int8_t
26+
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
27+
[&] { \
28+
switch (TYPE) { \
29+
case at::ScalarType::BFloat16: { \
30+
using packed_t = at::BFloat16; \
31+
return __VA_ARGS__(); \
32+
} \
33+
case at::ScalarType::Half: { \
34+
using packed_t = at::Half; \
35+
return __VA_ARGS__(); \
36+
} \
37+
case at::ScalarType::Char: { \
38+
using packed_t = int8_t; \
39+
return __VA_ARGS__(); \
40+
} \
41+
default: \
42+
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
43+
} \
44+
}()
45+
46+
#define UNUSED(x) (void)(x)
47+
48+
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
49+
50+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
51+
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
52+
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
53+
54+
#define CHECK_INPUT(x) \
55+
CHECK_CPU(x); \
56+
CHECK_CONTIGUOUS(x)
57+
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
58+
CHECK_CPU(x); \
59+
CHECK_LAST_DIM_CONTIGUOUS(x)
60+
61+
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
62+
63+
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
64+
65+
// parallel routines
66+
constexpr int GRAIN_SIZE = 1024;
67+
68+
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
69+
inline T div_up(T x, T y) {
70+
return (x + y - 1) / y;
71+
}
72+
73+
template <typename T>
74+
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
75+
#if 0
76+
// onednn partition pattern
77+
T& n_my = n_end;
78+
if (nth <= 1 || n == 0) {
79+
n_start = 0;
80+
n_my = n;
81+
} else {
82+
T n1 = div_up(n, nth);
83+
T n2 = n1 - 1;
84+
T T1 = n - n2 * nth;
85+
n_my = ith < T1 ? n1 : n2;
86+
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
87+
}
88+
n_end += n_start;
89+
#else
90+
// pytorch aten partition pattern
91+
T n_my = div_up(n, nth);
92+
n_start = ith * n_my;
93+
n_end = std::min(n_start + n_my, n);
94+
#endif
95+
}
96+
97+
template <typename func_t>
98+
inline void parallel_for(int n, const func_t& f) {
99+
#if defined(_OPENMP)
100+
#pragma omp parallel
101+
{
102+
int nth = omp_get_num_threads();
103+
int ith = omp_get_thread_num();
104+
int tbegin, tend;
105+
balance211(n, nth, ith, tbegin, tend);
106+
f(tbegin, tend);
107+
}
108+
#else
109+
f(0, n);
110+
#endif
111+
}
112+
113+
// data indexing for dimension collapse
114+
template <typename T>
115+
inline T data_index_init(T offset) {
116+
return offset;
117+
}
118+
119+
template <typename T, typename... Args>
120+
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
121+
offset = data_index_init(offset, std::forward<Args>(args)...);
122+
x = offset % X;
123+
return offset / X;
124+
}
125+
126+
inline bool data_index_step() {
127+
return true;
128+
}
129+
130+
template <typename T, typename... Args>
131+
inline bool data_index_step(T& x, const T& X, Args&&... args) {
132+
if (data_index_step(std::forward<Args>(args)...)) {
133+
x = ((x + 1) == X) ? 0 : (x + 1);
134+
return x == 0;
135+
}
136+
return false;
137+
}
138+
139+
// forced unroll for perf critical path
140+
141+
#if __has_attribute(always_inline)
142+
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
143+
#else
144+
#define ALWAYS_INLINE inline
145+
#endif
146+
147+
template <int n>
148+
struct Unroll {
149+
template <typename Func, typename... Args>
150+
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
151+
Unroll<n - 1>{}(f, args...);
152+
f(std::integral_constant<int, n - 1>{}, args...);
153+
}
154+
};
155+
156+
template <>
157+
struct Unroll<1> {
158+
template <typename Func, typename... Args>
159+
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
160+
f(std::integral_constant<int, 0>{}, args...);
161+
}
162+
};
163+
164+
} // anonymous namespace

0 commit comments

Comments
 (0)