[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
Roberto L. Castro
2026-01-25 02:45:27 +01:00
committed by GitHub
parent 1ebdff412a
commit fcb9df99bd
18 changed files with 508 additions and 151 deletions

View File

@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import flashinfer_fp4_quantize
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -117,9 +116,7 @@ def _nvfp4_quantize(
A_scale: torch.Tensor | None,
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return flashinfer_fp4_quantize(
A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
)
return ops.scaled_fp4_quant(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout)
def _fp8_quantize(

View File

@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_global_scale, self.backend
x,
layer.input_global_scale,
is_sf_swizzled_layout=True,
backend=self.backend,
)
mm_args = (

View File

@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend)
x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale

View File

@@ -8,6 +8,7 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
layer.a1_gscale,
is_sf_swizzled_layout=False,
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Determine routing method type
@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
layer.a1_gscale,
is_sf_swizzled_layout=False,
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, layer.a1_gscale, is_sf_swizzled_layout=False
)
# Call TRT-LLM FP4 block-scale MoE kernel