[ROCm][Quantization] add quark w4a8 mxfp4_fp8 for LinearLayer (#35316)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
@@ -861,6 +861,39 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
|
||||
return out, residual_out
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_a8wfp4_impl(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
x_scales: torch.Tensor,
|
||||
w_scales: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.triton.gemm_a8wfp4 import gemm_a8wfp4
|
||||
|
||||
M, N = x.shape[0], w.shape[0]
|
||||
y = torch.empty(M, N, dtype=out_dtype, device=x.device)
|
||||
gemm_a8wfp4(
|
||||
x=x,
|
||||
w=w,
|
||||
y=y,
|
||||
x_scales=x_scales,
|
||||
w_scales=w_scales,
|
||||
dtype=out_dtype,
|
||||
config=None,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_a8wfp4_fake(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
x_scales: torch.Tensor,
|
||||
w_scales: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(x.shape[0], w.shape[0], dtype=out_dtype, device=x.device)
|
||||
|
||||
|
||||
def _triton_rotary_embedding_impl(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@@ -1337,6 +1370,14 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_a8wfp4",
|
||||
op_func=_rocm_aiter_gemm_a8wfp4_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_a8wfp4_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
# Register rocm aiter rotary embedding custom op
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_triton_rotary_embedding",
|
||||
@@ -1646,6 +1687,18 @@ class rocm_aiter_ops:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
|
||||
|
||||
@staticmethod
|
||||
def gemm_a8wfp4(
|
||||
x: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
x_scales: torch.Tensor,
|
||||
w_scales: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_a8wfp4(
|
||||
x, w, x_scales, w_scales, out_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp4_gemm_dynamic_qaunt(
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkOCP_MX,
|
||||
QuarkScheme,
|
||||
QuarkW4A8_MXFP4_FP8,
|
||||
QuarkW8A8Fp8,
|
||||
QuarkW8A8Int8,
|
||||
)
|
||||
@@ -350,6 +351,31 @@ class QuarkConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_w4a8_mxfp4_fp8(
|
||||
self,
|
||||
weight_quant: dict[str, Any] | None,
|
||||
input_quant: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_weight_mxfp4 = (
|
||||
weight_quant.get("dtype") == "fp4"
|
||||
and weight_quant.get("qscheme") == "per_group"
|
||||
and weight_quant.get("group_size") == 32
|
||||
and weight_quant.get("scale_format") == "e8m0"
|
||||
and not weight_quant.get("is_dynamic")
|
||||
)
|
||||
|
||||
is_input_fp8 = (
|
||||
input_quant.get("dtype") == "fp8_e4m3"
|
||||
and input_quant.get("qscheme") == "per_tensor"
|
||||
and not input_quant.get("is_dynamic") # Static per-tensor
|
||||
and input_quant.get("symmetric") is True # Symmetric quantization
|
||||
)
|
||||
|
||||
return is_weight_mxfp4 and is_input_fp8
|
||||
|
||||
def _is_w_ocp_mx_a_x(
|
||||
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
|
||||
) -> bool:
|
||||
@@ -504,6 +530,12 @@ class QuarkConfig(QuantizationConfig):
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"),
|
||||
)
|
||||
elif self._is_w4a8_mxfp4_fp8(weight_config, input_config):
|
||||
is_w4a8_supported = self._check_scheme_supported(
|
||||
QuarkW4A8_MXFP4_FP8.get_min_capability(), error=False
|
||||
)
|
||||
if is_w4a8_supported:
|
||||
return QuarkW4A8_MXFP4_FP8(weight_config, input_config)
|
||||
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX(
|
||||
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
|
||||
@@ -3,7 +3,14 @@
|
||||
|
||||
from .quark_ocp_mx import QuarkOCP_MX
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w4a8_mxfp4_fp8 import QuarkW4A8_MXFP4_FP8
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"]
|
||||
__all__ = [
|
||||
"QuarkScheme",
|
||||
"QuarkW8A8Fp8",
|
||||
"QuarkW8A8Int8",
|
||||
"QuarkOCP_MX",
|
||||
"QuarkW4A8_MXFP4_FP8",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from fractions import Fraction
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_fp8_min_max,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
__all__ = ["QuarkW4A8_MXFP4_FP8"]
|
||||
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
class QuarkW4A8_MXFP4_FP8(QuarkScheme):
|
||||
"""
|
||||
- Weights: MXFP4 with E8M0 scales per block of 32
|
||||
- Activations: FP8 E4M3 (static per-tensor quantization)
|
||||
|
||||
Uses the AITER Triton kernel and falls back to emulation if AITER not available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any],
|
||||
):
|
||||
self.out_dtype = None
|
||||
|
||||
self.weight_dtype = "mxfp4"
|
||||
self.packed_factor: Fraction = Fraction(2, 1) # 2 FP4 values per byte
|
||||
self.weight_block_size = OCP_MX_BLOCK_SIZE
|
||||
|
||||
self.is_static_input_scheme = not input_quant_spec.get("is_dynamic")
|
||||
self.input_qscheme = input_quant_spec.get("qscheme") # "per_tensor"
|
||||
|
||||
self.fp8_min, self.fp8_max = get_fp8_min_max()
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
if not self.is_static_input_scheme:
|
||||
raise NotImplementedError(
|
||||
"Dynamic FP8 activation quantization is not yet supported "
|
||||
"for W4A8. The current implementation expects static per-tensor "
|
||||
"FP8 scales stored in the checkpoint."
|
||||
)
|
||||
|
||||
kernel_supported_gpu = False
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
|
||||
kernel_supported_gpu = on_gfx950()
|
||||
|
||||
self.use_aiter_kernel = (
|
||||
is_aiter_found_and_supported()
|
||||
and self.is_static_input_scheme
|
||||
and kernel_supported_gpu
|
||||
)
|
||||
|
||||
if not self.use_aiter_kernel:
|
||||
logger.warning_once(
|
||||
"[W4A8 MXFP4+FP8] Aiter Triton kernel not found. Using emulation mode."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_packed_dim(self, dim: int) -> int:
|
||||
assert dim % 2 == 0, f"Dimension {dim} must be even for MXFP4 packing"
|
||||
return dim // 2
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# MXFP4 WEIGHT (packed, 2 values per byte)
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE (E8M0 format, per block of 32)
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.weight_block_size,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (FP8 per-tensor static scale)
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(
|
||||
len(output_partition_sizes),
|
||||
dtype=torch.float32,
|
||||
),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
# Initialize to avoid NaN
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Ensuring weights & scales are non-trainable
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = layer.input_scale.data
|
||||
# For fused modules (QKV), take the max scale
|
||||
if input_scale.numel() != 1:
|
||||
input_scale = input_scale.max()
|
||||
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
torch.tensor(input_scale, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.use_aiter_kernel:
|
||||
return self._apply_aiter_kernel(layer, x, bias)
|
||||
else:
|
||||
return self._apply_emulation(layer, x, bias)
|
||||
|
||||
def _apply_aiter_kernel(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
M = x.shape[0]
|
||||
out_dtype = x.dtype if self.out_dtype is None else self.out_dtype
|
||||
|
||||
input_scale = layer.input_scale
|
||||
x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)
|
||||
|
||||
# Broadcast per-tensor scale to per-row (M, 1) for Aiter kernel
|
||||
x_scales = input_scale.expand(M, 1).to(dtype=torch.float32, device=x.device)
|
||||
|
||||
y = rocm_aiter_ops.gemm_a8wfp4(
|
||||
x_fp8, layer.weight, x_scales, layer.weight_scale, out_dtype
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
y = y + bias
|
||||
|
||||
return y
|
||||
|
||||
def _apply_emulation(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4,
|
||||
)
|
||||
|
||||
weight_dq = dequant_mxfp4(
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
x.dtype,
|
||||
)
|
||||
|
||||
input_scale = layer.input_scale
|
||||
x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)
|
||||
x_dq = (x_fp8.to(x.dtype) * input_scale).to(x.dtype)
|
||||
|
||||
return F.linear(x_dq, weight_dq, bias)
|
||||
Reference in New Issue
Block a user