Skip to content

[Kernel] moe wna16 marlin kernel #14447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 83 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
3b79aeb
moe wna16 marlin kernel
jinzhen-lin Mar 7, 2025
0fcf347
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Mar 8, 2025
8d60a34
reinit marlin_moe_wna16.cu with gptq_marlin.cu
jinzhen-lin Mar 8, 2025
77e301e
update marlin_moe_wna16.cu
jinzhen-lin Mar 8, 2025
16b8bdc
fix format error
jinzhen-lin Mar 8, 2025
f6519b9
fix format error
jinzhen-lin Mar 8, 2025
5dec3f5
add missing endif
jinzhen-lin Mar 8, 2025
b4e4a95
fix format error
jinzhen-lin Mar 8, 2025
1c83210
fix format error
jinzhen-lin Mar 8, 2025
a0e0264
fix dimension check of c
jinzhen-lin Mar 9, 2025
4b23fc0
optimize
jinzhen-lin Mar 9, 2025
2d6ab1b
update
jinzhen-lin Mar 9, 2025
8772450
support ep
jinzhen-lin Mar 9, 2025
e6896d3
support act order
jinzhen-lin Mar 10, 2025
9e8aa10
fix format error
jinzhen-lin Mar 10, 2025
81840fe
fix error
jinzhen-lin Mar 10, 2025
731dd07
fix error
jinzhen-lin Mar 10, 2025
92b7226
fix format error
jinzhen-lin Mar 10, 2025
9b2f324
update test marlin moe
jinzhen-lin Mar 10, 2025
b8d2da5
fix is_k_full = false
jinzhen-lin Mar 10, 2025
383368b
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Mar 11, 2025
6ff3292
fix format error
jinzhen-lin Mar 11, 2025
e4dc8b1
fix format error
jinzhen-lin Mar 11, 2025
6ae80b2
fix format error
jinzhen-lin Mar 11, 2025
c2e7c6a
fix rare case
jinzhen-lin Mar 11, 2025
58a58bd
update CMakeLists.txt
jinzhen-lin Mar 11, 2025
ae83b25
add workspace size check
jinzhen-lin Mar 11, 2025
724a673
update test atol
jinzhen-lin Mar 11, 2025
e70118c
update dtype and func name
jinzhen-lin Mar 11, 2025
82f4ff8
fix format error
jinzhen-lin Mar 11, 2025
3508760
fix fake ops name
jinzhen-lin Mar 11, 2025
60a26e7
fix int32 overflow issue
jinzhen-lin Mar 12, 2025
d7abbbf
fix
jinzhen-lin Mar 12, 2025
8aae8ac
fix moe config
jinzhen-lin Mar 12, 2025
7d74c3d
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Mar 14, 2025
c828c46
split kernel
jinzhen-lin Mar 14, 2025
fb0d062
fix error
jinzhen-lin Mar 14, 2025
b9c656c
fix format
jinzhen-lin Mar 14, 2025
d9b43ac
fix format error
jinzhen-lin Mar 14, 2025
2b8c977
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Mar 14, 2025
6d6b2cf
update topk weight loading
jinzhen-lin Mar 23, 2025
acb1d19
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Mar 23, 2025
dc17c89
fix format error
jinzhen-lin Mar 23, 2025
a999ee0
m_block_size_8; fused (w-zp)*s
jinzhen-lin Mar 30, 2025
2a6bd00
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Mar 30, 2025
21acabf
enable m_block_size_8
jinzhen-lin Mar 30, 2025
fdcd542
update
jinzhen-lin Mar 30, 2025
f3c5186
update
jinzhen-lin Mar 30, 2025
b56c70c
allow multi blocks per sm for small batch
jinzhen-lin Apr 1, 2025
eeece4b
fix
jinzhen-lin Apr 1, 2025
f24c7d8
rerun
jinzhen-lin Apr 1, 2025
204b735
fix
jinzhen-lin Apr 2, 2025
99fbbcc
fix
jinzhen-lin Apr 2, 2025
a574824
optimize and fix global reduce when use_atomic_add=false
jinzhen-lin Apr 2, 2025
ed47e34
set use_atomic_add=false for sm80 + bfloat16
jinzhen-lin Apr 2, 2025
191c1f4
Merge remote-tracking branch 'origin/main' into moe-wna16-marlin-kernel
jinzhen-lin Apr 2, 2025
e12861d
fix format error
jinzhen-lin Apr 2, 2025
86d5708
fix format error
jinzhen-lin Apr 2, 2025
f83b5f2
fix format error
jinzhen-lin Apr 2, 2025
7643a11
fix typo
jinzhen-lin Apr 2, 2025
8487d52
don't build old moe marlin kernel
jinzhen-lin Apr 2, 2025
2d32751
remove ku8 support
jinzhen-lin Apr 3, 2025
bdb13fb
remove ku8 support
jinzhen-lin Apr 3, 2025
264e17c
fix
jinzhen-lin Apr 3, 2025
d551b17
remove ununsed kernel
jinzhen-lin Apr 3, 2025
3a2a80f
remove unused kernels
jinzhen-lin Apr 3, 2025
2ddfc22
remove unused kernels
jinzhen-lin Apr 3, 2025
8750d75
add comment
jinzhen-lin Apr 6, 2025
b1085e8
add shape check and fallback
jinzhen-lin Apr 6, 2025
97ab9f5
fix format error
jinzhen-lin Apr 6, 2025
800b407
fix typo
jinzhen-lin Apr 6, 2025
5c86442
Merge branch 'main' into moe-wna16-marlin-kernel
mgoin Apr 8, 2025
1b752a9
generate kernels when building
jinzhen-lin Apr 9, 2025
3ebe04a
Merge branch 'main' into moe-wna16-marlin-kernel
mgoin Apr 11, 2025
18c9e92
fix kernel config
jinzhen-lin Apr 11, 2025
36b240b
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Apr 11, 2025
b981337
Signed-off-by: Jinzhen Lin <[email protected]>
jinzhen-lin Apr 11, 2025
634b337
[DONT MERGE] debug
jinzhen-lin Apr 12, 2025
49c0d11
[DONT MERGE] fix
jinzhen-lin Apr 12, 2025
529ceb9
move zero_output to python code
jinzhen-lin Apr 12, 2025
2e8972b
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin Apr 12, 2025
6600eb8
optimize the case that k is very small
jinzhen-lin Apr 12, 2025
f0ea290
Merge branch 'main' into moe-wna16-marlin-kernel
mgoin Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -608,21 +608,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
"csrc/moe/marlin_moe_ops.cu")

file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}"
SRCS "${MOE_WNAA16_MARLIN_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")

list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})

message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else()
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
Expand Down
91 changes: 91 additions & 0 deletions csrc/moe/marlin_moe_wna16/generate_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
import os

import jinja2

FILE_HEAD = """
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
""".strip()

TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]

for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
settings = []
bit = int(scalar_type[8])
has_zp = "B" not in scalar_type
all_template_str_list = []

for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):

has_act_order = group_blocks == 0
if has_zp and has_act_order:
continue
if thread_configs[2] == 256:
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue

k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]

c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"

template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)

all_template_str_list.append(template_str)

file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"

with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
44 changes: 44 additions & 0 deletions csrc/moe/marlin_moe_wna16/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif

#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"

#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce

namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);

}
89 changes: 89 additions & 0 deletions csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );

}
109 changes: 109 additions & 0 deletions csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );

}
Loading