[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -69,6 +69,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
@@ -316,7 +319,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
quant_method = Fp8LinearMethod(self)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
@@ -324,7 +329,9 @@ class Fp8Config(QuantizationConfig):
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Fp8MoEMethod(self, layer)
moe_quant_method = Fp8MoEMethod(self, layer)
moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
@@ -375,6 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
@@ -552,7 +560,9 @@ class Fp8LinearMethod(LinearMethodBase):
)
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer, size_k_first)
prepare_fp8_layer_for_marlin(
layer, size_k_first, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.input_scale
return
@@ -610,6 +620,7 @@ class Fp8LinearMethod(LinearMethodBase):
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
input_dtype=self.marlin_input_dtype,
bias=bias,
)
@@ -657,6 +668,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.block_quant, layer.moe_parallel_config
)
self.marlin_input_dtype = None
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
@@ -1031,7 +1043,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight.data = w13_weight.data
if self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
prepare_moe_fp8_layer_for_marlin(
layer, False, input_dtype=self.marlin_input_dtype
)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
@@ -1270,6 +1284,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
input_dtype=self.marlin_input_dtype,
workspace=layer.workspace,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: