Add support for ModelOpt MXFP8 dense models (#33786)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
danisereb
2026-02-08 21:16:48 +02:00
committed by GitHub
parent 1ecfabe525
commit 084aa19f02
6 changed files with 375 additions and 14 deletions

View File

@@ -17,6 +17,7 @@ following `quantization.quant_algo` values:
- `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization.
- `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks).
- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`).
- `MXFP8`: ModelOpt MXFP8 checkpoints (use `quantization="modelopt_mxfp8"`).
## Quantizing HuggingFace Models with PTQ

View File

@@ -878,6 +878,7 @@ class ModelConfig:
"moe_wna16",
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)

View File

@@ -494,6 +494,7 @@ class FusedMoEQuantConfig:
"mxfp4",
"mxfp6_e3m2",
"mxfp6_e2m3",
"mxfp8",
}
assert not isinstance(weight_dtype, str) or weight_dtype in {
"nvfp4",
@@ -501,6 +502,7 @@ class FusedMoEQuantConfig:
"mxfp6_e3m2",
"mxfp6_e2m3",
"int4",
"mxfp8",
}
if weight_dtype is None:

View File

@@ -17,6 +17,7 @@ QuantizationMethods = Literal[
"fp_quant",
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"gguf",
"gptq_marlin",
"awq_marlin",
@@ -119,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .inc import INCConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config
@@ -133,6 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fp_quant": FPQuantConfig,
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"modelopt_mxfp8": ModelOptMxFp8Config,
"gguf": GGUFConfig,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,

View File

@@ -63,6 +63,13 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
MXFP8_SCALE_DTYPE,
MXFP8_VALUE_DTYPE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
@@ -103,6 +110,8 @@ QUANT_ALGOS = [
"FP8_PB_WO",
# FP4
"NVFP4",
# MXFP8
"MXFP8",
]
KV_CACHE_QUANT_ALGOS = ["FP8"]
@@ -386,12 +395,12 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", ""))
if "FP8" in quant_algo.upper():
if quant_algo.upper() == "FP8":
return "modelopt"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
if "FP8" in quant_algo.upper():
if quant_algo.upper() == "FP8":
return "modelopt"
return None
@@ -1547,3 +1556,239 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
class ModelOptMxFp8Config(ModelOptQuantConfigBase):
"""Config class for ModelOpt MXFP8."""
def __init__(
self,
is_checkpoint_mxfp8_serialized: bool,
kv_cache_quant_algo: str | None,
exclude_modules: list[str],
) -> None:
super().__init__(exclude_modules)
self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized
if not is_checkpoint_mxfp8_serialized:
raise ValueError(
"MXFP8 quantization requires a serialized checkpoint. "
"Dynamic quantization is not supported."
)
logger.warning(
"Detected ModelOpt MXFP8 checkpoint. Please note that "
"the format is experimental and could change in future."
)
self.kv_cache_quant_algo = kv_cache_quant_algo
def get_name(self) -> QuantizationMethods:
return "modelopt_mxfp8"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
return 100
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
# MXFP8 does not yet support MoE models
if isinstance(layer, FusedMoE):
raise NotImplementedError(
"MXFP8 quantization does not yet support MoE models. "
"Please use FP8 or NVFP4 quantization for MoE models."
)
return super().get_quant_method(layer, prefix)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""Detect if this ModelOpt MXFP8 config should be used based on
quantization config."""
if hf_quant_cfg is None:
return None
# Use the community standard 'quant_method'
quant_method = hf_quant_cfg.get("quant_method", "").lower()
# Only proceed if the method is explicitly "modelopt"
if quant_method != "modelopt":
return None
# Look for ModelOpt-specific config structure
if "quantization" in hf_quant_cfg:
quant_config = hf_quant_cfg["quantization"]
if isinstance(quant_config, dict):
quant_algo = str(quant_config.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
else:
# Check for compressed-tensors style config with specific quant_algo
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
if "MXFP8" in quant_algo:
return "modelopt_mxfp8"
return None
@classmethod
def _from_config(
cls,
*,
quant_method: str,
kv_cache_quant_method: str | None,
exclude_modules: list[str],
original_config: dict[str, Any],
**kwargs: Any,
) -> "ModelOptMxFp8Config":
is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper()
# For MXFP8, validate required fields in the config
if is_checkpoint_mxfp8_serialized and "quantization" in original_config:
quant_config = original_config["quantization"]
required_fields = ["kv_cache_quant_algo", "exclude_modules"]
missing_fields = [
field for field in required_fields if field not in quant_config
]
if missing_fields:
raise ValueError(
f"MXFP8 quantization requires the following fields in "
f"hf_quant_config.json: {missing_fields}"
)
return cls(
is_checkpoint_mxfp8_serialized,
kv_cache_quant_method,
exclude_modules,
)
class ModelOptMxFp8LinearMethod(LinearMethodBase):
"""Linear method for ModelOpt MXFP8 quantization."""
def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
self.quant_config = quant_config
if not self.quant_config.is_checkpoint_mxfp8_serialized:
raise ValueError(
"MXFP8 currently only supports serialized checkpoints. "
"Dynamic quantization is not supported."
)
backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION
self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend)
logger.info_once("Using %s backend for MXFP8 GEMM", backend.value)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_mxfp8_serialized:
raise ValueError(
"MXFP8 quantization was selected, but checkpoint is not "
"MXFP8 serialized. Dynamic quantization is not supported."
)
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
raise ValueError(
f"MXFP8 requires input dimension to be divisible by "
f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}"
)
# Weight tensor: FP8 E4M3 format
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=MXFP8_VALUE_DTYPE,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K
weight_scale = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // MXFP8_BLOCK_SIZE,
dtype=MXFP8_SCALE_DTYPE,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if layer.weight.ndim != 2:
raise ValueError(
f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
f"with shape {tuple(layer.weight.shape)}"
)
if layer.weight.dtype != MXFP8_VALUE_DTYPE:
raise ValueError(
f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), "
f"got {layer.weight.dtype}. The checkpoint may not be properly "
f"quantized with MXFP8."
)
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if layer.weight.dtype != MXFP8_VALUE_DTYPE:
raise ValueError(
f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
)
if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"Weight scale dtype {layer.weight_scale.dtype} != "
f"expected {MXFP8_SCALE_DTYPE}"
)
return self.mxfp8_linear_op.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
)
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod

View File

@@ -1,24 +1,134 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
try:
from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize
except ImportError as err:
raise ImportError(
"The package `flashinfer` is required to do "
"MX-FP8 quantization. Please install it with"
"`pip install flashinfer`"
) from err
class Mxfp8LinearBackend(Enum):
EMULATION = "emulation"
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
if x_scales.ndim == 1:
# MXFP8 constants
MXFP8_VALUE_DTYPE = torch.float8_e4m3fn
MXFP8_SCALE_DTYPE = torch.uint8
MXFP8_BLOCK_SIZE = 32
def _mxfp8_e4m3_quantize_impl(
x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
x_q, x_scales = flashinfer_mxfp8_quantize(
x, is_sf_swizzled_layout=is_sf_swizzled_layout
)
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
x_scales = x_scales.view(x.size(0), -1)
return x_q, x_scales
def mxfp8_e4m3_quantize(
x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout)
def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
"""Dequantize MXFP8 tensor to BF16."""
x_float = x.to(torch.float32)
num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE
x_blocked = x_float.view(*x.shape[:-1], num_blocks, MXFP8_BLOCK_SIZE)
descale = torch.exp2(scales.to(torch.float32) - 127.0)
dequantized = x_blocked * descale.unsqueeze(-1)
dequantized = dequantized.view(*x.shape)
return dequantized.to(torch.bfloat16)
def mxfp8_e4m3_quantize_fake(
x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile tracing."""
fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE)
block_size = MXFP8_BLOCK_SIZE
if x.ndim == 2:
M, N = x.shape
K = (N + block_size - 1) // block_size
if is_sf_swizzled_layout:
M_padded = ((M + 127) // 128) * 128
K_padded = ((K + 3) // 4) * 4
scales = torch.empty(
M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
)
else:
scales = torch.empty((M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
elif x.ndim == 3:
B, M, N = x.shape
K = (N + block_size - 1) // block_size
if is_sf_swizzled_layout:
M_padded = ((M + 127) // 128) * 128
K_padded = ((K + 3) // 4) * 4
scales = torch.empty(
B * M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
)
else:
scales = torch.empty((B, M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
else:
scale_shape = list(x.shape)
scale_shape[-1] = (x.shape[-1] + block_size - 1) // block_size
scales = torch.empty(scale_shape, dtype=MXFP8_SCALE_DTYPE, device=x.device)
return fp_data, scales
direct_register_custom_op(
op_name="mxfp8_quantize",
op_func=_mxfp8_e4m3_quantize_impl,
fake_impl=mxfp8_e4m3_quantize_fake,
)
class Mxfp8LinearOp:
def __init__(self, backend: Mxfp8LinearBackend):
if backend not in Mxfp8LinearBackend:
raise ValueError(f"Unsupported backend: {backend}")
self.backend = backend
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# Validate weight_scale dtype and shape (must be 2D for TORCH backend)
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
f"got {weight_scale.dtype}."
)
if weight_scale.ndim != 2:
raise ValueError(
f"TORCH backend requires 2D weight_scale, got {weight_scale.ndim}D. "
f"Ensure process_weights_after_loading was called."
)
weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale)
output = torch.nn.functional.linear(input, weight_bf16, bias)
return output.to(out_dtype)