[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -38,6 +38,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_moe_fp4_layer_for_marlin,
|
||||
)
|
||||
@@ -205,7 +208,9 @@ class Mxfp4Config(QuantizationConfig):
|
||||
if current_platform.is_xpu():
|
||||
return IpexMxfp4MoEMethod(layer.moe_config)
|
||||
else:
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
quant_method = Mxfp4MoEMethod(layer.moe_config)
|
||||
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, Attention):
|
||||
# TODO: Add support for MXFP4 Attention.
|
||||
logger.debug_once(
|
||||
@@ -220,6 +225,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
|
||||
|
||||
self.marlin_input_dtype = None
|
||||
self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN
|
||||
self.max_capture_size = (
|
||||
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
|
||||
@@ -385,7 +392,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
|
||||
elif (
|
||||
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
@@ -914,6 +921,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
|
||||
Reference in New Issue
Block a user