Feature/silu block quant fusion v1 (#32996)

Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
This commit is contained in:
Monishver
2026-04-01 11:50:43 -07:00
committed by GitHub
parent c9a9db0e02
commit c09ad767cd
11 changed files with 830 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from functools import partial
import pytest
import torch
@@ -34,13 +35,16 @@ from vllm.model_executor.kernels.linear import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@@ -165,6 +169,48 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
class TestSiluMulBlockQuantModel(torch.nn.Module):
quant_key = kFp8Dynamic128Sym
def __init__(self, hidden_size: int, is_scale_transposed: bool = False, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.is_scale_transposed = is_scale_transposed
self.quant_fp8 = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
column_major_scales=is_scale_transposed,
compile_native=False,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.quant_fp8.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
out, scale = self.quant_fp8(y)
group_size = self.quant_key.scale.group_shape[1]
scale_expanded = scale.repeat_interleave(group_size, dim=1)
dequant = out.to(dtype=torch.float32) * scale_expanded
return (dequant,)
def ops_in_model_before(self):
ops = []
if self.enable_silu_mul_custom_op:
ops.append(SILU_MUL_OP)
# When silu custom op is disabled, aten.mul.Tensor also appears
# in dequant code, so we skip checking it to avoid false positives.
ops.append(
QUANT_OPS[self.quant_key]
if self.enable_quant_fp8_custom_op
else torch.ops.aten.reciprocal.default
)
return ops
def ops_in_model_after(self):
return [FUSED_OPS[self.quant_key]]
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
CUDA_KERNELS = [
FlashInferFP8ScaledMMLinearKernel,
@@ -200,6 +246,19 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
not current_platform.is_rocm(), reason="ROCm only"
),
),
# Block quant fusion for per-group FP8 (CUDA only).
*[
pytest.param(
partial(TestSiluMulBlockQuantModel, is_scale_transposed=transposed),
True,
None,
marks=pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUDA only"
),
id=f"TestSiluMulBlockQuant-transposed={transposed}",
)
for transposed in [False, True]
],
],
)
@pytest.mark.skipif(
@@ -213,6 +272,7 @@ def test_fusion_silu_and_mul_quant(
TestSiluMulFp8QuantModel
| TestSiluMulNvfp4QuantModel
| TestSiluMulGroupFp8QuantModel
| TestSiluMulBlockQuantModel
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
@@ -223,6 +283,12 @@ def test_fusion_silu_and_mul_quant(
pytest.skip("NVFP4 is not supported on this GPU.")
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
pytest.skip("AITER is not supported on this GPU.")
if (
isinstance(model_class, partial)
and model_class.func is TestSiluMulBlockQuantModel
and is_deep_gemm_supported()
):
pytest.skip("SiluMul+BlockQuant fusion not applicable with DeepGemm")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
@@ -269,11 +335,13 @@ def test_fusion_silu_and_mul_quant(
result2 = model2(x)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
if isinstance(model, TestSiluMulFp8QuantModel):
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
elif isinstance(model, TestSiluMulNvfp4QuantModel):
atol, rtol = 1e-1, 1e-1
elif model_class == TestSiluMulGroupFp8QuantModel:
elif isinstance(
model, (TestSiluMulGroupFp8QuantModel, TestSiluMulBlockQuantModel)
):
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(