[torch.compile] Enable silu_mul_fp8_quant fusion without custom ops enabled (#27146)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2025-10-22 12:22:39 +08:00
committed by GitHub
parent ceacedc1f9
commit ab3e80042e
4 changed files with 134 additions and 76 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import itertools
import pytest
import torch
@@ -16,7 +16,13 @@ from vllm.compilation.activation_quant_fusion import (
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
CompilationMode,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -25,7 +31,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
@@ -54,6 +60,8 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x):
y = self.silu_and_mul(x)
@@ -61,7 +69,14 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
return x2
def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
(
QUANT_OPS[kFp8StaticTensorSym]
if self.enable_quant_fp8_custom_op
else torch.ops.aten.reciprocal
),
]
def ops_in_model_after(self):
return [FUSED_OPS[kFp8StaticTensorSym]]
@@ -77,6 +92,7 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
assert silu_and_mul_nvfp4_quant_supported
self.silu_and_mul = SiluAndMul()
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
# create nvfp4 weight
w = torch.rand((hidden_size, hidden_size))
@@ -101,7 +117,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
return out
def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
QUANT_OPS[kNvfp4Quant],
]
def ops_in_model_after(self):
return [FUSED_OPS[kNvfp4Quant]]
@@ -110,67 +129,80 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class",
cast(
list[type],
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported()
else [TestSiluMulFp8QuantModel],
),
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4")
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
maybe_create_device_identity()
x = torch.rand(num_tokens, hidden_size * 2)
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
)
fusion_pass = ActivationQuantFusionPass(config)
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
custom_ops = []
if enable_silu_mul_custom_op:
custom_ops.append("+silu_and_mul")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
)
assert fusion_pass.matched_count == 1
with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config)
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
)
# In post-nodes, fused kernels should be present and quant op should not
backend.check_after_ops(model.ops_in_model_after())
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)
assert fusion_pass.matched_count == 1
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())
# In post-nodes, fused kernels should be present and quant op should not
backend.check_after_ops(model.ops_in_model_after())