diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 570ce1133..282208502 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -72,6 +72,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import ( from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import ( FlashInferFP8ScaledMMLinearKernel, ) +from vllm.model_executor.kernels.linear.scaled_mm.marlin import ( + MarlinFP8ScaledMMLinearKernel, +) from vllm.model_executor.kernels.linear.scaled_mm.pytorch import ( ChannelWiseTorchFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel, @@ -104,6 +107,7 @@ _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] # in priority/performance order (when available) _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { PlatformEnum.CUDA: [ + MarlinFP8ScaledMMLinearKernel, FlashInferFP8ScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel, diff --git a/vllm/model_executor/kernels/linear/scaled_mm/__init__.py b/vllm/model_executor/kernels/linear/scaled_mm/__init__.py index 3056d5d0f..2323a02ba 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/__init__.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/__init__.py @@ -14,6 +14,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import ( from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import ( FlashInferFP8ScaledMMLinearKernel, ) +from vllm.model_executor.kernels.linear.scaled_mm.marlin import ( + MarlinFP8ScaledMMLinearKernel, +) from vllm.model_executor.kernels.linear.scaled_mm.pytorch import ( ChannelWiseTorchFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel, @@ -46,6 +49,7 @@ __all__ = [ "CutlassFP8ScaledMMLinearKernel", "CutlassInt8ScaledMMLinearKernel", "FlashInferFP8ScaledMMLinearKernel", + "MarlinFP8ScaledMMLinearKernel", "ChannelWiseTorchFP8ScaledMMLinearKernel", "PerTensorTorchFP8ScaledMMLinearKernel", "RowWiseTorchFP8ScaledMMLinearKernel", diff --git a/vllm/model_executor/kernels/linear/scaled_mm/marlin.py b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py new file mode 100644 index 000000000..e79809037 --- /dev/null +++ b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence + +import torch + +import vllm.envs as envs +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + process_fp8_weight_block_strategy, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + is_fp8_marlin_supported, + prepare_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8Static128BlockSym, +) +from vllm.model_executor.utils import replace_parameter +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import ( + FP8ScaledMMLinearKernel, + FP8ScaledMMLinearLayerConfig, +) + + +class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + """ + FP8 Marlin kernel for GPUs that lack FP8 hardware support. + Leverages the Marlin kernel for fast weight-only FP8 quantization. + """ + + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not current_platform.is_cuda(): + return False, "requires CUDA." + # Check if platform supports FP8 Marlin + if not is_fp8_marlin_supported(): + return False, "FP8 Marlin requires compute capability 7.5 or higher" + if vllm_is_batch_invariant(): + return False, "FP8 Marlin not supported for batch invariant execution." + if ( + compute_capability is not None + and compute_capability >= 89 + and not envs.VLLM_TEST_FORCE_FP8_MARLIN + ): + return ( + False, + "To apply FP8 Marlin on high-capability GPUs, please set " + "VLLM_TEST_FORCE_FP8_MARLIN=1", + ) + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + return True, None + + def __init__( + self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str] + ) -> None: + super().__init__(c, layer_param_names) + self.marlin_input_dtype = None + self.block_quant = self.config.weight_quant_key in {kFp8Static128BlockSym} + self.size_k_first = not self.block_quant + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.block_quant: + weight, weight_scale_inv = process_fp8_weight_block_strategy( + layer.weight, layer.weight_scale_inv + ) + # Update layer with new values + replace_parameter(layer, "weight", weight.data) + replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) + else: + weight = layer.weight.t() + replace_parameter(layer, "weight", weight.data) + layer.input_scale = None + prepare_fp8_layer_for_marlin( + layer, self.size_k_first, input_dtype=self.marlin_input_dtype + ) + del layer.input_scale + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.block_quant: + weight_scale = layer.weight_scale_inv + else: + weight_scale = layer.weight_scale + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + input_dtype=self.marlin_input_dtype, + bias=bias, + ) + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + pass diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index cca3b58eb..c952b7690 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -177,15 +176,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if self.quant_config.use_marlin: - return apply_fp8_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) - return self.fp8_linear.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5101347cd..d2a23bcf2 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,7 +7,6 @@ import torch from torch.nn import Module from torch.utils._python_dispatch import TorchDispatchMode -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops @@ -16,6 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, ) +from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -61,10 +61,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, @@ -280,15 +276,6 @@ class Fp8LinearMethod(LinearMethodBase): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.marlin_input_dtype = None - self.use_marlin = ( - not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN - ) - # Disable marlin for rocm - if current_platform.is_rocm() or current_platform.is_xpu(): - self.use_marlin = False - if vllm_is_batch_invariant(): - self.use_marlin = False self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() self.use_deep_gemm = is_deep_gemm_supported() @@ -297,7 +284,28 @@ class Fp8LinearMethod(LinearMethodBase): self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" + # Use per-token quantization for better perf if dynamic and cutlass + if self.act_q_static: + activation_quant_key = kFp8StaticTensorSym + elif cutlass_fp8_supported(): + activation_quant_key = kFp8DynamicTokenSym + else: + activation_quant_key = kFp8DynamicTensorSym + if self.block_quant: + weight_quant_key = kFp8Static128BlockSym + else: + weight_quant_key = kFp8StaticTensorSym + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=activation_quant_key, + weight_quant_key=weight_quant_key, + out_dtype=torch.get_default_dtype(), + module_name=self.__class__.__name__, + ) + self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel) + + if self.block_quant and not self.use_marlin: assert not self.act_q_static assert self.weight_block_size is not None self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( @@ -306,21 +314,6 @@ class Fp8LinearMethod(LinearMethodBase): cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) - else: - # Use per-token quantization for better perf if dynamic and cutlass - if self.act_q_static: - activation_quant_key = kFp8StaticTensorSym - elif cutlass_fp8_supported(): - activation_quant_key = kFp8DynamicTokenSym - else: - activation_quant_key = kFp8DynamicTensorSym - - self.fp8_linear = init_fp8_linear_kernel( - activation_quant_key=activation_quant_key, - weight_quant_key=kFp8StaticTensorSym, - out_dtype=torch.get_default_dtype(), - module_name=self.__class__.__name__, - ) def create_weights( self, @@ -387,12 +380,18 @@ class Fp8LinearMethod(LinearMethodBase): layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: - size_k_first = True + if self.use_marlin: + # Only Marlin kernels support `marlin_input_dtype`; guard to avoid + # AttributeError if backend selection changes. + if hasattr(self.fp8_linear, "marlin_input_dtype"): + self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype + self.fp8_linear.process_weights_after_loading(layer) + return + input_scale = None # TODO(rob): refactor block quant into separate class. if self.block_quant: assert not self.act_q_static - size_k_first = False weight, weight_scale_inv = process_fp8_weight_block_strategy( layer.weight, layer.weight_scale_inv @@ -411,16 +410,15 @@ class Fp8LinearMethod(LinearMethodBase): # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. - if not self.use_marlin: - weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( - weight, - weight_scale, - layer.logical_widths, - getattr(layer, "input_scale", None), - ) - if self.act_q_static: - assert input_scale is not None - input_scale = input_scale.max() + weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( + weight, + weight_scale, + layer.logical_widths, + getattr(layer, "input_scale", None), + ) + if self.act_q_static: + assert input_scale is not None + input_scale = input_scale.max() weight = weight.t() # Update layer with new values. @@ -432,14 +430,6 @@ class Fp8LinearMethod(LinearMethodBase): else: layer.input_scale = None - if self.use_marlin: - prepare_fp8_layer_for_marlin( - layer, size_k_first, input_dtype=self.marlin_input_dtype - ) - # Activations not quantized for marlin. - del layer.input_scale - return - if self.block_quant: maybe_post_process_fp8_weight_block(layer) @@ -486,21 +476,7 @@ class Fp8LinearMethod(LinearMethodBase): return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: - if self.block_quant: - weight_scale = layer.weight_scale_inv - else: - weight_scale = layer.weight_scale - - return apply_fp8_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - input_dtype=self.marlin_input_dtype, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) if self.block_quant: assert self.weight_block_size is not None @@ -623,18 +599,20 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): layer.input_scale = None qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - weight = qweight.t() # Update layer with new values. - replace_parameter(layer, "weight", weight.data) + replace_parameter(layer, "weight", qweight.data) replace_parameter(layer, "weight_scale", weight_scale.data) if self.use_marlin: - size_k_first = True - prepare_fp8_layer_for_marlin( - layer, size_k_first, input_dtype=self.marlin_input_dtype - ) - # Activations not quantized for marlin. + # Only Marlin kernels support `marlin_input_dtype`; guard to avoid + # AttributeError if backend selection changes. + if hasattr(self.fp8_linear, "marlin_input_dtype"): + self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype + self.fp8_linear.process_weights_after_loading(layer) + else: + weight = qweight.t() + replace_parameter(layer, "weight", weight.data) # Prevent duplicate processing (e.g., during weight reload) layer._already_called_process_weights_after_loading = True