diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c8366ecce..c4ba8053c 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -861,6 +861,39 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake( return out, residual_out +def _rocm_aiter_gemm_a8wfp4_impl( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + from aiter.ops.triton.gemm_a8wfp4 import gemm_a8wfp4 + + M, N = x.shape[0], w.shape[0] + y = torch.empty(M, N, dtype=out_dtype, device=x.device) + gemm_a8wfp4( + x=x, + w=w, + y=y, + x_scales=x_scales, + w_scales=w_scales, + dtype=out_dtype, + config=None, + ) + return y + + +def _rocm_aiter_gemm_a8wfp4_fake( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty(x.shape[0], w.shape[0], dtype=out_dtype, device=x.device) + + def _triton_rotary_embedding_impl( positions: torch.Tensor, query: torch.Tensor, @@ -1337,6 +1370,14 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_gemm_a8wfp4", + op_func=_rocm_aiter_gemm_a8wfp4_impl, + mutates_args=[], + fake_impl=_rocm_aiter_gemm_a8wfp4_fake, + dispatch_key=current_platform.dispatch_key, + ) + # Register rocm aiter rotary embedding custom op direct_register_custom_op( op_name="rocm_aiter_triton_rotary_embedding", @@ -1646,6 +1687,18 @@ class rocm_aiter_ops: ) -> tuple[torch.Tensor, torch.Tensor]: return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale) + @staticmethod + def gemm_a8wfp4( + x: torch.Tensor, + w: torch.Tensor, + x_scales: torch.Tensor, + w_scales: torch.Tensor, + out_dtype: torch.dtype, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_gemm_a8wfp4( + x, w, x_scales, w_scales, out_dtype + ) + @staticmethod def triton_fp4_gemm_dynamic_qaunt( x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index dedc7db38..1ca28fbf0 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E from vllm.model_executor.layers.quantization.quark.schemes import ( QuarkOCP_MX, QuarkScheme, + QuarkW4A8_MXFP4_FP8, QuarkW8A8Fp8, QuarkW8A8Int8, ) @@ -350,6 +351,31 @@ class QuarkConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + def _is_w4a8_mxfp4_fp8( + self, + weight_quant: dict[str, Any] | None, + input_quant: dict[str, Any] | None, + ) -> bool: + if weight_quant is None or input_quant is None: + return False + + is_weight_mxfp4 = ( + weight_quant.get("dtype") == "fp4" + and weight_quant.get("qscheme") == "per_group" + and weight_quant.get("group_size") == 32 + and weight_quant.get("scale_format") == "e8m0" + and not weight_quant.get("is_dynamic") + ) + + is_input_fp8 = ( + input_quant.get("dtype") == "fp8_e4m3" + and input_quant.get("qscheme") == "per_tensor" + and not input_quant.get("is_dynamic") # Static per-tensor + and input_quant.get("symmetric") is True # Symmetric quantization + ) + + return is_weight_mxfp4 and is_input_fp8 + def _is_w_ocp_mx_a_x( self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None ) -> bool: @@ -504,6 +530,12 @@ class QuarkConfig(QuantizationConfig): is_static_input_scheme=True, input_symmetric=input_config.get("symmetric"), ) + elif self._is_w4a8_mxfp4_fp8(weight_config, input_config): + is_w4a8_supported = self._check_scheme_supported( + QuarkW4A8_MXFP4_FP8.get_min_capability(), error=False + ) + if is_w4a8_supported: + return QuarkW4A8_MXFP4_FP8(weight_config, input_config) elif self._is_w_ocp_mx_a_x(weight_config, input_config): return QuarkOCP_MX( weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index 7620d6e41..a5e33a044 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -3,7 +3,14 @@ from .quark_ocp_mx import QuarkOCP_MX from .quark_scheme import QuarkScheme +from .quark_w4a8_mxfp4_fp8 import QuarkW4A8_MXFP4_FP8 from .quark_w8a8_fp8 import QuarkW8A8Fp8 from .quark_w8a8_int8 import QuarkW8A8Int8 -__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"] +__all__ = [ + "QuarkScheme", + "QuarkW8A8Fp8", + "QuarkW8A8Int8", + "QuarkOCP_MX", + "QuarkW4A8_MXFP4_FP8", +] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py new file mode 100644 index 000000000..29283c7bb --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable +from fractions import Fraction +from typing import Any + +import torch +import torch.nn.functional as F + +from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, +) +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + PackedvLLMParameter, + PerTensorScaleParameter, +) +from vllm.platforms import current_platform + +from .quark_scheme import QuarkScheme + +logger = init_logger(__name__) + + +__all__ = ["QuarkW4A8_MXFP4_FP8"] + +OCP_MX_BLOCK_SIZE = 32 + + +class QuarkW4A8_MXFP4_FP8(QuarkScheme): + """ + - Weights: MXFP4 with E8M0 scales per block of 32 + - Activations: FP8 E4M3 (static per-tensor quantization) + + Uses the AITER Triton kernel and falls back to emulation if AITER not available. + """ + + def __init__( + self, + weight_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any], + ): + self.out_dtype = None + + self.weight_dtype = "mxfp4" + self.packed_factor: Fraction = Fraction(2, 1) # 2 FP4 values per byte + self.weight_block_size = OCP_MX_BLOCK_SIZE + + self.is_static_input_scheme = not input_quant_spec.get("is_dynamic") + self.input_qscheme = input_quant_spec.get("qscheme") # "per_tensor" + + self.fp8_min, self.fp8_max = get_fp8_min_max() + self.fp8_dtype = current_platform.fp8_dtype() + + if not self.is_static_input_scheme: + raise NotImplementedError( + "Dynamic FP8 activation quantization is not yet supported " + "for W4A8. The current implementation expects static per-tensor " + "FP8 scales stored in the checkpoint." + ) + + kernel_supported_gpu = False + if current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx950 + + kernel_supported_gpu = on_gfx950() + + self.use_aiter_kernel = ( + is_aiter_found_and_supported() + and self.is_static_input_scheme + and kernel_supported_gpu + ) + + if not self.use_aiter_kernel: + logger.warning_once( + "[W4A8 MXFP4+FP8] Aiter Triton kernel not found. Using emulation mode." + ) + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_packed_dim(self, dim: int) -> int: + assert dim % 2 == 0, f"Dimension {dim} must be even for MXFP4 packing" + return dim // 2 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # MXFP4 WEIGHT (packed, 2 values per byte) + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + self.get_packed_dim(input_size_per_partition), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.packed_factor, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE (E8M0 format, per block of 32) + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.weight_block_size, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE (FP8 per-tensor static scale) + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter( + data=torch.empty( + len(output_partition_sizes), + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + # Initialize to avoid NaN + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Ensuring weights & scales are non-trainable + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) + + if self.is_static_input_scheme: + input_scale = layer.input_scale.data + # For fused modules (QKV), take the max scale + if input_scale.numel() != 1: + input_scale = input_scale.max() + + layer.input_scale = torch.nn.Parameter( + torch.tensor(input_scale, dtype=torch.float32), + requires_grad=False, + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.use_aiter_kernel: + return self._apply_aiter_kernel(layer, x, bias) + else: + return self._apply_emulation(layer, x, bias) + + def _apply_aiter_kernel( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + M = x.shape[0] + out_dtype = x.dtype if self.out_dtype is None else self.out_dtype + + input_scale = layer.input_scale + x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype) + + # Broadcast per-tensor scale to per-row (M, 1) for Aiter kernel + x_scales = input_scale.expand(M, 1).to(dtype=torch.float32, device=x.device) + + y = rocm_aiter_ops.gemm_a8wfp4( + x_fp8, layer.weight, x_scales, layer.weight_scale, out_dtype + ) + + if bias is not None: + y = y + bias + + return y + + def _apply_emulation( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + dequant_mxfp4, + ) + + weight_dq = dequant_mxfp4( + layer.weight, + layer.weight_scale, + x.dtype, + ) + + input_scale = layer.input_scale + x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype) + x_dq = (x_fp8.to(x.dtype) * input_scale).to(x.dtype) + + return F.linear(x_dq, weight_dq, bias)