-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
[Kernel] moe wna16 marlin kernel #14447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vllm-bot
merged 83 commits into
vllm-project:main
from
jinzhen-lin:moe-wna16-marlin-kernel
Apr 15, 2025
Merged
Changes from 72 commits
Commits
Show all changes
83 commits
Select commit
Hold shift + click to select a range
3b79aeb
moe wna16 marlin kernel
jinzhen-lin 0fcf347
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin 8d60a34
reinit marlin_moe_wna16.cu with gptq_marlin.cu
jinzhen-lin 77e301e
update marlin_moe_wna16.cu
jinzhen-lin 16b8bdc
fix format error
jinzhen-lin f6519b9
fix format error
jinzhen-lin 5dec3f5
add missing endif
jinzhen-lin b4e4a95
fix format error
jinzhen-lin 1c83210
fix format error
jinzhen-lin a0e0264
fix dimension check of c
jinzhen-lin 4b23fc0
optimize
jinzhen-lin 2d6ab1b
update
jinzhen-lin 8772450
support ep
jinzhen-lin e6896d3
support act order
jinzhen-lin 9e8aa10
fix format error
jinzhen-lin 81840fe
fix error
jinzhen-lin 731dd07
fix error
jinzhen-lin 92b7226
fix format error
jinzhen-lin 9b2f324
update test marlin moe
jinzhen-lin b8d2da5
fix is_k_full = false
jinzhen-lin 383368b
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin 6ff3292
fix format error
jinzhen-lin e4dc8b1
fix format error
jinzhen-lin 6ae80b2
fix format error
jinzhen-lin c2e7c6a
fix rare case
jinzhen-lin 58a58bd
update CMakeLists.txt
jinzhen-lin ae83b25
add workspace size check
jinzhen-lin 724a673
update test atol
jinzhen-lin e70118c
update dtype and func name
jinzhen-lin 82f4ff8
fix format error
jinzhen-lin 3508760
fix fake ops name
jinzhen-lin 60a26e7
fix int32 overflow issue
jinzhen-lin d7abbbf
fix
jinzhen-lin 8aae8ac
fix moe config
jinzhen-lin 7d74c3d
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin c828c46
split kernel
jinzhen-lin fb0d062
fix error
jinzhen-lin b9c656c
fix format
jinzhen-lin d9b43ac
fix format error
jinzhen-lin 2b8c977
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin 6d6b2cf
update topk weight loading
jinzhen-lin acb1d19
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin dc17c89
fix format error
jinzhen-lin a999ee0
m_block_size_8; fused (w-zp)*s
jinzhen-lin 2a6bd00
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin 21acabf
enable m_block_size_8
jinzhen-lin fdcd542
update
jinzhen-lin f3c5186
update
jinzhen-lin b56c70c
allow multi blocks per sm for small batch
jinzhen-lin eeece4b
fix
jinzhen-lin f24c7d8
rerun
jinzhen-lin 204b735
fix
jinzhen-lin 99fbbcc
fix
jinzhen-lin a574824
optimize and fix global reduce when use_atomic_add=false
jinzhen-lin ed47e34
set use_atomic_add=false for sm80 + bfloat16
jinzhen-lin 191c1f4
Merge remote-tracking branch 'origin/main' into moe-wna16-marlin-kernel
jinzhen-lin e12861d
fix format error
jinzhen-lin 86d5708
fix format error
jinzhen-lin f83b5f2
fix format error
jinzhen-lin 7643a11
fix typo
jinzhen-lin 8487d52
don't build old moe marlin kernel
jinzhen-lin 2d32751
remove ku8 support
jinzhen-lin bdb13fb
remove ku8 support
jinzhen-lin 264e17c
fix
jinzhen-lin d551b17
remove ununsed kernel
jinzhen-lin 3a2a80f
remove unused kernels
jinzhen-lin 2ddfc22
remove unused kernels
jinzhen-lin 8750d75
add comment
jinzhen-lin b1085e8
add shape check and fallback
jinzhen-lin 97ab9f5
fix format error
jinzhen-lin 800b407
fix typo
jinzhen-lin 5c86442
Merge branch 'main' into moe-wna16-marlin-kernel
mgoin 1b752a9
generate kernels when building
jinzhen-lin 3ebe04a
Merge branch 'main' into moe-wna16-marlin-kernel
mgoin 18c9e92
fix kernel config
jinzhen-lin 36b240b
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin b981337
Signed-off-by: Jinzhen Lin <[email protected]>
jinzhen-lin 634b337
[DONT MERGE] debug
jinzhen-lin 49c0d11
[DONT MERGE] fix
jinzhen-lin 529ceb9
move zero_output to python code
jinzhen-lin 2e8972b
Merge branch 'main' into moe-wna16-marlin-kernel
jinzhen-lin 6600eb8
optimize the case that k is very small
jinzhen-lin f0ea290
Merge branch 'main' into moe-wna16-marlin-kernel
mgoin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import itertools | ||
import os | ||
|
||
import jinja2 | ||
|
||
FILE_HEAD = """ | ||
// auto generated by generate.py | ||
// clang-format off | ||
|
||
#include "kernel.h" | ||
#include "marlin_template.h" | ||
|
||
namespace MARLIN_NAMESPACE_NAME { | ||
""".strip() | ||
|
||
TEMPLATE = ("template __global__ void Marlin<" | ||
"{{scalar_t}}, " | ||
"{{w_type_id}}, " | ||
"{{threads}}, " | ||
"{{thread_m_blocks}}, " | ||
"{{thread_n_blocks}}, " | ||
"{{thread_k_blocks}}, " | ||
"{{'true' if m_block_size_8 else 'false'}}, " | ||
"{{stages}}, " | ||
"{{'true' if has_act_order else 'false'}}, " | ||
"{{'true' if has_zp else 'false'}}, " | ||
"{{group_blocks}}, " | ||
"{{'true' if is_zp_float else 'false'}}>" | ||
"( MARLIN_KERNEL_PARAMS );") | ||
|
||
# int8 with zero point case (vllm::kU8) is also supported, | ||
# we don't add it to reduce wheel size. | ||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] | ||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] | ||
|
||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] | ||
# group_blocks: | ||
# = 0 : act order case | ||
# = -1 : channelwise quantization | ||
# > 0 : group_size=16*group_blocks | ||
GROUP_BLOCKS = [0, -1, 2, 4, 8] | ||
jinzhen-lin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DTYPES = ["fp16", "bf16"] | ||
|
||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): | ||
settings = [] | ||
bit = int(scalar_type[8]) | ||
has_zp = "B" not in scalar_type | ||
all_template_str_list = [] | ||
|
||
for group_blocks, m_blocks, thread_configs in itertools.product( | ||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): | ||
|
||
has_act_order = group_blocks == 0 | ||
if has_zp and has_act_order: | ||
continue | ||
if thread_configs[2] == 256: | ||
if m_blocks <= 1 and thread_configs[0] != 128: | ||
continue | ||
if m_blocks > 1 and thread_configs[0] != 64: | ||
continue | ||
|
||
k_blocks = thread_configs[0] // 16 | ||
n_blocks = thread_configs[1] // 16 | ||
threads = thread_configs[2] | ||
|
||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" | ||
|
||
template_str = jinja2.Template(TEMPLATE).render( | ||
scalar_t=c_dtype, | ||
w_type_id=scalar_type + ".id()", | ||
threads=threads, | ||
thread_m_blocks=max(m_blocks, 1), | ||
thread_n_blocks=n_blocks, | ||
thread_k_blocks=k_blocks, | ||
m_block_size_8=m_blocks == 0.5, | ||
stages="pipe_stages", | ||
has_act_order=has_act_order, | ||
has_zp=has_zp, | ||
group_blocks=group_blocks, | ||
is_zp_float=False, | ||
) | ||
|
||
all_template_str_list.append(template_str) | ||
|
||
file_content = FILE_HEAD + "\n\n" | ||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" | ||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" | ||
|
||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: | ||
f.write(file_content) |
jinzhen-lin marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
|
||
#ifndef MARLIN_NAMESPACE_NAME | ||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 | ||
#endif | ||
|
||
#include "quantization/gptq_marlin/marlin.cuh" | ||
#include "quantization/gptq_marlin/marlin_dtypes.cuh" | ||
#include "core/scalar_type.hpp" | ||
|
||
#define MARLIN_KERNEL_PARAMS \ | ||
const int4 *__restrict__ A, const int4 *__restrict__ B, \ | ||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ | ||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ | ||
const int *__restrict__ g_idx, \ | ||
const int32_t *__restrict__ sorted_token_ids_ptr, \ | ||
const int32_t *__restrict__ expert_ids_ptr, \ | ||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \ | ||
const float *__restrict__ topk_weights_ptr, int top_k, \ | ||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ | ||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \ | ||
bool use_fp32_reduce | ||
|
||
namespace MARLIN_NAMESPACE_NAME { | ||
template <typename scalar_t, // compute dtype, half or nv_float16 | ||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id | ||
const int threads, // number of threads in a threadblock | ||
const int thread_m_blocks, // number of 16x16 blocks in the m | ||
// dimension (batchsize) of the | ||
// threadblock | ||
const int thread_n_blocks, // same for n dimension (output) | ||
const int thread_k_blocks, // same for k dimension (reduction) | ||
const bool m_block_size_8, // whether m_block_size == 8 | ||
// only works when thread_m_blocks == 1 | ||
const int stages, // number of stages for the async global->shared | ||
// fetch pipeline | ||
const bool has_act_order, // whether act_order is enabled | ||
const bool has_zp, // whether zero-points are enabled | ||
const int group_blocks, // number of consecutive 16x16 blocks | ||
// with a separate quantization scale | ||
const bool is_zp_float // is zero point of float16 type? | ||
> | ||
__global__ void Marlin(MARLIN_KERNEL_PARAMS); | ||
|
||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// auto generated by generate.py | ||
// clang-format off | ||
|
||
#include "kernel.h" | ||
#include "marlin_template.h" | ||
|
||
namespace MARLIN_NAMESPACE_NAME { | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// auto generated by generate.py | ||
// clang-format off | ||
|
||
#include "kernel.h" | ||
#include "marlin_template.h" | ||
|
||
namespace MARLIN_NAMESPACE_NAME { | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
template __global__ void Marlin<nv_bfloat16, vllm::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS ); | ||
|
||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.