[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
committed by
GitHub
parent
1ebdff412a
commit
fcb9df99bd
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_e4m3_quantize,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.flashinfer import flashinfer_fp4_quantize
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
@@ -117,9 +116,7 @@ def _nvfp4_quantize(
|
||||
A_scale: torch.Tensor | None,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return flashinfer_fp4_quantize(
|
||||
A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
|
||||
)
|
||||
return ops.scaled_fp4_quant(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
|
||||
@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||
x, layer.input_global_scale, self.backend
|
||||
x,
|
||||
layer.input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
backend=self.backend,
|
||||
)
|
||||
|
||||
mm_args = (
|
||||
|
||||
@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||
x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
|
||||
)
|
||||
|
||||
# validate dtypes of quantized input, input block scale,
|
||||
# weight and weight_blockscale
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe(
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
|
||||
Reference in New Issue
Block a user