[NVIDIA] Support SiluMul + NVFP4 quant fusion (#23671)

Signed-off-by: jindih <jindih@nvidia.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: jindih <jindih@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedic <lgovedic@redhat.com>
This commit is contained in:
elvischenv
2025-08-29 03:36:50 +08:00
committed by GitHub
parent 57d4ede520
commit 16a45b3a28
11 changed files with 746 additions and 64 deletions

View File

@@ -4,32 +4,41 @@ import pytest
import torch
import vllm.envs as envs
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
# yapf conflicts with isort for this block
# yapf: disable
from vllm.compilation.activation_quant_fusion import (
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
def is_nvfp4_supported():
return current_platform.has_device_capability(100)
class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = (torch.rand(
hidden_size,
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
@@ -45,14 +54,56 @@ class TestModel(torch.nn.Module):
input_scale=self.wscale)
return x2
def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
def ops_in_model_after(self):
return [FUSED_OPS[kFp8StaticTensorSym]]
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w = torch.randint(256, (hidden_size, hidden_size // 2),
dtype=FP4_DTYPE)
self.wscale = torch.randn(hidden_size,
hidden_size // 16).to(dtype=FP8_DTYPE)
self.wscale2 = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
def forward(self, x):
y = self.silu_and_mul(x)
y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale)
out = cutlass_scaled_fp4_mm(a=y_quant,
b=self.w,
block_scale_a=y_block_scale,
block_scale_b=self.wscale,
alpha=self.scale * self.wscale2,
out_dtype=y.dtype)
return out
def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]
def ops_in_model_after(self):
return [FUSED_OPS[kNvfp4Quant]]
@pytest.mark.parametrize("num_tokens", [64])
@pytest.mark.parametrize("hidden_size", [128])
@pytest.mark.parametrize(
"model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
force_fp8_e4m3fnuz):
if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz:
pytest.skip("Duplicate tests for NVFP4")
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
@@ -63,7 +114,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = TestModel(hidden_size, force_fp8_e4m3fnuz)
model = model_class(hidden_size=hidden_size,
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size * 2)
@@ -80,17 +132,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
atol=1e-3,
rtol=1e-3)
# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())
silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn(post_nodes, silu_and_mul_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
# In post-nodes, fused kernels should be present and quant op should not
backend.check_after_ops(model.ops_in_model_after())