Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8 (#35053)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -70,6 +70,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
MXFP8_VALUE_DTYPE,
|
||||
Mxfp8LinearBackend,
|
||||
Mxfp8LinearOp,
|
||||
swizzle_mxfp8_scale,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
apply_nvfp4_linear,
|
||||
@@ -1689,9 +1690,9 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
"Dynamic quantization is not supported."
|
||||
)
|
||||
|
||||
backend: Mxfp8LinearBackend = Mxfp8LinearBackend.EMULATION
|
||||
self.mxfp8_linear_op = Mxfp8LinearOp(backend=backend)
|
||||
logger.info_once("Using %s backend for MXFP8 GEMM", backend.value)
|
||||
self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
|
||||
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -1749,7 +1750,38 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
|
||||
"""Not swizzled - MXFP8 GEMM emulation"""
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
|
||||
# Slice weight_scale to match weight dimensions (handles padding)
|
||||
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
|
||||
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
|
||||
# 2D weight scale
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
# Swizzle the weight scales
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
|
||||
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(
|
||||
weight_scale_swizzled.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Validate weight tensor
|
||||
if layer.weight.ndim != 2:
|
||||
raise ValueError(
|
||||
f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D "
|
||||
@@ -1763,15 +1795,23 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
f"quantized with MXFP8."
|
||||
)
|
||||
|
||||
weight = layer.weight.data # [N, K]
|
||||
N, K = weight.shape
|
||||
scale_k = K // MXFP8_BLOCK_SIZE
|
||||
# Validate weight scale tensor (should be 2D, not swizzled)
|
||||
assert layer.weight_scale.ndim == 2, (
|
||||
f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D"
|
||||
)
|
||||
assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, (
|
||||
f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE},"
|
||||
f" got {layer.weight_scale.dtype}"
|
||||
)
|
||||
|
||||
# Slice weight_scale to match weight dimensions (handles padding)
|
||||
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
|
||||
if self.backend == Mxfp8LinearBackend.EMULATION:
|
||||
# Swizzled layout is not used
|
||||
self._process_weights_after_loading_scale_2d(layer)
|
||||
return
|
||||
|
||||
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
# Swizzled layout is required for Flashinfer CUTLASS
|
||||
self._process_weights_after_loading_scale_1d(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -6,6 +6,7 @@ from enum import Enum
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import flashinfer as vllm_flashinfer
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -13,6 +14,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class Mxfp8LinearBackend(Enum):
|
||||
EMULATION = "emulation"
|
||||
FLASHINFER_CUTLASS = "flashinfer-cutlass"
|
||||
|
||||
|
||||
# MXFP8 constants
|
||||
@@ -21,6 +23,30 @@ MXFP8_SCALE_DTYPE = torch.uint8
|
||||
MXFP8_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
|
||||
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
|
||||
scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8
|
||||
factor = scaling_vector_size * 4 # 128
|
||||
|
||||
num_m_tiles = (M + 127) // 128
|
||||
num_k_tiles = (K + factor - 1) // factor
|
||||
|
||||
m_padded = num_m_tiles * 128
|
||||
k_scale_padded = num_k_tiles * 4
|
||||
|
||||
scale_cols = K // scaling_vector_size
|
||||
sf_padded = torch.zeros(
|
||||
(m_padded, k_scale_padded), dtype=sf.dtype, device=sf.device
|
||||
)
|
||||
sf_padded[:M, :scale_cols] = sf
|
||||
|
||||
sf_reshaped = sf_padded.view(num_m_tiles, 4, 32, num_k_tiles, 4)
|
||||
|
||||
sf_swizzled = sf_reshaped.transpose(1, 3)
|
||||
|
||||
return sf_swizzled.contiguous().view(-1)
|
||||
|
||||
|
||||
def _mxfp8_e4m3_quantize_impl(
|
||||
x: torch.Tensor, is_sf_swizzled_layout: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -108,7 +134,7 @@ class Mxfp8LinearOp:
|
||||
|
||||
self.backend = backend
|
||||
|
||||
def apply(
|
||||
def _apply_emulation(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -132,3 +158,79 @@ class Mxfp8LinearOp:
|
||||
|
||||
output = torch.nn.functional.linear(input, weight_bf16, bias)
|
||||
return output.to(out_dtype)
|
||||
|
||||
def _apply_flashinfer_cutlass(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
N, K = weight.shape
|
||||
|
||||
input_shape = input.shape
|
||||
input_2d = input.view(-1, K)
|
||||
M_orig = input_2d.shape[0]
|
||||
|
||||
# Minimum dimension size for F8_128x4 block scaling layout
|
||||
min_dim = 128
|
||||
|
||||
assert min_dim <= K, (
|
||||
f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
|
||||
f"in_features is too small for mm_mxfp8."
|
||||
)
|
||||
assert K % MXFP8_BLOCK_SIZE == 0, (
|
||||
f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
|
||||
)
|
||||
assert min_dim <= N, (
|
||||
f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
|
||||
f"out_features is too small for mm_mxfp8."
|
||||
)
|
||||
|
||||
M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim
|
||||
if M_padded != M_orig:
|
||||
pad_rows = M_padded - M_orig
|
||||
input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows))
|
||||
|
||||
input_mxfp8, input_scale = mxfp8_e4m3_quantize(
|
||||
input_2d,
|
||||
is_sf_swizzled_layout=True, # Swizzled for best accuracy
|
||||
)
|
||||
|
||||
if not weight.is_contiguous():
|
||||
weight = weight.contiguous()
|
||||
|
||||
output = vllm_flashinfer.mm_mxfp8(
|
||||
input_mxfp8,
|
||||
weight.t(),
|
||||
input_scale,
|
||||
weight_scale,
|
||||
out_dtype=out_dtype,
|
||||
backend="cutlass",
|
||||
)
|
||||
|
||||
if M_padded != M_orig:
|
||||
output = output[:M_orig, :]
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
output_shape = (*input_shape[:-1], N)
|
||||
return output.view(output_shape)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.backend == Mxfp8LinearBackend.EMULATION:
|
||||
return self._apply_emulation(input, weight, weight_scale, out_dtype, bias)
|
||||
|
||||
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
|
||||
return self._apply_flashinfer_cutlass(
|
||||
input, weight, weight_scale, out_dtype, bias
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user