[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -1 +1,2 @@
kernel_*.cu
sm*_kernel_*.cu
kernel_selector.h

View File

@@ -4,134 +4,282 @@ import glob
import itertools
import os
import subprocess
import sys
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
ARCHS = []
SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
"""
)
TEMPLATE = (
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{a_type_id}}, "
"{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{m_block_size_8}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
QUANT_CONFIGS = [
# AWQ-INT4
{
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# AWQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8",
"vllm::kU8B128",
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=False,
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
all_template_str_list.append(template_str)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__":
remove_old_kernels()

View File

@@ -11,8 +11,9 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
@@ -20,12 +21,13 @@
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the

File diff suppressed because it is too large Load Diff

View File

@@ -37,39 +37,6 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {};
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
template <int moe_block_size>
@@ -207,7 +174,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) {
int is_zp_float, bool is_a_8bit) {
int pack_factor = 32 / num_bits;
// Get B size
@@ -217,8 +184,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
@@ -250,7 +217,7 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) {
int max_shared_mem, bool is_a_8bit) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@@ -273,188 +240,34 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
}
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size + 512 <= max_shared_mem;
int cache_size =
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
return cache_size <= max_shared_mem;
}
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4)
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
#include "kernel_selector.h"
return kernel;
}
template <typename scalar_t>
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks,
bool m_block_size_8, int num_bits,
int group_size, bool has_act_order,
bool is_k_full, bool has_zp,
bool is_zp_float, int max_shared_mem) {
exec_config_t determine_exec_config(
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_a_8bit) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
@@ -471,73 +284,69 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem)) {
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : (group_size / 16);
}
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, th_config.thread_n / 16,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
group_blocks, th_config.num_threads, is_zp_float);
auto kernel =
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue;
if (thread_m_blocks > 1) {
exec_cfg = {1, th_config};
break;
} else {
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1024));
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1536));
if (thread_m_blocks == 1)
allow_count = max(min(allow_count, 4), 1);
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
else
allow_count = max(min(allow_count, 2), 1);
if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) {
allow_count =
max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1);
}
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
}
return exec_cfg;
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm,
void* a_tmp, void* sorted_token_ids, void* expert_ids,
void* num_tokens_past_padded, void* topk_weights,
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* perm, void* a_tmp, void* sorted_token_ids,
void* expert_ids, void* num_tokens_past_padded,
void* topk_weights, int moe_block_size, int num_experts,
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
bool is_a_8bit = a_type.size_bits() == 8;
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
@@ -563,14 +372,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
}
}
int num_bits = q_type.size_bits();
int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const float* a_s_ptr = (const float*)a_s;
const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
@@ -618,22 +428,41 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
// Set thread config
exec_config_t exec_cfg;
thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
exec_cfg = exec_config_t{1, thread_tfg};
thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64};
if (blocks_per_sm == -1) blocks_per_sm = 1;
exec_cfg = exec_config_t{blocks_per_sm, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
} else {
// Auto config
exec_cfg = determine_exec_config<scalar_t>(
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem);
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
thread_tfg = exec_cfg.tb_cfg;
}
@@ -647,22 +476,29 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order,
", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem, is_a_8bit),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
int sh_cache_size =
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
@@ -679,19 +515,20 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on
}
} // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
@@ -699,11 +536,70 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
int pack_factor = 32 / b_q_type.size_bits();
bool is_zp_float, int64_t thread_k, int64_t thread_n,
int64_t blocks_per_sm) {
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
int num_experts = b_q_weight.size(0);
if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0,
@@ -745,19 +641,27 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
torch::Tensor a_scales;
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// sms: number of SMs to use for the kernel
int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c;
if (c_or_none.has_value()) {
c = c_or_none.value();
@@ -774,8 +678,6 @@ torch::Tensor moe_wna16_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4
long max_c_tmp_size = min(
@@ -846,11 +748,11 @@ torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format.");
}
@@ -877,15 +779,15 @@ torch::Tensor moe_wna16_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
b_type == vllm::kU4 || b_type == vllm::kU8,
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = ",
b_q_type.str());
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"b_type must be uint4b8, uint8b128, int4, int8, "
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_type.str());
}
if (has_zp && is_zp_float) {
@@ -929,71 +831,33 @@ torch::Tensor moe_wna16_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false,
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
"scalar type of a_scales must be float");
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
"scalar type of global_scale must be the same with c");
if (a_type.size_bits() == 16) {
TORCH_CHECK(
a.scalar_type() == c.scalar_type(),
"scalar type of a must be the same with c for 16 bit activation");
}
MARLIN_NAMESPACE_NAME::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);
return c;
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}

View File

@@ -63,16 +63,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? global_scale, Tensor? "
"Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"bool mul_topk_weights, bool is_ep, int b_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
"bool use_fp32_reduce, bool is_zp_float,"
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "

View File

@@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < 2; ++k_idx) {
FType low16 =
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 =
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
FType low16 = MarlinScalarType2<FType>::float2num(
C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 = MarlinScalarType2<FType>::float2num(
C_frag[m_idx][n_idx][k_idx * 2 + 1]);
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
(reinterpret_cast<uint32_t&>(high16) << 16);
int sts_offset =

View File

@@ -8,7 +8,7 @@
#include <cuda_bf16.h>
#include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType;
using marlin::MarlinScalarType2;
namespace allspark {
@@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) {
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
}
C[idx] = ScalarType<FType>::float2num(sum);
C[idx] = MarlinScalarType2<FType>::float2num(sum);
}
template <typename FType>

View File

@@ -1 +1,2 @@
kernel_*.cu
sm*_kernel_*.cu
kernel_selector.h

View File

@@ -4,14 +4,16 @@
namespace marlin {
template <int const num_threads, int const num_bits>
template <int const num_threads, int const num_bits, bool is_a_8bit>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles;
@@ -33,10 +35,10 @@ __global__ void awq_marlin_repack_kernel(
extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor;
constexpr int tile_n_ints = target_tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_k_threads = target_tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
@@ -45,7 +47,7 @@ __global__ void awq_marlin_repack_kernel(
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n = n_tile_id * target_tile_n_size;
int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe;
@@ -54,7 +56,7 @@ __global__ void awq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k = k_tile_id * target_tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
@@ -78,11 +80,11 @@ __global__ void awq_marlin_repack_kernel(
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;
@@ -105,23 +107,50 @@ __global__ void awq_marlin_repack_kernel(
uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
if constexpr (is_a_8bit) {
int cur_elem = tc_row + i;
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * cur_elem];
int packed_src_1 =
sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) +
sh_stride * (cur_elem + 16)];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
} else {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 =
sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
if constexpr (!is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0;
#pragma unroll
@@ -138,8 +167,9 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
const int ii = is_a_8bit ? i : pack_idx[i];
res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
@@ -176,18 +206,21 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
#define CALL_IF(NUM_BITS, IS_A_8BIT) \
else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
int64_t size_n, int64_t num_bits,
bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
@@ -238,10 +271,13 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
if (false) {
}
CALL_IF(4)
CALL_IF(8)
CALL_IF(4, false)
CALL_IF(8, false)
CALL_IF(4, true)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", is_a_8bit = ", is_a_8bit);
}
return out;

View File

@@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
frag_b[0] = __hmul2(frag_b[0], bias_reg);
}
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
// Constants for FP4 (E2M1) and FP16 formats
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
constexpr int MASK = 0x70707070;
// Extract and shift FP4 values to FP16 format
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
q <<= 4;
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
// Note1: reverse indexing is intentional because weights are permuted
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
// `frag_b[1]` to fit the layout of tensorcore
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <>
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
int q, int32_t* frag_b) {
constexpr int repeated_zp = 0x08080808;
constexpr int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
int q, __nv_fp8x4_e4m3* frag_b) {
int s = q & 0x08080808;
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
q >>= 4;
s = q & 0x08080808;
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
@@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
};
// subtract zero point in quanted format and then dequant
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
template <>
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
int q, int32_t* frag_b, int zp) {
// INT4 with zp -> INT8
// see https://github.com/vllm-project/vllm/pull/24722
int repeated_zp = 0x01010101 * zp;
int MASK = 0x80808080;
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
q >>= 4;
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
}
template <>
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
true>(int q, __nv_fp8x4_e4m3* frag_b,
int zp) {
// INT4 with zp -> FP8
// see https://github.com/vllm-project/vllm/pull/24722
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
uint32_t u_zp1 = u_zp + 1;
uint32_t repeated_zp = 0x01010101 * u_zp;
uint32_t q0, s;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
u_q >>= 4;
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
s = (q0 + repeated_zp) & 0x80808080;
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
}
#endif

View File

@@ -4,141 +4,292 @@ import glob
import itertools
import os
import subprocess
import sys
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
ARCHS = []
SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
"""
)
TEMPLATE = (
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{a_type_id}}, "
"{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{m_block_size_8}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = [
"vllm::kU4",
"vllm::kU4B8",
"vllm::kU8B128",
"vllm::kFE4M3fn",
"vllm::kFE2M1f",
]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
QUANT_CONFIGS = [
# AWQ-INT4
{
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# HQQ
{
"a_type": ["kFloat16"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [4],
"is_zp_float": True,
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# GPTQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
is_zp_float = quant_config.get("is_zp_float", False)
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "true" if is_zp_float else "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8",
"vllm::kU8B128",
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
# nvfp4 only supports group_size == 16
# mxfp4 only supports group_size == 32
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
continue
# other quantization methods don't support group_size = 16
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
continue
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
is_zp_float_list = [False]
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
# HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16
is_zp_float_list.append(True)
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
s_type = "vllm::kFE4M3fn"
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
s_type = "vllm::kFE8M0fnu"
if dtype == "fp16":
# we cannot safely dequantize e8m0 to fp16, so skip this
continue
elif dtype == "fp16":
s_type = "vllm::kFloat16"
elif dtype == "bf16":
s_type = "vllm::kBFloat16"
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
s_type_id=s_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=is_zp_float,
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__":
remove_old_kernels()

View File

@@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm(
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
@@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int cache_size = get_kernel_cache_size(
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size + 512 <= max_shared_mem;
return cache_size <= max_shared_mem;
}
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
constexpr auto S_TYPE = \
W_TYPE == vllm::kFE2M1f \
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
: vllm::kBFloat16); \
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
#define NVFP4_GET_IF(W_TYPE) \
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
#define MXFP4_GET_IF(W_TYPE) \
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 4, 8, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 4, 8, 128)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
int thread_m_blocks, int thread_n_blocks,
int thread_k_blocks, bool m_block_size_8,
bool has_act_order, bool has_zp,
int group_blocks, int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
if (false) {
}
COMMON_GET_IF(vllm::kU4)
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
NVFP4_GET_IF(vllm::kFE2M1f)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
if (std::is_same<scalar_t, half>::value) {
if (false) {
}
FZP_GET_IF(vllm::kU4)
}
if (std::is_same<scalar_t, nv_bfloat16>::value) {
if (false) {
}
MXFP4_GET_IF(vllm::kFE2M1f)
}
#include "kernel_selector.h"
return kernel;
}
template <typename scalar_t>
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks,
bool m_block_size_8, int num_bits,
int group_size, bool has_act_order,
bool is_k_full, bool has_zp,
bool is_zp_float, int max_shared_mem,
int sms) {
exec_config_t determine_exec_config(
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
int num_bits, int group_size, bool has_act_order, bool is_k_full,
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
@@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) {
is_zp_float, max_shared_mem - 512)) {
continue;
}
@@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
group_blocks = group_size == -1 ? -1 : group_size / 16;
}
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, th_config.thread_n / 16,
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
group_blocks, th_config.num_threads, is_zp_float);
auto kernel =
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue;
@@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
return exec_cfg;
}
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* s, void* s2, void* zp, void* g_idx, void* perm,
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k,
int lda, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k_init,
int thread_n_init, int sms, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
"has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
@@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
}
}
int num_bits = q_type.size_bits();
int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const int4* s_ptr = (const int4*)s;
const uint16_t* s2_ptr = (const uint16_t*)s2;
const float* a_s_ptr = (const float*)a_s;
const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
int* locks = (int*)workspace;
if (has_act_order) {
@@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
int max_par = 16;
if (prob_n <= 4096) max_par = 16 * 8;
int max_shared_mem_new = max_shared_mem;
@@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
int thread_n = thread_n_init;
int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);
int m_block_size_8 = prob_m_split <= 8;
int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16;
// Set thread config
exec_config_t exec_cfg;
@@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
" is not divisible by thread_k = ", thread_k);
} else {
// Auto config
exec_cfg = determine_exec_config<scalar_t>(
q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem, sms);
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
thread_tfg = exec_cfg.tb_cfg;
if (thread_tfg.thread_n != -1) {
if (prob_n / thread_tfg.thread_n *
div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <=
sms) {
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem_new)) {
thread_tfg = {128, 64, 128};
exec_cfg = {1, thread_tfg};
}
}
}
if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {
max_thread_m_blocks--;
continue;
@@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem_new = ", max_shared_mem_new);
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks,
m_block_size_8, has_act_order, has_zp, group_blocks, num_threads,
is_zp_float);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
@@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr,
g_idx_ptr, num_groups,
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
use_fp32_reduce, max_shared_mem_new);
// clang-format on
A_ptr += prob_m_split * (lda / 8);
bool is_a_8bit = a_type.size_bits() == 8;
A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8));
a_s_ptr += prob_m_split;
C_ptr += prob_m_split * (prob_n / 8);
rest_m -= prob_m_split;
}
@@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
int pack_factor = 32 / b_q_type.size_bits();
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
@@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm(
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
torch::Tensor a_scales;
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
@@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c;
if (c_or_none.has_value()) {
c = c_or_none.value();
@@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm(
// Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce) {
int max_m_block_size = (size_m + 16 - 1) / 16 * 16;
max_m_block_size = min(max_m_block_size, 64);
@@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format.");
}
@@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm(
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
b_type == vllm::kU4 || b_type == vllm::kU8,
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
"float4_e2m1f when "
"has_zp = False. Got = ",
b_q_type.str());
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"b_type must be uint4b8, uint8b128, int4, int8, "
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_type.str());
}
if (has_zp && is_zp_float) {
@@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm(
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::Half>();
}
marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
void* scales_ptr;
if (b_q_type == vllm::kFE2M1f) {
if (group_size == 16)
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
else if (group_size == 32)
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
else
TORCH_CHECK(false,
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
"and group_size == 32 (MXFP4)");
} else {
scales_ptr = b_scales.data_ptr<at::BFloat16>();
}
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
"scalar type of a_scales must be float");
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
"scalar type of global_scale must be the same with c");
if (a_type.size_bits() == 16) {
TORCH_CHECK(
a.scalar_type() == c.scalar_type(),
"scalar type of a must be the same with c for 16 bit activation");
}
marlin::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0),
workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias,
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
use_atomic_add, use_fp32_reduce, is_zp_float);
return c;
}

