[Performance] Fused blockwise quant RMS norm (#27883)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
ElizaWszola
2025-12-07 17:38:04 +01:00
committed by GitHub
parent 0044c4038c
commit af0444bf40
14 changed files with 949 additions and 157 deletions

View File

@@ -8,6 +8,12 @@ import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8,
)
DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
@@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
GROUP_SIZES = [None, [1, 64], [1, 128]]
SEEDS = [0]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@@ -45,12 +52,13 @@ def ref_rms_norm(
return out, residual
def ref_dynamic_per_token_quant(
def ref_dynamic_per_token_or_block_quant(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
@@ -59,13 +67,24 @@ def ref_dynamic_per_token_quant(
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
# Quant
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
if group_size is not None:
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)
else:
assert quant_dtype == torch.int8
torch_out, scales = per_token_group_quant_int8(
torch_out, group_size=group_size[1]
)
else:
assert quant_dtype == torch.int8
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = ops.scaled_fp8_quant(
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
else:
assert quant_dtype == torch.int8
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
return torch_out, scales, residual
@@ -76,24 +95,32 @@ def ref_impl(
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ref_dynamic_per_token_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub
return ref_dynamic_per_token_or_block_quant(
rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size
)
def ops_dynamic_per_token_quant(
def ops_dynamic_per_token_or_block_quant(
weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
if residual is not None:
residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(
x, weight, EPS, quant_dtype, scale_ub, residual
)
if group_size is not None:
out, scales = ops.rms_norm_per_block_quant(
x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True
)
scales = scales.contiguous()
else:
out, scales = ops.rms_norm_dynamic_per_token_quant(
x, weight, EPS, quant_dtype, scale_ub, residual
)
return out, scales, residual
@@ -103,8 +130,11 @@ def ops_impl(
quant_dtype: torch.dtype,
residual: torch.Tensor | None,
scale_ub: torch.Tensor | None,
group_size: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)
return ops_dynamic_per_token_or_block_quant(
weight, x, quant_dtype, residual, scale_ub, group_size
)
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@@ -112,6 +142,7 @@ def ops_impl(
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("group_size", GROUP_SIZES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
@@ -122,6 +153,7 @@ def test_rms_norm(
has_scale_ub: bool,
dtype: torch.dtype,
quant_dtype: torch.dtype,
group_size: list[int] | None,
seed: int,
device: str,
) -> None:
@@ -130,6 +162,14 @@ def test_rms_norm(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if group_size is not None and hidden_size % group_size[1] != 0:
# skip
return
if group_size is not None and has_scale_ub:
# blockwise baseline doesn't support scale_ub
return
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
# skip
return
@@ -150,10 +190,10 @@ def test_rms_norm(
scale_ub = None
ref_out, ref_scales, ref_residual = ref_impl(
layer, x, quant_dtype, residual, scale_ub
layer, x, quant_dtype, residual, scale_ub, group_size
)
ops_out, ops_scales, ops_residual = ops_impl(
layer.weight, x, quant_dtype, residual, scale_ub
layer.weight, x, quant_dtype, residual, scale_ub, group_size
)
assert ref_out.dtype == quant_dtype
@@ -166,11 +206,15 @@ def test_rms_norm(
assert torch.allclose(ref_scales, ops_scales)
a = ref_out.to(dtype=torch.float32)
b = ops_out.to(dtype=torch.float32)
ok = torch.allclose(a, b)
ok = torch.allclose(a, b, atol=1e-6)
if not ok:
# fallback: compare dequantized values with relaxed tolerance
a_deq = a * ref_scales.view(-1, 1)
b_deq = b * ops_scales.view(-1, 1)
if group_size is None:
a_deq = a * ref_scales.view(-1, 1)
b_deq = b * ops_scales.view(-1, 1)
else:
a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1)
b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1)
# NOTE: It is possible that some future test cases trigger this
# max diff due to precision issues. If such an error is
# encountered, it's recommended to inspect the differences between