[Performance] Fused blockwise quant RMS norm (#27883)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -8,6 +8,12 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8,
|
||||
)
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
|
||||
@@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
|
||||
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SCALE_UBS = [True, False]
|
||||
GROUP_SIZES = [None, [1, 64], [1, 128]]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
@@ -45,12 +52,13 @@ def ref_rms_norm(
|
||||
return out, residual
|
||||
|
||||
|
||||
def ref_dynamic_per_token_quant(
|
||||
def ref_dynamic_per_token_or_block_quant(
|
||||
rms_norm_layer: RMSNorm,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
@@ -59,13 +67,24 @@ def ref_dynamic_per_token_quant(
|
||||
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
|
||||
|
||||
# Quant
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = ops.scaled_fp8_quant(
|
||||
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
if group_size is not None:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = per_token_group_quant_fp8(
|
||||
torch_out, group_size=group_size[1], use_ue8m0=False
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales = per_token_group_quant_int8(
|
||||
torch_out, group_size=group_size[1]
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
torch_out, scales = ops.scaled_fp8_quant(
|
||||
torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
else:
|
||||
assert quant_dtype == torch.int8
|
||||
torch_out, scales, _ = ops.scaled_int8_quant(torch_out)
|
||||
|
||||
return torch_out, scales, residual
|
||||
|
||||
@@ -76,24 +95,32 @@ def ref_impl(
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ref_dynamic_per_token_quant(
|
||||
rms_norm_layer, x, quant_dtype, residual, scale_ub
|
||||
return ref_dynamic_per_token_or_block_quant(
|
||||
rms_norm_layer, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
|
||||
def ops_dynamic_per_token_quant(
|
||||
def ops_dynamic_per_token_or_block_quant(
|
||||
weight: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
if residual is not None:
|
||||
residual = residual.clone()
|
||||
out, scales = ops.rms_norm_dynamic_per_token_quant(
|
||||
x, weight, EPS, quant_dtype, scale_ub, residual
|
||||
)
|
||||
if group_size is not None:
|
||||
out, scales = ops.rms_norm_per_block_quant(
|
||||
x, weight, EPS, quant_dtype, group_size, scale_ub, residual, True
|
||||
)
|
||||
scales = scales.contiguous()
|
||||
else:
|
||||
out, scales = ops.rms_norm_dynamic_per_token_quant(
|
||||
x, weight, EPS, quant_dtype, scale_ub, residual
|
||||
)
|
||||
return out, scales, residual
|
||||
|
||||
|
||||
@@ -103,8 +130,11 @@ def ops_impl(
|
||||
quant_dtype: torch.dtype,
|
||||
residual: torch.Tensor | None,
|
||||
scale_ub: torch.Tensor | None,
|
||||
group_size: list[int] | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub)
|
||||
return ops_dynamic_per_token_or_block_quant(
|
||||
weight, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
|
||||
@@ -112,6 +142,7 @@ def ops_impl(
|
||||
@pytest.mark.parametrize("has_scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
@@ -122,6 +153,7 @@ def test_rms_norm(
|
||||
has_scale_ub: bool,
|
||||
dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
group_size: list[int] | None,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
@@ -130,6 +162,14 @@ def test_rms_norm(
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if group_size is not None and hidden_size % group_size[1] != 0:
|
||||
# skip
|
||||
return
|
||||
|
||||
if group_size is not None and has_scale_ub:
|
||||
# blockwise baseline doesn't support scale_ub
|
||||
return
|
||||
|
||||
if has_scale_ub and quant_dtype != torch.float8_e4m3fn:
|
||||
# skip
|
||||
return
|
||||
@@ -150,10 +190,10 @@ def test_rms_norm(
|
||||
scale_ub = None
|
||||
|
||||
ref_out, ref_scales, ref_residual = ref_impl(
|
||||
layer, x, quant_dtype, residual, scale_ub
|
||||
layer, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
ops_out, ops_scales, ops_residual = ops_impl(
|
||||
layer.weight, x, quant_dtype, residual, scale_ub
|
||||
layer.weight, x, quant_dtype, residual, scale_ub, group_size
|
||||
)
|
||||
|
||||
assert ref_out.dtype == quant_dtype
|
||||
@@ -166,11 +206,15 @@ def test_rms_norm(
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
a = ref_out.to(dtype=torch.float32)
|
||||
b = ops_out.to(dtype=torch.float32)
|
||||
ok = torch.allclose(a, b)
|
||||
ok = torch.allclose(a, b, atol=1e-6)
|
||||
if not ok:
|
||||
# fallback: compare dequantized values with relaxed tolerance
|
||||
a_deq = a * ref_scales.view(-1, 1)
|
||||
b_deq = b * ops_scales.view(-1, 1)
|
||||
if group_size is None:
|
||||
a_deq = a * ref_scales.view(-1, 1)
|
||||
b_deq = b * ops_scales.view(-1, 1)
|
||||
else:
|
||||
a_deq = a * ref_scales.repeat_interleave(group_size[1], dim=1)
|
||||
b_deq = b * ops_scales.repeat_interleave(group_size[1], dim=1)
|
||||
# NOTE: It is possible that some future test cases trigger this
|
||||
# max diff due to precision issues. If such an error is
|
||||
# encountered, it's recommended to inspect the differences between
|
||||
|
||||
Reference in New Issue
Block a user