[Kernel] Adding split_K implementation for fused_moe_lora (#27291)
Signed-off-by: Danielle Robinson <dmmaddix@amazon.com> Signed-off-by: Danielle Robinson <dcmaddix@gmail.com> Co-authored-by: Danielle Robinson <dmmaddix@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2d631d28c6
commit
9932ed6a83
@@ -121,6 +121,7 @@ def fused_moe_kernel_gptq_awq(
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
@@ -356,6 +357,7 @@ def fused_moe_kernel(
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
@@ -646,7 +648,6 @@ def invoke_fused_moe_kernel(
|
||||
bit,
|
||||
)
|
||||
return
|
||||
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
@@ -686,6 +687,7 @@ def invoke_fused_moe_kernel(
|
||||
)
|
||||
else:
|
||||
config = config.copy()
|
||||
config["SPLIT_K"] = 1
|
||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||
if block_shape is not None:
|
||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||
@@ -983,6 +985,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
@@ -996,6 +999,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_N": block_shape[0],
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"SPLIT_K": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
||||
}
|
||||
@@ -1006,19 +1010,20 @@ def get_default_config(
|
||||
bit = 4 if dtype == "int4_w4a16" else 8
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
|
||||
if use_moe_wna16_cuda:
|
||||
config = {"BLOCK_SIZE_M": min(16, M)}
|
||||
config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1}
|
||||
elif M <= 20:
|
||||
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
|
||||
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
elif M <= 40:
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
else:
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
elif M <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
@@ -1026,6 +1031,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user