[Quantization] Quark MXFP4 format loading (#16943)

This commit is contained in:
Bowen Bao
2025-05-07 12:05:05 -07:00
committed by GitHub
parent f98e307588
commit db593aa67f
9 changed files with 289 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, cast
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
@@ -15,13 +16,15 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkMoEMethod)
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.platforms import current_platform
__all__ = ["QuarkLinearMethod"]
logger = init_logger(__name__)
class QuarkConfig(QuantizationConfig):
@@ -67,6 +70,7 @@ class QuarkConfig(QuantizationConfig):
return QuarkLinearMethod(self)
if isinstance(layer, Attention):
return QuarkKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return QuarkMoEMethod.get_moe_method(self,
module=layer,
@@ -205,6 +209,54 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
logger.debug("Quark model is not in MX-FP4 format: "
"weight_quant or input_quant not set")
return False
# Input and weight dtype needs to be fp4.
if weight_quant.get("dtype") != "fp4" or input_quant.get(
"dtype") != "fp4":
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
return False
# Input and weight qscheme needs to be per group.
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
"qscheme") != "per_group":
logger.debug("Quark model is not in MX-FP4 format: not per_group")
return False
# Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get(
"group_size") != 32:
logger.debug(
"Quark model is not in MX-FP4 format: not group_size=32")
return False
# Weights need to use static quantization.
if weight_quant.get("is_dynamic") is True:
logger.debug(
"Quark model is not in MX-FP4 format: not weight static")
return False
# Activations need to use dynamic quantization.
if input_quant.get("is_dynamic") is False:
logger.debug(
"Quark model is not in MX-FP4 format: not activation dynamic")
return False
# Activations and weight scales need to be in e8m0 format.
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
"scale_format") != "e8m0":
logger.debug(
"Quark model is not in MX-FP4 format: not scale_format e8m0")
return False
return True
def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:
@@ -269,6 +321,8 @@ class QuarkConfig(QuantizationConfig):
return QuarkW8A8Int8(qscheme=weight_qscheme,
is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"))
elif self._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFP4(weight_config, input_config)
raise NotImplementedError("No quark compatible scheme was found. "
f"Weight config: {weight_config}, "

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from .quark_scheme import QuarkScheme
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
from .quark_w8a8_fp8 import QuarkW8A8Fp8
from .quark_w8a8_int8 import QuarkW8A8Int8
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"]
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]

View File

@@ -0,0 +1,125 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform
__all__ = ["QuarkW4A4MXFP4"]
class QuarkW4A4MXFP4(QuarkScheme):
def __init__(self, weight_quant_spec: Dict[str, Any],
input_quant_spec: Dict[str, Any]):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.emulate = not current_platform.supports_mx()
@classmethod
def get_min_capability(cls) -> int:
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.emulate:
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)
weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data
if not envs.VLLM_QUARK_EMU_MEM_OPT:
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
else:
self.weight_quantizer = weight_quantizer
layer.weight_scale = None
# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=2,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.emulate:
if envs.VLLM_QUARK_EMU_MEM_OPT:
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
else:
dq_w = layer.weight
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
return F.linear(qdq_x, dq_w, bias)
else:
raise NotImplementedError()

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Tuple
import torch
OCP_MX_BLOCK_SIZE = 32
def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> Tuple[torch.Tensor, torch.Tensor]:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)
from quark.torch.quantization.utils import (even_round,
reshape_to_blocks)
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
axis = -1
block_x = reshape_to_blocks(x, block_k, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP4 quantization")
scale = even_round(amax, "fp4")
# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=block_k,
quant_dtype="fp4",
)
return x, scale