diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 9db03ea14..52f266707 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -58,7 +58,7 @@ TEMPLATE = ( "( MARLIN_KERNEL_PARAMS );" ) -THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 00b17f075..6f229a4df 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -126,14 +126,16 @@ thread_config_t small_batch_thread_configs[] = { // thread_k, thread_n, num_threads {128, 128, 256}, - {64, 128, 128}}; + {64, 128, 128}, + {128, 64, 128}}; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {64, 256, 256}, - {64, 128, 128}}; + {64, 128, 128}, + {128, 64, 128}}; typedef struct { int blocks_per_sm; diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 91b93c76c..5be688265 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -226,6 +226,7 @@ def prepare_fp8_moe_layer_for_marlin( e = layer.num_experts k = layer.hidden_size n = layer.intermediate_size_per_partition + w13_n = w13_weight.size(1) weight_block_size = getattr(layer, "weight_block_size", None) # WORKSPACE @@ -240,7 +241,7 @@ def prepare_fp8_moe_layer_for_marlin( def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor: tensor_list = [] if "w13" in name: - size_n, size_k = n * 2, k + size_n, size_k = w13_n, k else: size_n, size_k = k, n @@ -268,7 +269,7 @@ def prepare_fp8_moe_layer_for_marlin( scales = scales.to(layer.orig_dtype) tensor_list = [] if "w13" in name: - size_n, size_k = n * 2, k + size_n, size_k = w13_n, k else: size_n, size_k = k, n