[Kernel] moe wna16 cuda kernel (#13321)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin
2025-03-11 08:12:40 +08:00
committed by GitHub
parent 04421dff8a
commit 90e88ab756
7 changed files with 698 additions and 1 deletions

View File

@@ -719,6 +719,33 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=topk_ids.numel(),
group_size=block_shape[1],
num_experts=B.shape[0],
bit=4 if use_int4_w4a16 else 8)
config = config.copy()
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=topk_ids.numel(),
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
group_size=block_shape[1],
real_top_k=topk_ids.shape[1],
block_size_m=config["BLOCK_SIZE_M"]))
if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
topk_weights if mul_routed_weight else None,
sorted_token_ids, expert_ids,
num_tokens_post_padded, top_k,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"], bit)
return
fused_moe_kernel_gptq_awq[grid](
A,
B,
@@ -852,6 +879,70 @@ def get_moe_configs(
return None
def get_moe_wna16_block_config(config: Dict[str,
int], use_moe_wna16_cuda: bool,
num_valid_tokens: int, size_k: int, size_n: int,
num_experts: int, group_size: int,
real_top_k: int, block_size_m: int):
if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
# optimal block config is set
return {}
if not use_moe_wna16_cuda:
# triton moe wna16 kernel
if num_valid_tokens // real_top_k == 1:
# if bs=1, use a smaller BLOCK_SIZE_N
return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
else:
return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
else:
# cuda moe wna16 kernel
# set default block_size 128, and increase them when num_blocks
# is too large.
block_size_n = 128
block_size_k = 128
if block_size_k <= group_size:
block_size_k = group_size
num_n_blocks = size_k // block_size_k
num_k_blocks = size_n // block_size_k
num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \
num_experts
if num_valid_tokens // real_top_k <= block_size_m:
num_m_blocks = min(num_m_blocks, num_valid_tokens)
num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
if size_k % 256 == 0 and num_blocks >= 256 and \
block_size_k < 256:
block_size_k = 256
num_blocks = num_blocks // (256 // block_size_k)
if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \
size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \
num_blocks >= 512:
block_size_k = block_size_k * 2
num_blocks = num_blocks // 2
if num_blocks > 1024:
block_size_n = 256
num_n_blocks = num_n_blocks // 2
num_blocks = num_blocks // 2
if size_n <= 1024 and num_blocks >= 1024:
# The kernel performance got much better with BLOCK_SIZE_N=1024
# when num_blocks is large, event when N is small.
# Not sure why, maybe it force the CUDA SM process only one block
# at the same time.
block_size_n = 1024
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int,
num_experts: int, bit: int):
return bit == 4 and group_size in [32, 64, 128] and \
num_valid_tokens / num_experts <= 6
def get_default_config(
M: int,
E: int,
@@ -873,6 +964,21 @@ def get_default_config(
"num_warps": 4,
"num_stages": 3,
}
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
# moe wna16 kernels
# only set BLOCK_SIZE_M
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
bit = 4 if dtype == "int4_w4a16" else 8
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
block_shape[1], E, bit)
if use_moe_wna16_cuda:
config = {"BLOCK_SIZE_M": min(16, M)}
elif M <= 20:
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
elif M <= 40:
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
else:
config = {
"BLOCK_SIZE_M": 64,
@@ -907,6 +1013,8 @@ def try_get_optimal_moe_config(
else:
# First try to load optimal config from the file
E, _, N = w2_shape
if dtype == "int4_w4a16":
N = N * 2
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k)
@@ -1027,7 +1135,7 @@ def get_config_dtype_str(dtype: torch.dtype,
elif use_int8_w8a16:
return "int8_w8a16"
elif use_int4_w4a16:
return "int4_w8a16"
return "int4_w4a16"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs