[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2026-01-20 14:48:20 +08:00
committed by GitHub
parent e9c83cdc51
commit 148117ea2e
30 changed files with 1467 additions and 1038 deletions

View File

@@ -25,19 +25,30 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
PerTensorTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
)
from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported
from ..utils import TestFP8Layer
from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
@@ -49,25 +60,27 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
quant_key = kFp8StaticTensorSym
def __init__(
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
force_kernel=force_kernel,
)
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
x2 = self.fp8_linear(y)
return x2
def ops_in_model_before(self):
@@ -161,20 +174,27 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
CUDA_KERNELS = [
FlashInferFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
]
TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
"model_class, enable_quant_fp8_custom_op, force_kernel",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS))
+ [
(TestSiluMulNvfp4QuantModel, False, False),
(TestSiluMulGroupFp8QuantModel, False, False),
(TestSiluMulNvfp4QuantModel, False, None),
(TestSiluMulGroupFp8QuantModel, False, None),
],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
@@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant(
],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
force_kernel: FP8ScaledMMLinearKernel | None,
):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
@@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant(
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
maybe_create_device_identity()
x = torch.rand(num_tokens, hidden_size * 2)
@@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant(
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
)
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)