[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -4,141 +4,292 @@ import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
# only SM89 and SM120 fully support
|
||||
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM90 and SM100 can use this PTX, but it’s simulated
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
// clang-format off
|
||||
""".lstrip()
|
||||
|
||||
FILE_HEAD = (
|
||||
FILE_HEAD_COMMENT
|
||||
+ """
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
"""
|
||||
)
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{a_type_id}}, "
|
||||
"{{b_type_id}}, "
|
||||
"{{c_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{m_block_size_8}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"{{is_zp_float}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = [
|
||||
"vllm::kU4",
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
"vllm::kFE4M3fn",
|
||||
"vllm::kFE2M1f",
|
||||
]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
QUANT_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# HQQ
|
||||
{
|
||||
"a_type": ["kFloat16"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [4],
|
||||
"is_zp_float": True,
|
||||
},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": "kU8B128",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{
|
||||
"b_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 8],
|
||||
},
|
||||
# NVFP4
|
||||
{
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [1],
|
||||
},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kFE2M1f",
|
||||
"c_type": ["kBFloat16"],
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
filename = os.path.dirname(__file__) + "/kernel_selector.h"
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
|
||||
b_type = quant_config["b_type"]
|
||||
is_zp_float = quant_config.get("is_zp_float", False)
|
||||
all_group_blocks = quant_config["group_blocks"]
|
||||
all_m_blocks = quant_config["thread_m_blocks"]
|
||||
all_thread_configs = quant_config["thread_configs"]
|
||||
|
||||
for a_type, c_type in itertools.product(a_types, c_types):
|
||||
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
|
||||
continue
|
||||
if "16" in a_type and "16" in c_type and a_type != c_type:
|
||||
continue
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
):
|
||||
thread_k, thread_n, threads = thread_configs
|
||||
|
||||
if threads == 256:
|
||||
# for small batch (m_blocks == 1),
|
||||
# we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1),
|
||||
# we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
|
||||
continue
|
||||
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
|
||||
continue
|
||||
|
||||
config = {
|
||||
"threads": threads,
|
||||
"s_type": s_type,
|
||||
"thread_m_blocks": max(m_blocks, 1),
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "true" if is_zp_float else "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||
):
|
||||
# act order case only support gptq-int4 and gptq-int8
|
||||
if group_blocks == 0 and scalar_type not in [
|
||||
"vllm::kU4B8",
|
||||
"vllm::kU8B128",
|
||||
]:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
# for small batch (m_blocks == 1), we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1), we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
# we only support channelwise quantization and group_size == 128
|
||||
# for fp8
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
# mxfp4 only supports group_size == 32
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
continue
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
is_zp_float_list = [False]
|
||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
|
||||
# HQQ (is_zp_float = true) only supports
|
||||
# 4bit quantization and fp16
|
||||
is_zp_float_list.append(True)
|
||||
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||
s_type = "vllm::kFE4M3fn"
|
||||
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||
s_type = "vllm::kFE8M0fnu"
|
||||
if dtype == "fp16":
|
||||
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||
continue
|
||||
elif dtype == "fp16":
|
||||
s_type = "vllm::kFloat16"
|
||||
elif dtype == "bf16":
|
||||
s_type = "vllm::kBFloat16"
|
||||
|
||||
for is_zp_float in is_zp_float_list:
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
s_type_id=s_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=is_zp_float,
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
"else if (a_type == vllm::kFE4M3fn)\n"
|
||||
" TORCH_CHECK(false, "
|
||||
'"marlin kernel with fp8 activation is not built.");'
|
||||
)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
|
||||
f.write(kernel_selector_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
|
||||
Reference in New Issue
Block a user