[NVIDIA] Support SiluMul + NVFP4 quant fusion (#23671)
Signed-off-by: jindih <jindih@nvidia.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: jindih <jindih@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedic <lgovedic@redhat.com>
This commit is contained in:
126
tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
Normal file
126
tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
global_scale: torch.Tensor,
|
||||
ref_output_scale: torch.Tensor) -> torch.Tensor:
|
||||
silu_and_mul_out = silu_and_mul.forward_native(x)
|
||||
assert not current_platform.is_rocm()
|
||||
assert silu_and_mul_out.ndim >= 1, (
|
||||
f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.')
|
||||
other_dims = 1 if silu_and_mul_out.ndim == 1 else -1
|
||||
silu_and_mul_out = silu_and_mul_out.reshape(other_dims,
|
||||
silu_and_mul_out.shape[-1])
|
||||
m, n = silu_and_mul_out.shape
|
||||
device = silu_and_mul_out.device
|
||||
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
out = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
output_scale = ref_output_scale
|
||||
|
||||
torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale,
|
||||
global_scale)
|
||||
|
||||
return out, output_scale
|
||||
|
||||
|
||||
def ops_impl(x: torch.Tensor, global_scale: torch.Tensor,
|
||||
ref_output_scale: torch.Tensor) -> torch.Tensor:
|
||||
out_shape = (x.shape[0], x.shape[1] // 4)
|
||||
output_scale = ref_output_scale
|
||||
out = torch.empty(out_shape, dtype=torch.uint8, device=x.device)
|
||||
torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale)
|
||||
return out, output_scale
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
m, n = shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
|
||||
block_size = 16
|
||||
|
||||
assert n % block_size == 0, (
|
||||
f'last dim has to be multiple of 16, but got {n}.')
|
||||
assert x.dtype in (torch.float16, torch.bfloat16), (
|
||||
f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.')
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(x.shape[0], 128)
|
||||
scale_n = x.shape[1] // (2 * block_size)
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||
device=x.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
layer = SiluAndMul()
|
||||
|
||||
ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale)
|
||||
|
||||
fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale)
|
||||
|
||||
assert ref_out.dtype == torch.uint8
|
||||
assert fusion_out.dtype == torch.uint8
|
||||
assert ref_out.shape == fusion_out.shape
|
||||
|
||||
assert ref_out_scale.dtype == torch.int32
|
||||
assert fusion_out_scale.dtype == torch.int32
|
||||
assert ref_out_scale.shape == fusion_out_scale.shape
|
||||
|
||||
# Allow up to 2% of mismatched values since BF16 has accuracy issues.
|
||||
mis_threshold = 0.02
|
||||
atol = 0.4
|
||||
rtol = 0.4
|
||||
ref_logits = ref_out[-1]
|
||||
fusion_logits = fusion_out[-1]
|
||||
|
||||
mis_count = torch.sum(
|
||||
torch.abs(fusion_logits - ref_logits) > (atol +
|
||||
rtol * torch.abs(ref_logits)))
|
||||
mis_ratio = mis_count / fusion_logits.numel()
|
||||
|
||||
assert mis_ratio < mis_threshold, \
|
||||
f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}"
|
||||
|
||||
torch.testing.assert_close(ref_out_scale, fusion_out_scale)
|
||||
|
||||
opcheck(torch.ops._C.silu_and_mul_nvfp4_quant,
|
||||
(fusion_out, fusion_out_scale, x, global_scale))
|
||||
Reference in New Issue
Block a user