[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
nvjullin
2025-08-26 21:54:04 +08:00
committed by GitHub
parent b78bed1bc5
commit f66673a39d
9 changed files with 198 additions and 36 deletions

View File

@@ -12,7 +12,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
Fp8LinearOp)
from vllm.platforms import current_platform
from .backend import TestBackend
@@ -20,7 +20,7 @@ from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul()
@@ -32,7 +32,7 @@ class TestModel(torch.nn.Module):
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
@@ -48,12 +48,11 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
cutlass_fp8_enabled):
force_fp8_e4m3fnuz):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
@@ -64,7 +63,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
model = TestModel(hidden_size, cutlass_fp8_enabled)
model = TestModel(hidden_size, force_fp8_e4m3fnuz)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size * 2)