diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1ccfa93c5..0d3b7a294 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -655,6 +655,7 @@ steps: - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py + - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 5cfad935a..c4229f934 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) + Fp8LinearOp, maybe_create_device_identity) from vllm.platforms import current_platform from .backend import TestBackend @@ -26,9 +26,9 @@ FP8_DTYPE = current_platform.fp8_dtype() class TestModel(torch.nn.Module): def __init__(self, hidden_size: int, eps: float, static: bool, - cutlass_fp8_enabled: bool, *args, **kwargs): + force_fp8_e4m3fnuz: bool, *args, **kwargs): super().__init__(*args, **kwargs) - self.cutlass_fp8_enabled = cutlass_fp8_enabled + self.force_fp8_e4m3fnuz = force_fp8_e4m3fnuz self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN @@ -43,7 +43,7 @@ class TestModel(torch.nn.Module): for _ in range(2) ] self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, act_quant_static=static, act_quant_group_shape=group_shape, ) @@ -81,12 +81,11 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("static", [True, False]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) +@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - cutlass_fp8_enabled): + force_fp8_e4m3fnuz): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) @@ -103,7 +102,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, fusion_pass = FusionPass.instance(vllm_config) backend = TestBackend(noop_pass, fusion_pass) - model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled) + model = TestModel(hidden_size, eps, static, force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index a6baa97fe..fb9f9dde2 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -104,8 +104,7 @@ class TestQuantModel(torch.nn.Module): # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=True, - use_per_token_if_dynamic=False) + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 5351a3cf3..0e1059e65 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_FP8_SUPPORTED, Fp8LinearOp) + Fp8LinearOp) from vllm.platforms import current_platform from .backend import TestBackend @@ -20,7 +20,7 @@ from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, + def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args, **kwargs): super().__init__(*args, **kwargs) self.silu_and_mul = SiluAndMul() @@ -32,7 +32,7 @@ class TestModel(torch.nn.Module): hidden_size).to(dtype=current_platform.fp8_dtype()).t()) self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_enabled, + force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR, ) @@ -48,12 +48,11 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) -@pytest.mark.parametrize("cutlass_fp8_enabled", - [True, False] if CUTLASS_FP8_SUPPORTED else [False]) +@pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, - cutlass_fp8_enabled): + force_fp8_e4m3fnuz): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -64,7 +63,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = TestModel(hidden_size, cutlass_fp8_enabled) + model = TestModel(hidden_size, force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size * 2) diff --git a/tests/kernels/quantization/test_flashinfer_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_scaled_mm.py new file mode 100644 index 000000000..9f669c6df --- /dev/null +++ b/tests/kernels/quantization/test_flashinfer_scaled_mm.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm + +if not current_platform.has_device_capability(100): + pytest.skip( + reason= + "Flashinfer FP8 gemms requires compute capability of 10.0 or above.", + allow_module_level=True, + ) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ["cuda:0"] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("autotune", [False, True]) +@torch.inference_mode() +def test_flashinfer_fp8_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + use_bias: bool, + seed: int, + device: str, + autotune: bool, +) -> None: + current_platform.seed_everything(seed) + m, n, k = shape + a = torch.randn((m, k), dtype=dtype, device=device) + b = torch.randn((n, k), dtype=dtype, device=device) / k + + a_fp8, a_scale = ops.scaled_fp8_quant(a) + b_fp8, b_scale = ops.scaled_fp8_quant(b) + + expected_out = torch.mm( + a_scale * a_fp8.to(dtype=torch.float32), + b_scale * b_fp8.to(dtype=torch.float32).t(), + ).to(dtype=dtype) + + if use_bias: + bias = torch.randn((n, ), dtype=dtype, device=device) + expected_out = expected_out + bias + else: + bias = None + + import flashinfer + + with flashinfer.autotune(autotune): + out = flashinfer_scaled_fp8_mm( + a_fp8, + b_fp8.t(), + a_scale, + b_scale, + dtype, + bias=bias, + ) + + torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a4de4d709..d45d368b5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -223,8 +223,7 @@ class Fp8LinearMethod(LinearMethodBase): self.fp8_linear = Fp8LinearOp( act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape, - cutlass_fp8_supported=cutlass_fp8_supported()) + act_quant_group_shape=self.act_q_group_shape) def create_weights( self, @@ -376,6 +375,8 @@ class Fp8LinearMethod(LinearMethodBase): # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) + # layer.input_scale is None indicates dynamic quant and scale is + # computed from input. layer.input_scale = None # If checkpoint is fp8, handle that there are N scales for N diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index d11cba2ca..466fd5fba 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -97,8 +97,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( act_quant_static=False, - cutlass_fp8_supported=False, - act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_group_shape=GroupShape.PER_TOKEN, + force_fp8_e4m3fnuz=True) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 36d16960e..5333bbd31 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -157,6 +158,19 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, return output.view(*output_shape) +def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: list, **kwargs) -> torch.Tensor: + + return flashinfer_scaled_fp8_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + def rocm_per_tensor_w8a8_scaled_mm_impl( qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, @@ -231,8 +245,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, out_dtype: torch.dtype, scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: + input_2d: torch.Tensor, output_shape: list, + **kwargs) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. # For now it has only been validated on ROCm platform. @@ -303,16 +317,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( - cutlass_fp8_supported: bool, per_tensor_weights: bool, + preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool) -> Callable[..., torch.Tensor]: - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if cutlass_fp8_supported: - return cutlass_w8a8_scaled_mm if per_tensor_weights and per_tensor_activations: - if current_platform.is_rocm(): + if preferred_backend == "rocm": return rocm_per_tensor_w8a8_scaled_mm + if preferred_backend == "flashinfer": + return flashinfer_w8a8_scaled_mm + if preferred_backend == "cutlass": + return cutlass_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if preferred_backend == "cutlass" or preferred_backend == "flashinfer": + return cutlass_w8a8_scaled_mm + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) if not per_tensor_weights and not per_tensor_activations \ and USE_ROWWISE_TORCH_SCALED_MM: @@ -334,10 +354,20 @@ class Fp8LinearOp: def __init__(self, act_quant_static: bool, - cutlass_fp8_supported: bool = cutlass_fp8_supported(), act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None): - self.cutlass_fp8_supported = cutlass_fp8_supported + pad_output: Optional[bool] = None, + force_fp8_e4m3fnuz: bool = False): + if current_platform.is_rocm(): + self.preferred_backend = "rocm" + elif current_platform.is_cuda( + ) and not force_fp8_e4m3fnuz and cutlass_fp8_supported(): + if has_flashinfer() and current_platform.has_device_capability( + 100): + self.preferred_backend = "flashinfer" + else: + self.preferred_backend = "cutlass" + else: + self.preferred_backend = "torch" # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -347,8 +377,7 @@ class Fp8LinearOp: if pad_output is None: config = get_current_vllm_config().compilation_config pad_output = config.level < CompilationLevel.PIECEWISE and \ - not cutlass_fp8_supported and \ - not current_platform.is_rocm() + self.preferred_backend == "torch" self.output_padding = 17 if pad_output else None self.act_quant_static = act_quant_static @@ -393,9 +422,9 @@ class Fp8LinearOp: per_tensor_activations = (x_scale.numel() == 1) # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations) + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend, + per_tensor_weights, + per_tensor_activations) return w8a8_scaled_mm_func(qinput=qinput, weight=weight, diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5dd239c50..fab134733 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -265,6 +265,37 @@ if has_flashinfer(): dtype=dtype, device=A.device) + @torch.library.custom_op( + "vllm::bmm_fp8", + mutates_args=[], + device_types="cuda", + ) + def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import bmm_fp8 as bmm_fp8_ + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) + + @torch.library.register_fake("vllm::bmm_fp8", ) + def bmm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], + A.shape[1], + B.shape[2], + dtype=dtype, + device=A.device) + def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, @@ -293,6 +324,35 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, ) +def flashinfer_scaled_fp8_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert a.shape[1] == b.shape[0] + assert scale_a.numel() == 1 and scale_b.numel() == 1 + assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn + assert a.device.type == "cuda" and b.device.type == "cuda" + assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32 + assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda" + + output = bmm_fp8( + a.unsqueeze(0), + b.unsqueeze(0), + scale_a, + scale_b, + out_dtype, + "auto", + ).view(a.shape[0], b.shape[1]) + + if bias is not None: + output = output + bias + return output + + __all__ = [ "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", @@ -307,4 +367,5 @@ __all__ = [ "supports_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", + "flashinfer_scaled_fp8_mm", ]