[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
@@ -20,7 +20,7 @@ from .backend import TestBackend
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
|
||||
def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@@ -32,7 +32,7 @@ class TestModel(torch.nn.Module):
|
||||
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
force_fp8_e4m3fnuz=force_fp8_e4m3fnuz,
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
@@ -48,12 +48,11 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||
@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
cutlass_fp8_enabled):
|
||||
force_fp8_e4m3fnuz):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
@@ -64,7 +63,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
model = TestModel(hidden_size, cutlass_fp8_enabled)
|
||||
model = TestModel(hidden_size, force_fp8_e4m3fnuz)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
|
||||
Reference in New Issue
Block a user