View File

@@ -4,15 +4,18 @@
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
template <int const num_threads, int const num_bits, bool const has_perm,
bool is_a_8bit>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1);
constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1);
int k_tiles = size_k / target_tile_k_size;
int n_tiles = size_n / target_tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles;
@@ -34,7 +37,7 @@ __global__ void gptq_marlin_repack_kernel(
extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4;
constexpr int perm_size = target_tile_k_size / 4;
int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr;
@@ -42,14 +45,14 @@ __global__ void gptq_marlin_repack_kernel(
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int tile_ints = target_tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_n_threads = target_tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? target_tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int first_k_int4 = (k_tile_id * target_tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
@@ -65,7 +68,7 @@ __global__ void gptq_marlin_repack_kernel(
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n = n_tile_id * target_tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
@@ -91,7 +94,7 @@ __global__ void gptq_marlin_repack_kernel(
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k = k_tile_id * target_tile_k_size;
int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
@@ -117,13 +120,13 @@ __global__ void gptq_marlin_repack_kernel(
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2);
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr int sh_stride = target_tile_n_size;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
@@ -134,6 +137,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t vals[8];
if constexpr (has_perm) {
static_assert(!is_a_8bit);
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
@@ -156,28 +160,49 @@ __global__ void gptq_marlin_repack_kernel(
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
if constexpr (is_a_8bit) {
b1_vals[i] =
sh_stage_int_ptr[cur_n + sh_stride * i + (warp_id % 2) * 8];
} else {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
}
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_elem = tc_row + (is_a_8bit ? i : tc_offsets[i]);
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
if constexpr (is_a_8bit)
vals[4 + i] =
(b1_vals[cur_int + tile_ints / 2] >> (cur_pos * num_bits)) & mask;
else
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
constexpr int tile_size =
target_tile_k_size * target_tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
if constexpr (!is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else if constexpr (is_a_8bit && num_bits == 4) {
int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7};
uint32_t res = 0;
#pragma unroll
@@ -194,8 +219,9 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
const int ii = is_a_8bit ? i : pack_idx[i];
res1 |= vals[ii] << (i * 8);
res2 |= vals[4 + ii] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
@@ -236,21 +262,22 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
#define CALL_IF(NUM_BITS, HAS_PERM, IS_A_8BIT) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM && \
is_a_8bit == IS_A_8BIT) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
HAS_PERM, IS_A_8BIT>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
HAS_PERM, IS_A_8BIT> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
int64_t num_bits, bool is_a_8bit) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
@@ -309,13 +336,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
CALL_IF(4, false, false)
CALL_IF(4, true, false)
CALL_IF(8, false, false)
CALL_IF(8, true, false)
CALL_IF(4, false, true)
CALL_IF(8, false, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
", has_perm = ", has_perm, ", is_a_8bit = ", is_a_8bit);
}
return out;

View File

@@ -11,17 +11,19 @@
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ scale2_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the

View File

@@ -55,6 +55,45 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// No support for async
#else
__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 4;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 8;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;

View File

@@ -2,8 +2,10 @@
#ifndef _data_types_cuh
#define _data_types_cuh
#include "marlin.cuh"
#include "core/scalar_type.hpp"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
@@ -11,14 +13,16 @@
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t>
class ScalarType {};
template <long scalar_type_id>
class MarlinScalarType {};
template <>
class ScalarType<half> {
class MarlinScalarType<vllm::kFloat16.id()> {
public:
using scalar_t = half;
using scalar_t2 = half2;
using scalar_t4 = half2;
using scalar_32bit_t = half2;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
@@ -27,6 +31,7 @@ class ScalarType<half> {
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) {
@@ -44,18 +49,25 @@ class ScalarType<half> {
static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
static __host__ __device__ float2 inline num22float2(const half2 x) {
return __half22float2(x);
}
};
template <>
class ScalarType<nv_bfloat16> {
class MarlinScalarType<vllm::kBFloat16.id()> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;
using scalar_t4 = nv_bfloat162;
using scalar_32bit_t = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragS0 = Vec<__nv_fp8x2_e4m3, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
@@ -75,9 +87,63 @@ class ScalarType<nv_bfloat16> {
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
return __bfloat1622float2(x);
}
#endif
};
template <>
class MarlinScalarType<vllm::kFE4M3fn.id()> {
public:
using scalar_t = __nv_fp8_e4m3;
using scalar_t2 = __nv_fp8x2_e4m3;
using scalar_t4 = __nv_fp8x4_e4m3;
using scalar_32bit_t = __nv_fp8x4_e4m3;
using FragA = Vec<__nv_fp8x4_e4m3, 4>;
using FragB = Vec<__nv_fp8x4_e4m3, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<__nv_fp8x2_e4m3, 4>;
static __host__ __device__
float2 inline num22float2(const __nv_fp8x2_e4m3 x) {
return (float2)x;
}
};
template <>
class MarlinScalarType<vllm::kS8.id()> {
public:
using scalar_t = int8_t;
using scalar_t2 = int16_t;
using scalar_t4 = int32_t;
using scalar_32bit_t = int32_t;
using FragA = Vec<int32_t, 4>;
using FragB = Vec<int32_t, 2>;
using FragC = Vec<float, 4>;
using FragZP = Vec<int16_t, 4>;
};
template <typename scalar_t>
class MarlinScalarType2 {};
template <>
class MarlinScalarType2<half> : public MarlinScalarType<vllm::kFloat16.id()> {};
template <>
class MarlinScalarType2<nv_bfloat16>
: public MarlinScalarType<vllm::kBFloat16.id()> {};
template <>
class MarlinScalarType2<__nv_fp8_e4m3>
: public MarlinScalarType<vllm::kFE4M3fn.id()> {};
template <>
class MarlinScalarType2<int8_t> : public MarlinScalarType<vllm::kS8.id()> {};
} // namespace MARLIN_NAMESPACE_NAME
#endif

View File

@@ -0,0 +1,106 @@
#include "marlin.cuh"
#include "core/registration.h"
// for only non-zp format (like gptq)
__global__ void marlin_int4_fp8_preprocess_kernel_without_zp(
// qweight: (size_k * size_n // 8,)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output) {
int32_t val = qweight[blockIdx.x * 32 + threadIdx.x];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
single_val = single_val >= 8 ? single_val - 8 : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
}
output[blockIdx.x * 32 + threadIdx.x] = new_val;
}
// for awq format only (with zp and with awq weight layout)
__global__ void marlin_int4_fp8_preprocess_kernel_awq(
// AWQ qweight: (size_k, size_n // 8)
const int32_t* __restrict__ qweight,
// output: same shape with qweight
int32_t* __restrict__ output,
// AWQ zeros: (size_k // group_size, size_n // 8)
const int32_t* __restrict__ qzeros, int32_t size_n, int32_t size_k,
int32_t group_size) {
int32_t val =
qweight[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y];
int32_t zero =
qzeros[(blockIdx.x * 32 + threadIdx.x) / group_size * size_n / 8 +
blockIdx.y];
int32_t new_val = 0;
#pragma unroll
for (int32_t i = 0; i < 8; i++) {
int32_t single_val = val & 0xF;
int32_t single_zero = zero & 0xF;
single_val =
single_val >= single_zero ? single_val - single_zero : 15 - single_val;
new_val |= single_val << (i * 4);
val >>= 4;
zero >>= 4;
}
output[(blockIdx.x * 32 + threadIdx.x) * size_n / 8 + blockIdx.y] = new_val;
}
torch::Tensor marlin_int4_fp8_preprocess(
torch::Tensor& qweight, std::optional<torch::Tensor> qzeros_or_none,
bool inplace) {
TORCH_CHECK(qweight.device().is_cuda(), "qweight is not on GPU");
TORCH_CHECK(qweight.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));
torch::Tensor output = inplace ? qweight : torch::empty_like(qweight);
if (!qzeros_or_none.has_value()) {
TORCH_CHECK(qweight.numel() * 8 % 256 == 0,
"qweight.numel() * 8 % 256 != 0");
int blocks = qweight.numel() * 8 / 256;
marlin_int4_fp8_preprocess_kernel_without_zp<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr());
} else {
int32_t size_k = qweight.size(0);
int32_t size_n = qweight.size(1) * 8;
torch::Tensor qzeros = qzeros_or_none.value();
TORCH_CHECK(size_k % 32 == 0, "size_k % 32 != 0");
TORCH_CHECK(qzeros.device().is_cuda(), "qzeros is not on GPU");
TORCH_CHECK(qzeros.scalar_type() == at::ScalarType::Int,
"qweight.dtype != torch.int32");
TORCH_CHECK(device_of(qweight) == device_of(qzeros),
"qzeros is not on the same device with qweight");
int32_t group_size = qweight.size(0) / qzeros.size(0);
TORCH_CHECK(qweight.size(1) == qzeros.size(1),
"qweight.size(1) != qzeros.size(1)");
TORCH_CHECK(qweight.size(0) % qzeros.size(0) == 0,
"qweight.size(0) % qzeros.size(0) != 0");
TORCH_CHECK(group_size % 8 == 0, "group_size % 8 != 0");
dim3 blocks(size_k / 32, size_n / 8);
marlin_int4_fp8_preprocess_kernel_awq<<<blocks, 32>>>(
(const int32_t*)qweight.data_ptr(), (int32_t*)output.data_ptr(),
(const int32_t*)qzeros.data_ptr(), size_n, size_k, group_size);
}
return output;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_int4_fp8_preprocess", &marlin_int4_fp8_preprocess);
}

File diff suppressed because it is too large Load Diff

View File

@@ -298,9 +298,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,"
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"Tensor? "
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_type_id, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file
@@ -308,13 +309,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin repack from GPTQ.
ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
"SymInt size_k, SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ.
ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
"SymInt size_n, int num_bits, bool is_a_8bit) -> Tensor");
// conditionally compiled so impl registrations are in source file
// preprocess W-int4A-fp8 weight for marlin kernel
ops.def(
"marlin_int4_fp8_preprocess(Tensor qweight, "
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM