[Kernel] moe wna16 marlin kernel (#14447)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin
2025-04-15 11:05:22 +08:00
committed by GitHub
parent 6b40996ae8
commit d06ba4ed3f
16 changed files with 3477 additions and 329 deletions

View File

@@ -17,14 +17,13 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
check_marlin_supports_layer, check_moe_marlin_supports_layer,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@@ -136,12 +135,15 @@ class AWQMarlinConfig(QuantizationConfig):
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQMoEMethod(self)
return AWQMoEMethod(self)
return None
@classmethod
@@ -391,6 +393,13 @@ class AWQMoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
@@ -473,10 +482,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if expert_map is not None:
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
@@ -503,7 +509,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
)

View File

@@ -15,13 +15,13 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
check_marlin_supported, check_moe_marlin_supports_layer,
marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
@@ -153,12 +153,15 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
if layer.local_num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQMarlinMoEMethod(self)
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)
@@ -408,7 +411,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=torch.half),
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
@@ -418,7 +421,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
torch.empty(num_experts,
scales_size2,
hidden_size,
dtype=torch.half),
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
@@ -493,6 +496,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
layer.workspace = torch.zeros((sms * 4, ),
dtype=torch.int,
device=device,
requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
@@ -601,10 +611,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@@ -626,9 +632,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
is_k_full=self.is_k_full).to(orig_dtype)
workspace=layer.workspace,
is_k_full=self.is_k_full)

View File

@@ -151,6 +151,19 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
group_size=group_size)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
return hidden_size % 128 == 0 and \
intermediate_size_per_partition % max(64, group_size) == 0 and \
group_size in [-1, 32, 64, 128]
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //