[Perf][fp8] Use CustomOp abstraction for fp8 quant for better perf (#19830)

Signed-off-by: Luka Govedic <lgovedic@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Luka Govedič
2025-07-11 00:56:28 -04:00
committed by GitHub
parent 35514b682a
commit 31d5c1797f
18 changed files with 368 additions and 104 deletions

View File

@@ -4,33 +4,56 @@ import pytest
import torch
import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
from vllm.platforms import current_platform
from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, *args, **kwargs):
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = (torch.rand(
hidden_size,
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
def forward(self, x):
y = self.silu_and_mul(x)
x2 = scaled_fp8_quant(y, self.scale)
x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
return x2
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
cutlass_fp8_enabled):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
@@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass)
model = TestModel()
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = TestModel(hidden_size, cutlass_fp8_enabled)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
x = torch.rand(num_tokens, hidden_size * 2)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)