[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
This commit is contained in:
committed by
GitHub
parent
1ebdff412a
commit
fcb9df99bd
@@ -1534,6 +1534,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||
def scaled_fp4_quant(
|
||||
input: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
is_sf_swizzled_layout: bool = True,
|
||||
backend: str = "none",
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -1577,22 +1578,26 @@ def scaled_fp4_quant(
|
||||
else:
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
if is_sf_swizzled_layout:
|
||||
# We use the rounded values to store the swizzled values. Due to the
|
||||
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
||||
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
||||
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
|
||||
|
||||
# We use the rounded values to store the swizzled values. Due to the
|
||||
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
||||
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
||||
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
output, input, output_scale, input_global_scale, is_sf_swizzled_layout
|
||||
)
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)
|
||||
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
@@ -152,6 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
|
||||
@@ -946,6 +946,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
@@ -1043,6 +1044,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
|
||||
@@ -248,6 +248,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_e4m3_quantize,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.flashinfer import flashinfer_fp4_quantize
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
@@ -117,9 +116,7 @@ def _nvfp4_quantize(
|
||||
A_scale: torch.Tensor | None,
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return flashinfer_fp4_quantize(
|
||||
A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
|
||||
)
|
||||
return ops.scaled_fp4_quant(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
|
||||
@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||
x, layer.input_global_scale, self.backend
|
||||
x,
|
||||
layer.input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
backend=self.backend,
|
||||
)
|
||||
|
||||
mm_args = (
|
||||
|
||||
@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(
|
||||
x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
|
||||
)
|
||||
|
||||
# validate dtypes of quantized input, input block scale,
|
||||
# weight and weight_blockscale
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe(
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
|
||||
x, layer.a1_gscale, is_sf_swizzled_layout=False
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
|
||||
Reference in New Issue
Block a user