Files
vllm/vllm/model_executor/layers/quantization/mxfp4.py
2026-02-11 00:27:15 -08:00

1195 lines
47 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
mxfp4_mxfp8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_can_support_mxfp4,
_swizzle_mxfp4,
get_padding_alignment,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
# enum for mxfp4 backend
class Mxfp4Backend(Enum):
NONE = 0
# FlashInfer Backend
SM100_FI_MXFP4_MXFP8_TRTLLM = 1
SM100_FI_MXFP4_MXFP8_CUTLASS = 2
SM100_FI_MXFP4_BF16 = 3
SM90_FI_MXFP4_BF16 = 4
# Marlin Backend
MARLIN = 5
# Triton Backend
TRITON = 6
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
"""
Not all MXFP4 backends support LoRA. Select backends that are known to
have LoRA support.
"""
if not current_platform.is_cuda():
return Mxfp4Backend.NONE
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend")
return Mxfp4Backend.TRITON
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.MARLIN
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
# Backend Selection
if with_lora_support:
return get_mxfp4_backend_with_lora()
if current_platform.is_cuda():
if (
current_platform.is_device_capability(90)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
):
logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90")
return Mxfp4Backend.SM90_FI_MXFP4_BF16
elif (
current_platform.is_device_capability_family(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
):
logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100")
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
elif (
current_platform.is_device_capability_family(100)
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
elif current_platform.is_device_capability_family(100) and has_flashinfer():
logger.info_once(
"Using FlashInfer MXFP4 BF16 backend for SM100, "
"For faster performance on SM100, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact "
"accuracy."
)
return Mxfp4Backend.SM100_FI_MXFP4_BF16
elif (
current_platform.is_device_capability_family(100)
or current_platform.is_device_capability(90)
) and not has_flashinfer():
logger.warning_once(
"MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results."
)
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
and (9, 0) <= current_platform.get_device_capability() < (11, 0)
)
if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported:
logger.info_once("Using Marlin backend")
return Mxfp4Backend.MARLIN
else:
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_xpu():
logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
return Mxfp4Backend.NONE
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: list[str] | None = None):
super().__init__()
self.ignored_layers = ignored_layers
@classmethod
def from_config(cls, config):
return cls()
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp4"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedLinearMethod()
# TODO: Add support for MXFP4 Linear Method.
# MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation
# if you are interested in enabling MXFP4 here.
logger.debug_once(
"MXFP4 linear layer is not implemented - falling back to "
"UnquantizedLinearMethod.",
scope="local",
)
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
if current_platform.is_xpu():
return XpuMxfp4MoEMethod(layer.moe_config)
else:
quant_method = Mxfp4MoEMethod(layer.moe_config)
quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, Attention):
# TODO: Add support for MXFP4 Attention.
logger.debug_once(
"MXFP4 attention layer is not implemented. "
"Skipping quantization for this layer.",
scope="local",
)
return None
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""MXFP4 config always uses MXFP4 quantization."""
return True
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
"Please check your environment and try again."
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
# FIXME (zyongye): ship after torch and safetensors support mxfp4
# is_torch_mxfp4_available = (
# hasattr(torch, "float4_e2m1fn_x2") and
# hasattr(torch, "float8_e8m0fnu"))
# if is_torch_mxfp4_available:
# weight_dtype = torch.float4_e2m1fn_x2
# scale_dtype = torch.float8_e8m0fnu
mxfp4_block = 32
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
# The moe marlin kernel requires that for each linear
# n % 256 == 0 and k % 128 == 0.
# In gate_up_proj:
# n = 2 * intermediate_size_per_partition_after_pad
# k = hidden_size
# In down_proj
# n = hidden_size
# k = intermediate_size_per_partition_after_pad
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128
)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
else:
hidden_size = round_up(hidden_size, 256)
layer.params_dtype = params_dtype
layer.num_experts = num_experts
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = (
intermediate_size_per_partition_after_pad
)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256
)
hidden_size = round_up(hidden_size, 256)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128
)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
pad_align = get_padding_alignment()
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, pad_align
)
hidden_size = round_up(hidden_size, pad_align)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64
)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
prepare_moe_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_beta = Parameter(
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
)
assert (
layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2
)
assert (
layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size
)
w13_weight_scale = layer.w13_weight_scale.data
w2_weight_scale = layer.w2_weight_scale.data
w13_weight = layer.w13_weight.data
w2_weight = layer.w2_weight.data
w13_bias = layer.w13_bias.data.to(torch.float32)
w2_bias = layer.w2_bias.data.to(torch.float32)
# Swap w1 and w3 as the definition of
# swiglu is different in the trtllm-gen
def swap_every_two_rows(x, axis=-1):
shape = x.shape
if axis < 0:
axis = len(shape) + axis
# Create a new shape with pairs swapped along specified axis
new_shape = list(shape)
new_shape[axis] = shape[axis] // 2
new_shape.insert(axis + 1, 2)
# Reshape to expose pairs, swap them, and reshape back
x = x.reshape(*new_shape)
x = x.flip(axis + 1)
new_shape = list(shape)
return x.reshape(*new_shape)
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
w13_weight = swap_every_two_rows(w13_weight, -2)
w13_bias = swap_every_two_rows(w13_bias, -1)
# Do not interleave as the checkpoint is already interleaved
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_mxfp4_shuffled = []
gemm1_scales_mxfp4_shuffled = []
gemm2_weights_mxfp4_shuffled = []
gemm2_scales_mxfp4_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
# w13 weight shuffling
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm1_weights_mxfp4_shuffled.append(
w13_weight[i]
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
.contiguous()
)
# w13 scale shuffling
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm1_scales_mxfp4_shuffled.append(
nvfp4_block_scale_interleave(
w13_weight_scale[i]
.view(torch.uint8)[
permute_sf_indices.to(w13_weight_scale.device)
]
.contiguous()
)
)
# w13 bias shuffling
permute_bias_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm1_bias_shuffled.append(
w13_bias[i]
.clone()
.reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
.contiguous()
)
# w2 weight shuffling
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
gemm2_weights_mxfp4_shuffled.append(
w2_weight[i]
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
.contiguous()
)
# w2 scale shuffling
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
num_elts_per_sf=16,
)
gemm2_scales_mxfp4_shuffled.append(
nvfp4_block_scale_interleave(
w2_weight_scale[i]
.view(torch.uint8)[
permute_sf_indices.to(w2_weight_scale.device)
]
.contiguous()
)
)
# w2 bias shuffling
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
)
gemm2_bias_shuffled.append(
w2_bias[i]
.clone()
.reshape(-1, 1)[permute_indices.to(w2_bias.device)]
.contiguous()
)
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = (
torch.stack(gemm1_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
2 * self.intermediate_size,
self.hidden_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
w2_weight_scale = (
torch.stack(gemm2_scales_mxfp4_shuffled)
.reshape(
self.num_experts,
self.hidden_size,
self.intermediate_size // sf_block_size,
)
.view(torch.float8_e4m3fn)
)
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False)
layer.w13_bias = Parameter(
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False,
)
layer.w2_bias = Parameter(
torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False,
)
elif (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
):
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_beta = Parameter(
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size
# Common shape assertions
assert (
layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2
)
assert (
layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
)
assert (
layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size
and layer.w2_weight.shape[2] == self.intermediate_size // 2
)
assert (
layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size
)
assert (
layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2
)
assert (
layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size
)
# De-interleave and swap for w13 weight, bias, and scales
w13_w = layer.w13_weight.data
gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :]
deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1)
w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
w13_b = layer.w13_bias.data.to(torch.float32)
gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2]
deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1)
b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1)
w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16)
w13_s = layer.w13_weight_scale.data
gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :]
deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1)
s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1)
w13_scale_swapped = torch.cat([s3, s1], dim=1)
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import block_scale_interleave
orig_shape = w13_scale_swapped.shape
w13_scale_interleaved = block_scale_interleave(
w13_scale_swapped.view(torch.uint8)
).reshape(orig_shape)
w2_s = layer.w2_weight_scale.data
orig_shape = w2_s.shape
w2_scale_interleaved = block_scale_interleave(
w2_s.view(torch.uint8)
).reshape(orig_shape)
layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False)
layer.w13_weight_scale = Parameter(
w13_scale_interleaved, requires_grad=False
)
layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False)
layer.w2_weight_scale = Parameter(
w2_scale_interleaved, requires_grad=False
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
def _interleave_mxfp4_cutlass_sm90(w):
w_shape = w.shape
w_interleaved = w.reshape(
w_shape[0], w_shape[1], (w_shape[2] // 4), 4
)
w_interleaved = w_interleaved.permute(0, 2, 1, 3)
w_interleaved = w_interleaved.reshape(
w_shape[0], w_shape[2] // 4, w_shape[1] * 4
)
return w_interleaved
w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8)
w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales)
w2_weight_scale = layer.w2_weight_scale.data
w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8)
w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales)
layer.w13_weight = torch.nn.Parameter(
torch.cat([w3_w, w1_w], dim=1), requires_grad=False
)
layer.w13_bias = torch.nn.Parameter(
w13_bias_swapped, requires_grad=False
)
layer.w13_weight_scale = torch.nn.Parameter(
w31_scales_interleaved, requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales_interleaved, requires_grad=False
)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight = w13_weight
self.w2_weight = w2_weight
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
else:
raise ValueError(
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
f"should be one of: {list(Mxfp4Backend)}."
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM,
Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS,
]:
return mxfp4_mxfp8_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
else:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
return ocp_mx_moe_quant_config(
quant_dtype="mxfp4",
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
assert self.moe_quant_config is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
moe_config=self.moe,
)
else:
raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
"EP batched experts format"
)
else:
assert self.moe_quant_config is not None
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
# B200 code-path
kwargs = {
"gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
# TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size,
}
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
return MarlinExperts(self.moe, self.moe_quant_config)
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
if self.moe.is_lora_enabled:
return UnfusedOAITritonExperts(self.moe, self.moe_quant_config)
return OAITritonExperts(self.moe, self.moe_quant_config)
else:
raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
)
@property
def is_monolithic(self) -> bool:
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_bias,
layer.w2_bias,
layer.w13_weight_scale,
layer.w2_weight_scale,
topk_weights,
topk_ids,
global_scale1=None,
global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
activation=layer.activation,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
inplace=not self.moe.disable_inplace,
)
assert _can_support_mxfp4(
layer.use_grouped_topk,
layer.topk_group,
layer.num_expert_group,
layer.expert_map,
layer.custom_routing_function,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.eplb_state.expert_load_view,
layer.eplb_state.logical_to_physical_map,
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
assert (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
)
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16)
flashinfer_cutlass_fused_moe(
input=fi_input,
token_selected_experts=topk_ids.to(torch.int).contiguous(),
token_final_scales=topk_weights,
output_dtype=torch.bfloat16,
output=output,
quant_scales=quant_scales,
fc1_expert_biases=layer.w13_bias,
fc2_expert_biases=layer.w2_bias,
swiglu_alpha=layer.gemm1_alpha,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
**extra_kwargs,
)
return output
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
assert _can_support_mxfp4(
layer.use_grouped_topk,
layer.topk_group,
layer.num_expert_group,
layer.expert_map,
layer.custom_routing_function,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.scoring_func,
layer.activation,
layer.eplb_state.expert_load_view,
layer.eplb_state.logical_to_physical_map,
layer.eplb_state.logical_replica_count,
), "MXFP4 are not supported with this configuration."
if (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
from flashinfer import trtllm_fp4_block_scale_moe
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
routing_logits=router_logits.to(torch.bfloat16),
routing_bias=None,
hidden_states=x_quant,
hidden_states_scale=x_scale,
gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
gemm1_bias=layer.w13_bias, # fp32 per expert per channel
gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
gemm1_beta=layer.gemm1_beta, # fp32 per expert
gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
gemm2_bias=layer.w2_bias, # fp32 per expert per channel
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size, # padded to multiple of 256
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=self.num_experts,
routed_scaling_factor=None,
routing_method_type=1 if layer.renormalize else 0,
do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
)
return triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
gating_output=router_logits,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self.moe_config = moe_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
self.original_hidden_size = hidden_size
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert layer.activation == "swigluoai", (
"Only swiglu_oai activation is supported for XPU MXFP4 MoE"
)
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
M, _ = x.size()
routing_weights = torch.empty(
M, layer.top_k, dtype=torch.float32, device=x.device
)
selected_experts = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)
token_expert_indices = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)
if layer.use_grouped_topk:
routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
x,
router_logits,
layer.top_k,
layer.renormalize,
n_expert_group=layer.num_expert_group,
n_topk_group=layer.topk_group,
scoring_func=layer.scoring_func,
routed_scaling_factor=layer.routed_scaling_factor,
bias=layer.e_score_correction_bias,
)
else:
torch.ops._moe_C.topk_softmax(
routing_weights,
selected_experts,
token_expert_indices,
router_logits,
layer.renormalize,
layer.e_score_correction_bias,
)
return xpu_fused_moe(
hidden_states=x,
w13=layer.w13_weight,
w13_bias=layer.w13_bias if self.moe.has_bias else None,
w13_scales=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_bias=layer.w2_bias if self.moe.has_bias else None,
w2_scales=layer.w2_weight_scale,
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
activation=layer.activation,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)