Add support for ModelOpt MXFP8 dense models (#33786)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -17,6 +17,7 @@ following `quantization.quant_algo` values:
|
||||
- `FP8_PER_CHANNEL_PER_TOKEN`: per-channel weight scale and dynamic per-token activation quantization.
|
||||
- `FP8_PB_WO` (ModelOpt may emit `fp8_pb_wo`): block-scaled FP8 weight-only (typically 128×128 blocks).
|
||||
- `NVFP4`: ModelOpt NVFP4 checkpoints (use `quantization="modelopt_fp4"`).
|
||||
- `MXFP8`: ModelOpt MXFP8 checkpoints (use `quantization="modelopt_mxfp8"`).
|
||||
|
||||
## Quantizing HuggingFace Models with PTQ
|
||||
|
||||
|
||||
@@ -878,6 +878,7 @@ class ModelConfig:
|
||||
"moe_wna16",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"modelopt_mxfp8",
|
||||
"petit_nvfp4",
|
||||
# Ensure heavy backends are probed last to avoid unnecessary
|
||||
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||
|
||||
@@ -494,6 +494,7 @@ class FusedMoEQuantConfig:
|
||||
"mxfp4",
|
||||
"mxfp6_e3m2",
|
||||
"mxfp6_e2m3",
|
||||
"mxfp8",
|
||||
}
|
||||
assert not isinstance(weight_dtype, str) or weight_dtype in {
|
||||
"nvfp4",
|
||||
@@ -501,6 +502,7 @@ class FusedMoEQuantConfig:
|
||||
"mxfp6_e3m2",
|
||||
"mxfp6_e2m3",
|
||||
"int4",
|
||||
"mxfp8",
|
||||
}
|
||||
|
||||
if weight_dtype is None:
|
||||
|
||||
@@ -17,6 +17,7 @@ QuantizationMethods = Literal[
|
||||
"fp_quant",
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"modelopt_mxfp8",
|
||||
"gguf",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
@@ -119,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .inc import INCConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .petit import PetitNvFp4Config
|
||||
@@ -133,6 +134,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"fp_quant": FPQuantConfig,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"modelopt_mxfp8": ModelOptMxFp8Config,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
|
||||
@@ -63,6 +63,13 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
get_marlin_input_dtype,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
MXFP8_BLOCK_SIZE,
|
||||
MXFP8_SCALE_DTYPE,
|
||||
MXFP8_VALUE_DTYPE,
|
||||
Mxfp8LinearBackend,
|
||||
Mxfp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
apply_nvfp4_linear,
|
||||
convert_to_nvfp4_linear_kernel_format,
|
||||
@@ -103,6 +110,8 @@ QUANT_ALGOS = [
|
||||
"FP8_PB_WO",
|
||||
# FP4
|
||||
"NVFP4",
|
||||
# MXFP8
|
||||
"MXFP8",
|
||||
]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
@@ -386,12 +395,12 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", ""))
|
||||
if "FP8" in quant_algo.upper():
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
|
||||
if "FP8" in quant_algo.upper():
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
|
||||
return None
|
||||
@@ -1547,3 +1556,239 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE
|
||||
ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
|
||||
class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
||||
"""Config class for ModelOpt MXFP8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_mxfp8_serialized: bool,
|
||||
kv_cache_quant_algo: str | None,
|
||||
exclude_modules: list[str],
|
||||
) -> None:
|
||||
super().__init__(exclude_modules)
|
||||
self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized
|
||||
|
||||
if not is_checkpoint_mxfp8_serialized:
|
||||
raise ValueError(
|
||||
"MXFP8 quantization requires a serialized checkpoint. "
|
||||
"Dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"Detected ModelOpt MXFP8 checkpoint. Please note that "
|
||||
"the format is experimental and could change in future."
|
||||
)
|
||||
|
||||
self.kv_cache_quant_algo = kv_cache_quant_algo
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "modelopt_mxfp8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
|
||||
return 100
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# MXFP8 does not yet support MoE models
|
||||
if isinstance(layer, FusedMoE):
|
||||
raise NotImplementedError(
|
||||
"MXFP8 quantization does not yet support MoE models. "
|
||||
"Please use FP8 or NVFP4 quantization for MoE models."
|
||||
)
|
||||
return super().get_quant_method(layer, prefix)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt MXFP8 config should be used based on
|
||||
quantization config."""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _from_config(
|
||||
cls,
|
||||
*,
|
||||
quant_method: str,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
original_config: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> "ModelOptMxFp8Config":
|
||||
is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper()
|
||||
|
||||
# For MXFP8, validate required fields in the config
|
||||
if is_checkpoint_mxfp8_serialized and "quantization" in original_config:
|
||||
quant_config = original_config["quantization"]
|
||||
required_fields = ["kv_cache_quant_algo", "exclude_modules"]
|
||||
missing_fields = [
|
||||
field for field in required_fields if field not in quant_config
|
||||
]
|
||||
if missing_fields:
|
||||
raise ValueError(
|
||||
f"MXFP8 quantization requires the following fields in "
|
||||
f"hf_quant_config.json: {missing_fields}"
|
||||
)
|
||||
|
||||
return cls(
|
||||
is_checkpoint_mxfp8_serialized,
|
||||
kv_cache_quant_method,
|
||||
exclude_modules,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for ModelOpt MXFP8 quantization."""
|
||||
|
||||
def __init__(self, quant_config: ModelOptMxFp8Config) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
if not self.quant_config.is_checkpoint_mxfp8_serialized:
|
||||
raise ValueError(
|
||||
"MXFP8 currently only supports serialized checkpoints. "
|
||||
"Dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION
|
||||
self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend)
|
||||
logger.info_once("Using %s backend for MXFP8 GEMM", backend.value)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
|
||||
if not self.quant_config.is_checkpoint_mxfp8_serialized:
|
||||
raise ValueError(
|
||||
"MXFP8 quantization was selected, but checkpoint is not "
|
||||
"MXFP8 serialized. Dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
if input_size_per_partition % MXFP8_BLOCK_SIZE != 0:
|
||||
raise ValueError(
|
||||
f"MXFP8 requires input dimension to be divisible by "
|
||||
f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}"
|
||||
)
|
||||
|
||||
# Weight tensor: FP8 E4M3 format
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=MXFP8_VALUE_DTYPE,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K
|
||||
weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // MXFP8_BLOCK_SIZE,
|
||||
dtype=MXFP8_SCALE_DTYPE,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if layer.weight.ndim != 2:
|
||||
raise ValueError(
|
||||
f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
|
||||
f"with shape {tuple(layer.weight.shape)}"
|
||||
)
|
||||
|
||||
if layer.weight.dtype != MXFP8_VALUE_DTYPE:
|
||||
raise ValueError(
|
||||
f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), "
|
||||
f"got {layer.weight.dtype}. The checkpoint may not be properly "
|
||||
f"quantized with MXFP8."
|
||||
)
|
||||
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
|
||||
# Slice weight_scale to match weight dimensions (handles padding)
|
||||
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if layer.weight.dtype != MXFP8_VALUE_DTYPE:
|
||||
raise ValueError(
|
||||
f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
|
||||
)
|
||||
if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
|
||||
raise ValueError(
|
||||
f"Weight scale dtype {layer.weight_scale.dtype} != "
|
||||
f"expected {MXFP8_SCALE_DTYPE}"
|
||||
)
|
||||
|
||||
return self.mxfp8_linear_op.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
# Register the method classes for ModelOptMxFp8Config
|
||||
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
@@ -1,24 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def mxfp8_e4m3_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
try:
|
||||
from flashinfer import mxfp8_quantize as mxfp8_e4m3_quantize
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `flashinfer` is required to do "
|
||||
"MX-FP8 quantization. Please install it with"
|
||||
"`pip install flashinfer`"
|
||||
) from err
|
||||
class Mxfp8LinearBackend(Enum):
|
||||
EMULATION = "emulation"
|
||||
|
||||
x_q, x_scales = mxfp8_e4m3_quantize(x, is_sf_swizzled_layout=False)
|
||||
if x_scales.ndim == 1:
|
||||
|
||||
# MXFP8 constants
|
||||
MXFP8_VALUE_DTYPE = torch.float8_e4m3fn
|
||||
MXFP8_SCALE_DTYPE = torch.uint8
|
||||
MXFP8_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def _mxfp8_e4m3_quantize_impl(
|
||||
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
|
||||
|
||||
x_q, x_scales = flashinfer_mxfp8_quantize(
|
||||
x, is_sf_swizzled_layout=is_sf_swizzled_layout
|
||||
)
|
||||
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
|
||||
x_scales = x_scales.view(x.size(0), -1)
|
||||
return x_q, x_scales
|
||||
|
||||
|
||||
def mxfp8_e4m3_quantize(
|
||||
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.mxfp8_quantize(x, is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def dequant_mxfp8_to_bf16(x: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
|
||||
"""Dequantize MXFP8 tensor to BF16."""
|
||||
x_float = x.to(torch.float32)
|
||||
|
||||
num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE
|
||||
x_blocked = x_float.view(*x.shape[:-1], num_blocks, MXFP8_BLOCK_SIZE)
|
||||
|
||||
descale = torch.exp2(scales.to(torch.float32) - 127.0)
|
||||
|
||||
dequantized = x_blocked * descale.unsqueeze(-1)
|
||||
|
||||
dequantized = dequantized.view(*x.shape)
|
||||
|
||||
return dequantized.to(torch.bfloat16)
|
||||
|
||||
|
||||
def mxfp8_e4m3_quantize_fake(
|
||||
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fake implementation for torch.compile tracing."""
|
||||
fp_data = torch.empty_like(x, dtype=MXFP8_VALUE_DTYPE)
|
||||
|
||||
block_size = MXFP8_BLOCK_SIZE
|
||||
|
||||
if x.ndim == 2:
|
||||
M, N = x.shape
|
||||
K = (N + block_size - 1) // block_size
|
||||
if is_sf_swizzled_layout:
|
||||
M_padded = ((M + 127) // 128) * 128
|
||||
K_padded = ((K + 3) // 4) * 4
|
||||
scales = torch.empty(
|
||||
M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
|
||||
)
|
||||
else:
|
||||
scales = torch.empty((M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
|
||||
elif x.ndim == 3:
|
||||
B, M, N = x.shape
|
||||
K = (N + block_size - 1) // block_size
|
||||
if is_sf_swizzled_layout:
|
||||
M_padded = ((M + 127) // 128) * 128
|
||||
K_padded = ((K + 3) // 4) * 4
|
||||
scales = torch.empty(
|
||||
B * M_padded * K_padded, dtype=MXFP8_SCALE_DTYPE, device=x.device
|
||||
)
|
||||
else:
|
||||
scales = torch.empty((B, M, K), dtype=MXFP8_SCALE_DTYPE, device=x.device)
|
||||
else:
|
||||
scale_shape = list(x.shape)
|
||||
scale_shape[-1] = (x.shape[-1] + block_size - 1) // block_size
|
||||
scales = torch.empty(scale_shape, dtype=MXFP8_SCALE_DTYPE, device=x.device)
|
||||
|
||||
return fp_data, scales
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mxfp8_quantize",
|
||||
op_func=_mxfp8_e4m3_quantize_impl,
|
||||
fake_impl=mxfp8_e4m3_quantize_fake,
|
||||
)
|
||||
|
||||
|
||||
class Mxfp8LinearOp:
|
||||
def __init__(self, backend: Mxfp8LinearBackend):
|
||||
if backend not in Mxfp8LinearBackend:
|
||||
raise ValueError(f"Unsupported backend: {backend}")
|
||||
|
||||
self.backend = backend
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Validate weight_scale dtype and shape (must be 2D for TORCH backend)
|
||||
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
|
||||
raise ValueError(
|
||||
f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
|
||||
f"got {weight_scale.dtype}."
|
||||
)
|
||||
if weight_scale.ndim != 2:
|
||||
raise ValueError(
|
||||
f"TORCH backend requires 2D weight_scale, got {weight_scale.ndim}D. "
|
||||
f"Ensure process_weights_after_loading was called."
|
||||
)
|
||||
|
||||
weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale)
|
||||
|
||||
output = torch.nn.functional.linear(input, weight_bf16, bias)
|
||||
return output.to(out_dtype)
|
||||
|
||||
Reference in New Issue
Block a user