Add thread_n=64 support to Marlin MoE (#32360)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -58,7 +58,7 @@ TEMPLATE = (
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
|
||||
|
||||
@@ -126,14 +126,16 @@ thread_config_t small_batch_thread_configs[] = {
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{128, 128, 256},
|
||||
{64, 128, 128}};
|
||||
{64, 128, 128},
|
||||
{128, 64, 128}};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128}};
|
||||
{64, 128, 128},
|
||||
{128, 64, 128}};
|
||||
|
||||
typedef struct {
|
||||
int blocks_per_sm;
|
||||
|
||||
@@ -226,6 +226,7 @@ def prepare_fp8_moe_layer_for_marlin(
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
n = layer.intermediate_size_per_partition
|
||||
w13_n = w13_weight.size(1)
|
||||
weight_block_size = getattr(layer, "weight_block_size", None)
|
||||
|
||||
# WORKSPACE
|
||||
@@ -240,7 +241,7 @@ def prepare_fp8_moe_layer_for_marlin(
|
||||
def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
size_n, size_k = w13_n, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
@@ -268,7 +269,7 @@ def prepare_fp8_moe_layer_for_marlin(
|
||||
scales = scales.to(layer.orig_dtype)
|
||||
tensor_list = []
|
||||
if "w13" in name:
|
||||
size_n, size_k = n * 2, k
|
||||
size_n, size_k = w13_n, k
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
|
||||
Reference in New Issue
Block a user