[ROCm][Quantization] add quark w4a8 mxfp4_fp8 for LinearLayer (#35316)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
Divakar Verma
2026-03-13 15:44:24 -04:00
committed by GitHub
parent 7afe0faab1
commit 6341d43043
4 changed files with 311 additions and 1 deletions

View File

@@ -861,6 +861,39 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
return out, residual_out
def _rocm_aiter_gemm_a8wfp4_impl(
x: torch.Tensor,
w: torch.Tensor,
x_scales: torch.Tensor,
w_scales: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
from aiter.ops.triton.gemm_a8wfp4 import gemm_a8wfp4
M, N = x.shape[0], w.shape[0]
y = torch.empty(M, N, dtype=out_dtype, device=x.device)
gemm_a8wfp4(
x=x,
w=w,
y=y,
x_scales=x_scales,
w_scales=w_scales,
dtype=out_dtype,
config=None,
)
return y
def _rocm_aiter_gemm_a8wfp4_fake(
x: torch.Tensor,
w: torch.Tensor,
x_scales: torch.Tensor,
w_scales: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(x.shape[0], w.shape[0], dtype=out_dtype, device=x.device)
def _triton_rotary_embedding_impl(
positions: torch.Tensor,
query: torch.Tensor,
@@ -1337,6 +1370,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_a8wfp4",
op_func=_rocm_aiter_gemm_a8wfp4_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8wfp4_fake,
dispatch_key=current_platform.dispatch_key,
)
# Register rocm aiter rotary embedding custom op
direct_register_custom_op(
op_name="rocm_aiter_triton_rotary_embedding",
@@ -1646,6 +1687,18 @@ class rocm_aiter_ops:
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
@staticmethod
def gemm_a8wfp4(
x: torch.Tensor,
w: torch.Tensor,
x_scales: torch.Tensor,
w_scales: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_a8wfp4(
x, w, x_scales, w_scales, out_dtype
)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,

View File

@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkOCP_MX,
QuarkScheme,
QuarkW4A8_MXFP4_FP8,
QuarkW8A8Fp8,
QuarkW8A8Int8,
)
@@ -350,6 +351,31 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_w4a8_mxfp4_fp8(
self,
weight_quant: dict[str, Any] | None,
input_quant: dict[str, Any] | None,
) -> bool:
if weight_quant is None or input_quant is None:
return False
is_weight_mxfp4 = (
weight_quant.get("dtype") == "fp4"
and weight_quant.get("qscheme") == "per_group"
and weight_quant.get("group_size") == 32
and weight_quant.get("scale_format") == "e8m0"
and not weight_quant.get("is_dynamic")
)
is_input_fp8 = (
input_quant.get("dtype") == "fp8_e4m3"
and input_quant.get("qscheme") == "per_tensor"
and not input_quant.get("is_dynamic") # Static per-tensor
and input_quant.get("symmetric") is True # Symmetric quantization
)
return is_weight_mxfp4 and is_input_fp8
def _is_w_ocp_mx_a_x(
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
) -> bool:
@@ -504,6 +530,12 @@ class QuarkConfig(QuantizationConfig):
is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"),
)
elif self._is_w4a8_mxfp4_fp8(weight_config, input_config):
is_w4a8_supported = self._check_scheme_supported(
QuarkW4A8_MXFP4_FP8.get_min_capability(), error=False
)
if is_w4a8_supported:
return QuarkW4A8_MXFP4_FP8(weight_config, input_config)
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant

View File

@@ -3,7 +3,14 @@
from .quark_ocp_mx import QuarkOCP_MX
from .quark_scheme import QuarkScheme
from .quark_w4a8_mxfp4_fp8 import QuarkW4A8_MXFP4_FP8
from .quark_w8a8_fp8 import QuarkW8A8Fp8
from .quark_w8a8_int8 import QuarkW8A8Int8
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"]
__all__ = [
"QuarkScheme",
"QuarkW8A8Fp8",
"QuarkW8A8Int8",
"QuarkOCP_MX",
"QuarkW4A8_MXFP4_FP8",
]

View File

@@ -0,0 +1,218 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from fractions import Fraction
from typing import Any
import torch
import torch.nn.functional as F
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
)
from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme
logger = init_logger(__name__)
__all__ = ["QuarkW4A8_MXFP4_FP8"]
OCP_MX_BLOCK_SIZE = 32
class QuarkW4A8_MXFP4_FP8(QuarkScheme):
"""
- Weights: MXFP4 with E8M0 scales per block of 32
- Activations: FP8 E4M3 (static per-tensor quantization)
Uses the AITER Triton kernel and falls back to emulation if AITER not available.
"""
def __init__(
self,
weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any],
):
self.out_dtype = None
self.weight_dtype = "mxfp4"
self.packed_factor: Fraction = Fraction(2, 1) # 2 FP4 values per byte
self.weight_block_size = OCP_MX_BLOCK_SIZE
self.is_static_input_scheme = not input_quant_spec.get("is_dynamic")
self.input_qscheme = input_quant_spec.get("qscheme") # "per_tensor"
self.fp8_min, self.fp8_max = get_fp8_min_max()
self.fp8_dtype = current_platform.fp8_dtype()
if not self.is_static_input_scheme:
raise NotImplementedError(
"Dynamic FP8 activation quantization is not yet supported "
"for W4A8. The current implementation expects static per-tensor "
"FP8 scales stored in the checkpoint."
)
kernel_supported_gpu = False
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx950
kernel_supported_gpu = on_gfx950()
self.use_aiter_kernel = (
is_aiter_found_and_supported()
and self.is_static_input_scheme
and kernel_supported_gpu
)
if not self.use_aiter_kernel:
logger.warning_once(
"[W4A8 MXFP4+FP8] Aiter Triton kernel not found. Using emulation mode."
)
@classmethod
def get_min_capability(cls) -> int:
return 70
def get_packed_dim(self, dim: int) -> int:
assert dim % 2 == 0, f"Dimension {dim} must be even for MXFP4 packing"
return dim // 2
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# MXFP4 WEIGHT (packed, 2 values per byte)
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE (E8M0 format, per block of 32)
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.weight_block_size,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (FP8 per-tensor static scale)
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(
data=torch.empty(
len(output_partition_sizes),
dtype=torch.float32,
),
weight_loader=weight_loader,
)
# Initialize to avoid NaN
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Ensuring weights & scales are non-trainable
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
if self.is_static_input_scheme:
input_scale = layer.input_scale.data
# For fused modules (QKV), take the max scale
if input_scale.numel() != 1:
input_scale = input_scale.max()
layer.input_scale = torch.nn.Parameter(
torch.tensor(input_scale, dtype=torch.float32),
requires_grad=False,
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.use_aiter_kernel:
return self._apply_aiter_kernel(layer, x, bias)
else:
return self._apply_emulation(layer, x, bias)
def _apply_aiter_kernel(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
M = x.shape[0]
out_dtype = x.dtype if self.out_dtype is None else self.out_dtype
input_scale = layer.input_scale
x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)
# Broadcast per-tensor scale to per-row (M, 1) for Aiter kernel
x_scales = input_scale.expand(M, 1).to(dtype=torch.float32, device=x.device)
y = rocm_aiter_ops.gemm_a8wfp4(
x_fp8, layer.weight, x_scales, layer.weight_scale, out_dtype
)
if bias is not None:
y = y + bias
return y
def _apply_emulation(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4,
)
weight_dq = dequant_mxfp4(
layer.weight,
layer.weight_scale,
x.dtype,
)
input_scale = layer.input_scale
x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)
x_dq = (x_fp8.to(x.dtype) * input_scale).to(x.dtype)
return F.linear(x_dq, weight_dq, bias)