Feature/silu block quant fusion v1 (#32996)
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -34,13 +35,16 @@ from vllm.model_executor.kernels.linear import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
@@ -165,6 +169,48 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
|
||||
|
||||
|
||||
class TestSiluMulBlockQuantModel(torch.nn.Module):
|
||||
quant_key = kFp8Dynamic128Sym
|
||||
|
||||
def __init__(self, hidden_size: int, is_scale_transposed: bool = False, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.is_scale_transposed = is_scale_transposed
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=False,
|
||||
group_shape=GroupShape(1, 128),
|
||||
column_major_scales=is_scale_transposed,
|
||||
compile_native=False,
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
self.enable_quant_fp8_custom_op = self.quant_fp8.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
out, scale = self.quant_fp8(y)
|
||||
group_size = self.quant_key.scale.group_shape[1]
|
||||
scale_expanded = scale.repeat_interleave(group_size, dim=1)
|
||||
dequant = out.to(dtype=torch.float32) * scale_expanded
|
||||
return (dequant,)
|
||||
|
||||
def ops_in_model_before(self):
|
||||
ops = []
|
||||
if self.enable_silu_mul_custom_op:
|
||||
ops.append(SILU_MUL_OP)
|
||||
# When silu custom op is disabled, aten.mul.Tensor also appears
|
||||
# in dequant code, so we skip checking it to avoid false positives.
|
||||
ops.append(
|
||||
QUANT_OPS[self.quant_key]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else torch.ops.aten.reciprocal.default
|
||||
)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [FUSED_OPS[self.quant_key]]
|
||||
|
||||
|
||||
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
|
||||
CUDA_KERNELS = [
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
@@ -200,6 +246,19 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
|
||||
not current_platform.is_rocm(), reason="ROCm only"
|
||||
),
|
||||
),
|
||||
# Block quant fusion for per-group FP8 (CUDA only).
|
||||
*[
|
||||
pytest.param(
|
||||
partial(TestSiluMulBlockQuantModel, is_scale_transposed=transposed),
|
||||
True,
|
||||
None,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="CUDA only"
|
||||
),
|
||||
id=f"TestSiluMulBlockQuant-transposed={transposed}",
|
||||
)
|
||||
for transposed in [False, True]
|
||||
],
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
@@ -213,6 +272,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
TestSiluMulFp8QuantModel
|
||||
| TestSiluMulNvfp4QuantModel
|
||||
| TestSiluMulGroupFp8QuantModel
|
||||
| TestSiluMulBlockQuantModel
|
||||
],
|
||||
enable_silu_mul_custom_op: bool,
|
||||
enable_quant_fp8_custom_op: bool,
|
||||
@@ -223,6 +283,12 @@ def test_fusion_silu_and_mul_quant(
|
||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||
pytest.skip("AITER is not supported on this GPU.")
|
||||
if (
|
||||
isinstance(model_class, partial)
|
||||
and model_class.func is TestSiluMulBlockQuantModel
|
||||
and is_deep_gemm_supported()
|
||||
):
|
||||
pytest.skip("SiluMul+BlockQuant fusion not applicable with DeepGemm")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
@@ -269,11 +335,13 @@ def test_fusion_silu_and_mul_quant(
|
||||
result2 = model2(x)
|
||||
|
||||
# Check that it gives the same answer
|
||||
if model_class == TestSiluMulFp8QuantModel:
|
||||
if isinstance(model, TestSiluMulFp8QuantModel):
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||
elif isinstance(model, TestSiluMulNvfp4QuantModel):
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
elif model_class == TestSiluMulGroupFp8QuantModel:
|
||||
elif isinstance(
|
||||
model, (TestSiluMulGroupFp8QuantModel, TestSiluMulBlockQuantModel)
|
||||
):
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
|
||||
Reference in New Issue
Block a user