Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8 (#35053)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
danisereb
2026-02-24 17:45:13 +02:00
committed by GitHub
parent a0c7081695
commit 9609b1f18d
3 changed files with 230 additions and 11 deletions

View File

@@ -70,6 +70,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_VALUE_DTYPE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
swizzle_mxfp8_scale,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
apply_nvfp4_linear,
@@ -1689,9 +1690,9 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
"Dynamic quantization is not supported."
)
backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION
self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend)
logger.info_once("Using %s backend for MXFP8 GEMM", backend.value)
self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
def create_weights(
self,
@@ -1749,7 +1750,38 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
)
layer.register_parameter("weight_scale", weight_scale)
def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
"""Not swizzled - MXFP8 GEMM emulation"""
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
weight = layer.weight.data # [N, K]
N, K = weight.shape
# 2D weight scale
weight_scale = layer.weight_scale.data
# Swizzle the weight scales
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(
weight_scale_swizzled.contiguous(), requires_grad=False
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Validate weight tensor
if layer.weight.ndim != 2:
raise ValueError(
f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
@@ -1763,15 +1795,23 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
f"quantized with MXFP8."
)
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
# Validate weight scale tensor (should be 2D, not swizzled)
assert layer.weight_scale.ndim == 2, (
f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D"
)
assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, (
f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE},"
f" got {layer.weight_scale.dtype}"
)
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
if self.backend == Mxfp8LinearBackend.EMULATION:
# Swizzled layout is not used
self._process_weights_after_loading_scale_2d(layer)
return
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
# Swizzled layout is required for Flashinfer CUTLASS
self._process_weights_after_loading_scale_1d(layer)
def apply(
self,

View File

@@ -6,6 +6,7 @@ from enum import Enum
import torch
from vllm.logger import init_logger
from vllm.utils import flashinfer as vllm_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -13,6 +14,7 @@ logger = init_logger(__name__)
class Mxfp8LinearBackend(Enum):
EMULATION = "emulation"
FLASHINFER_CUTLASS = "flashinfer-cutlass"
# MXFP8 constants
@@ -21,6 +23,30 @@ MXFP8_SCALE_DTYPE = torch.uint8
MXFP8_BLOCK_SIZE = 32
def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8
factor = scaling_vector_size * 4 # 128
num_m_tiles = (M + 127) // 128
num_k_tiles = (K + factor - 1) // factor
m_padded = num_m_tiles * 128
k_scale_padded = num_k_tiles * 4
scale_cols = K // scaling_vector_size
sf_padded = torch.zeros(
(m_padded, k_scale_padded), dtype=sf.dtype, device=sf.device
)
sf_padded[:M, :scale_cols] = sf
sf_reshaped = sf_padded.view(num_m_tiles, 4, 32, num_k_tiles, 4)
sf_swizzled = sf_reshaped.transpose(1, 3)
return sf_swizzled.contiguous().view(-1)
def _mxfp8_e4m3_quantize_impl(
x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -108,7 +134,7 @@ class Mxfp8LinearOp:
self.backend = backend
def apply(
def _apply_emulation(
self,
input: torch.Tensor,
weight: torch.Tensor,
@@ -132,3 +158,79 @@ class Mxfp8LinearOp:
output = torch.nn.functional.linear(input, weight_bf16, bias)
return output.to(out_dtype)
def _apply_flashinfer_cutlass(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
N, K = weight.shape
input_shape = input.shape
input_2d = input.view(-1, K)
M_orig = input_2d.shape[0]
# Minimum dimension size for F8_128x4 block scaling layout
min_dim = 128
assert min_dim <= K, (
f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
f"in_features is too small for mm_mxfp8."
)
assert K % MXFP8_BLOCK_SIZE == 0, (
f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
)
assert min_dim <= N, (
f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
f"out_features is too small for mm_mxfp8."
)
M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim
if M_padded != M_orig:
pad_rows = M_padded - M_orig
input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows))
input_mxfp8, input_scale = mxfp8_e4m3_quantize(
input_2d,
is_sf_swizzled_layout=True, # Swizzled for best accuracy
)
if not weight.is_contiguous():
weight = weight.contiguous()
output = vllm_flashinfer.mm_mxfp8(
input_mxfp8,
weight.t(),
input_scale,
weight_scale,
out_dtype=out_dtype,
backend="cutlass",
)
if M_padded != M_orig:
output = output[:M_orig, :]
if bias is not None:
output = output + bias
output_shape = (*input_shape[:-1], N)
return output.view(output_shape)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.backend == Mxfp8LinearBackend.EMULATION:
return self._apply_emulation(input, weight, weight_scale, out_dtype, bias)
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
return self._apply_flashinfer_cutlass(
input, weight, weight_scale, out_dtype, bias
)