[Kernel] Adding split_K implementation for fused_moe_lora (#27291)

Signed-off-by: Danielle Robinson <dmmaddix@amazon.com>
Signed-off-by: Danielle Robinson <dcmaddix@gmail.com>
Co-authored-by: Danielle Robinson <dmmaddix@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Danielle Robinson
2025-10-27 02:05:24 -07:00
committed by GitHub
parent 2d631d28c6
commit 9932ed6a83
4 changed files with 35 additions and 14 deletions

View File

@@ -121,6 +121,7 @@ def fused_moe_kernel_gptq_awq(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
@@ -356,6 +357,7 @@ def fused_moe_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
@@ -646,7 +648,6 @@ def invoke_fused_moe_kernel(
bit,
)
return
fused_moe_kernel_gptq_awq[grid](
A,
B,
@@ -686,6 +687,7 @@ def invoke_fused_moe_kernel(
)
else:
config = config.copy()
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
@@ -983,6 +985,7 @@ def get_default_config(
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
}
return config
@@ -996,6 +999,7 @@ def get_default_config(
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"SPLIT_K": 1,
"num_warps": 4,
"num_stages": 3 if not current_platform.is_rocm() else 2,
}
@@ -1006,19 +1010,20 @@ def get_default_config(
bit = 4 if dtype == "int4_w4a16" else 8
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
if use_moe_wna16_cuda:
config = {"BLOCK_SIZE_M": min(16, M)}
config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1}
elif M <= 20:
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
elif M <= 40:
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
elif M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
}
else:
config = {
@@ -1026,6 +1031,7 @@ def get_default_config(
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
}
return config