[Kernel] moe wna16 cuda kernel (#13321)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -719,6 +719,33 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=topk_ids.numel(),
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.shape[0],
|
||||
bit=4 if use_int4_w4a16 else 8)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=topk_ids.numel(),
|
||||
size_k=A.shape[1],
|
||||
size_n=B.shape[1],
|
||||
num_experts=B.shape[1],
|
||||
group_size=block_shape[1],
|
||||
real_top_k=topk_ids.shape[1],
|
||||
block_size_m=config["BLOCK_SIZE_M"]))
|
||||
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
sorted_token_ids, expert_ids,
|
||||
num_tokens_post_padded, top_k,
|
||||
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"], bit)
|
||||
return
|
||||
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
@@ -852,6 +879,70 @@ def get_moe_configs(
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_wna16_block_config(config: Dict[str,
|
||||
int], use_moe_wna16_cuda: bool,
|
||||
num_valid_tokens: int, size_k: int, size_n: int,
|
||||
num_experts: int, group_size: int,
|
||||
real_top_k: int, block_size_m: int):
|
||||
if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
|
||||
# optimal block config is set
|
||||
return {}
|
||||
if not use_moe_wna16_cuda:
|
||||
# triton moe wna16 kernel
|
||||
if num_valid_tokens // real_top_k == 1:
|
||||
# if bs=1, use a smaller BLOCK_SIZE_N
|
||||
return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
|
||||
else:
|
||||
return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
|
||||
else:
|
||||
# cuda moe wna16 kernel
|
||||
# set default block_size 128, and increase them when num_blocks
|
||||
# is too large.
|
||||
block_size_n = 128
|
||||
block_size_k = 128
|
||||
if block_size_k <= group_size:
|
||||
block_size_k = group_size
|
||||
|
||||
num_n_blocks = size_k // block_size_k
|
||||
num_k_blocks = size_n // block_size_k
|
||||
num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \
|
||||
num_experts
|
||||
if num_valid_tokens // real_top_k <= block_size_m:
|
||||
num_m_blocks = min(num_m_blocks, num_valid_tokens)
|
||||
num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
|
||||
|
||||
if size_k % 256 == 0 and num_blocks >= 256 and \
|
||||
block_size_k < 256:
|
||||
block_size_k = 256
|
||||
num_blocks = num_blocks // (256 // block_size_k)
|
||||
|
||||
if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \
|
||||
size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \
|
||||
num_blocks >= 512:
|
||||
block_size_k = block_size_k * 2
|
||||
num_blocks = num_blocks // 2
|
||||
|
||||
if num_blocks > 1024:
|
||||
block_size_n = 256
|
||||
num_n_blocks = num_n_blocks // 2
|
||||
num_blocks = num_blocks // 2
|
||||
|
||||
if size_n <= 1024 and num_blocks >= 1024:
|
||||
# The kernel performance got much better with BLOCK_SIZE_N=1024
|
||||
# when num_blocks is large, event when N is small.
|
||||
# Not sure why, maybe it force the CUDA SM process only one block
|
||||
# at the same time.
|
||||
block_size_n = 1024
|
||||
|
||||
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
|
||||
|
||||
|
||||
def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int,
|
||||
num_experts: int, bit: int):
|
||||
return bit == 4 and group_size in [32, 64, 128] and \
|
||||
num_valid_tokens / num_experts <= 6
|
||||
|
||||
|
||||
def get_default_config(
|
||||
M: int,
|
||||
E: int,
|
||||
@@ -873,6 +964,21 @@ def get_default_config(
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None:
|
||||
# moe wna16 kernels
|
||||
# only set BLOCK_SIZE_M
|
||||
# BLOCK_SIZE_N and BLOCK_SIZE_K would be set later
|
||||
bit = 4 if dtype == "int4_w4a16" else 8
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk,
|
||||
block_shape[1], E, bit)
|
||||
if use_moe_wna16_cuda:
|
||||
config = {"BLOCK_SIZE_M": min(16, M)}
|
||||
elif M <= 20:
|
||||
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
|
||||
elif M <= 40:
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
||||
else:
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
@@ -907,6 +1013,8 @@ def try_get_optimal_moe_config(
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
E, _, N = w2_shape
|
||||
if dtype == "int4_w4a16":
|
||||
N = N * 2
|
||||
block_n = block_shape[0] if block_shape else 0
|
||||
block_k = block_shape[1] if block_shape else 0
|
||||
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
||||
@@ -1027,7 +1135,7 @@ def get_config_dtype_str(dtype: torch.dtype,
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w8a16"
|
||||
return "int4_w4a16"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
|
||||
Reference in New Issue
Block a user