[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
@@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig):
if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention.
logger.debug_once(
@@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
@@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer)
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
@@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
)
assert _can_support_mxfp4(