# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum import torch from torch.nn.parameter import Parameter from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, MoEActivation, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.all2all_utils import ( maybe_make_prepare_finalize, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, mxfp4_mxfp8_moe_quant_config, mxfp4_w4a16_moe_quant_config, ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( BatchedMarlinExperts, MarlinExperts, ) from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( OAITritonExperts, UnfusedOAITritonExperts, ) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( CK_MXFP4_MOE_DIM_ALIGNMENT, _can_support_mxfp4, _swizzle_mxfp4, get_padding_alignment, ) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import round_up logger = init_logger(__name__) # enum for mxfp4 backend class Mxfp4Backend(Enum): NONE = 0 # FlashInfer Backend SM100_FI_MXFP4_MXFP8_TRTLLM = 1 SM100_FI_MXFP4_MXFP8_CUTLASS = 2 SM100_FI_MXFP4_BF16 = 3 SM90_FI_MXFP4_BF16 = 4 # Marlin Backend MARLIN = 5 # Triton Backend TRITON = 6 CK = 7 def get_mxfp4_backend_with_lora() -> Mxfp4Backend: """ Not all MXFP4 backends support LoRA. Select backends that are known to have LoRA support. """ if not current_platform.is_cuda(): return Mxfp4Backend.NONE # If FlashInfer is not available, try either Marlin or Triton triton_kernels_supported = ( has_triton_kernels() # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 and (9, 0) <= current_platform.get_device_capability() < (11, 0) ) if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") return Mxfp4Backend.TRITON logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") return Mxfp4Backend.MARLIN def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: # Backend Selection if with_lora_support: return get_mxfp4_backend_with_lora() if current_platform.is_cuda(): if ( current_platform.is_device_capability(90) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 ): logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") return Mxfp4Backend.SM90_FI_MXFP4_BF16 elif ( current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS ): logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS elif ( current_platform.is_device_capability_family(100) and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 ): logger.info_once( "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local" ) return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM elif current_platform.is_device_capability_family(100) and has_flashinfer(): logger.info_once( "Using FlashInfer MXFP4 BF16 backend for SM100, " "For faster performance on SM100, consider setting " "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " "accuracy." ) return Mxfp4Backend.SM100_FI_MXFP4_BF16 elif ( current_platform.is_device_capability_family(100) or current_platform.is_device_capability(90) ) and not has_flashinfer(): logger.warning_once( "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " "is not available. This may result in degraded performance. " "Please `pip install vllm[flashinfer]` for best results." ) # If FlashInfer is not available, try either Marlin or Triton triton_kernels_supported = ( has_triton_kernels() # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 and (9, 0) <= current_platform.get_device_capability() < (11, 0) ) if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: logger.info_once("Using Marlin backend") return Mxfp4Backend.MARLIN else: logger.info_once("Using Triton backend") return Mxfp4Backend.TRITON elif current_platform.is_xpu(): logger.info_once("Using xpu backend on XPU") return Mxfp4Backend.MARLIN elif current_platform.is_rocm(): from vllm.platforms.rocm import on_gfx950 if rocm_aiter_ops.is_enabled() and on_gfx950(): logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)") return Mxfp4Backend.CK elif has_triton_kernels(): logger.info_once("Using Triton backend") return Mxfp4Backend.TRITON return Mxfp4Backend.NONE class Mxfp4Config(QuantizationConfig): def __init__(self, ignored_layers: list[str] | None = None): super().__init__() self.ignored_layers = ignored_layers @classmethod def from_config(cls, config): return cls() @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def get_name(cls) -> QuantizationMethods: return "mxfp4" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16] @classmethod def get_config_filenames(cls) -> list[str]: return [] def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": if isinstance(layer, LinearBase): if self.ignored_layers and is_layer_skipped( prefix=prefix, ignored_layers=self.ignored_layers, fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() # TODO: Add support for MXFP4 Linear Method. # MXFP4 LinearMethod is available in AMD-Quark, refer to that implementation # if you are interested in enabling MXFP4 here. logger.debug_once( "MXFP4 linear layer is not implemented - falling back to " "UnquantizedLinearMethod.", scope="local", ) return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): if current_platform.is_xpu(): return XpuMxfp4MoEMethod(layer.moe_config) else: quant_method = Mxfp4MoEMethod(layer.moe_config) return quant_method elif isinstance(layer, Attention): # TODO: Add support for MXFP4 Attention. logger.debug_once( "MXFP4 attention layer is not implemented. " "Skipping quantization for this layer.", scope="local", ) return None def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: """MXFP4 config always uses MXFP4 quantization.""" return True class Mxfp4MoEMethod(FusedMoEMethodBase): """MXFP4 MoE quantization method.""" def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.weight_dtype = "mxfp4" self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) # CK's pre-compiled MXFP4 MoE GEMM kernel instances have dimension # alignment requirements. Fall back to Triton when not met. if ( self.mxfp4_backend == Mxfp4Backend.CK and moe.intermediate_size_per_partition % CK_MXFP4_MOE_DIM_ALIGNMENT != 0 ): if has_triton_kernels(): logger.warning_once( "CK MXFP4 MoE GEMM does not support " "intermediate_size_per_partition=%d (not a multiple of " "%d). Falling back to Triton backend.", moe.intermediate_size_per_partition, CK_MXFP4_MOE_DIM_ALIGNMENT, ) self.mxfp4_backend = Mxfp4Backend.TRITON else: raise ValueError( f"CK MXFP4 MoE GEMM does not support " f"intermediate_size_per_partition=" f"{moe.intermediate_size_per_partition} (not a multiple " f"of {CK_MXFP4_MOE_DIM_ALIGNMENT}) and no Triton " f"fallback is available. Use a compatible " f"tensor_parallel_size." ) assert self.mxfp4_backend != Mxfp4Backend.NONE, ( f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found" "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)." "Please check your environment and try again." ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} # Initialized in process_weights_after_loading for CUTLASS/SM90 backends self.moe_kernel: mk.FusedMoEKernel | None = None def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): self.num_experts = num_experts weight_dtype = torch.uint8 scale_dtype = torch.uint8 # FIXME (zyongye): ship after torch and safetensors support mxfp4 # is_torch_mxfp4_available = ( # hasattr(torch, "float4_e2m1fn_x2") and # hasattr(torch, "float8_e8m0fnu")) # if is_torch_mxfp4_available: # weight_dtype = torch.float4_e2m1fn_x2 # scale_dtype = torch.float8_e8m0fnu mxfp4_block = 32 intermediate_size_per_partition_after_pad = intermediate_size_per_partition if self.mxfp4_backend == Mxfp4Backend.MARLIN: # The moe marlin kernel requires that for each linear # n % 256 == 0 and k % 128 == 0. # In gate_up_proj: # n = 2 * intermediate_size_per_partition_after_pad # k = hidden_size # In down_proj # n = hidden_size # k = intermediate_size_per_partition_after_pad intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128 ) if current_platform.is_xpu(): hidden_size = round_up(hidden_size, 128) else: hidden_size = round_up(hidden_size, 256) layer.params_dtype = params_dtype layer.num_experts = num_experts layer.hidden_size = hidden_size layer.intermediate_size_per_partition = ( intermediate_size_per_partition_after_pad ) elif ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 ): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256 ) hidden_size = round_up(hidden_size, 256) elif ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 ): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 128 ) hidden_size = round_up(hidden_size, 128) elif current_platform.is_rocm(): pad_align = get_padding_alignment() intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, pad_align ) hidden_size = round_up(hidden_size, pad_align) else: intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 64 ) self.intermediate_size = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0) self.intermediate_pad = ( intermediate_size_per_partition_after_pad - intermediate_size_per_partition ) # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition_after_pad, hidden_size // 2, dtype=weight_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w13_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition_after_pad, hidden_size // mxfp4_block, dtype=scale_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) w13_bias = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition_after_pad, dtype=torch.bfloat16, ), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, intermediate_size_per_partition_after_pad // 2, dtype=weight_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) w2_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, intermediate_size_per_partition_after_pad // mxfp4_block, dtype=scale_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w2_weight_scale, extra_weight_attrs) w2_bias = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, dtype=torch.bfloat16, ), requires_grad=False, ) layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): if self.mxfp4_backend == Mxfp4Backend.MARLIN: prepare_moe_fp4_layer_for_marlin( layer, input_dtype=get_marlin_input_dtype() ) self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.moe_quant_config is not None prepare_finalize = maybe_make_prepare_finalize( moe=self.moe, quant_config=self.moe_quant_config, routing_tables=layer._maybe_init_expert_routing_tables(), allow_new_interface=True, ) assert prepare_finalize is not None self.moe_kernel = mk.FusedMoEKernel( prepare_finalize, MarlinExperts( self.moe, self.moe_quant_config, ), inplace=not self.moe.disable_inplace, shared_experts=None, ) elif ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 ): from flashinfer.fp4_quantization import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache layer.gemm1_alpha = Parameter( torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False, ) layer.gemm1_beta = Parameter( torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False, ) layer.gemm1_clamp_limit = Parameter( torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False, ) sf_block_size = 32 # mxfp4 block size assert ( layer.w13_weight.dim() == 3 and layer.w13_weight.shape[0] == self.num_experts and layer.w13_weight.shape[1] == self.intermediate_size * 2 and layer.w13_weight.shape[2] == self.hidden_size // 2 ) assert ( layer.w13_weight_scale.dim() == 3 and layer.w13_weight_scale.shape[0] == self.num_experts and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size ) assert ( layer.w2_weight.dim() == 3 and layer.w2_weight.shape[0] == self.num_experts and layer.w2_weight.shape[1] == self.hidden_size and layer.w2_weight.shape[2] == self.intermediate_size // 2 ) assert ( layer.w2_weight_scale.dim() == 3 and layer.w2_weight_scale.shape[1] == self.hidden_size and layer.w2_weight_scale.shape[2] == self.intermediate_size // sf_block_size ) assert ( layer.w13_bias.dim() == 2 and layer.w13_bias.shape[0] == self.num_experts and layer.w13_bias.shape[1] == self.intermediate_size * 2 ) assert ( layer.w2_bias.dim() == 2 and layer.w2_bias.shape[0] == self.num_experts and layer.w2_bias.shape[1] == self.hidden_size ) w13_weight_scale = layer.w13_weight_scale.data w2_weight_scale = layer.w2_weight_scale.data w13_weight = layer.w13_weight.data w2_weight = layer.w2_weight.data w13_bias = layer.w13_bias.data.to(torch.float32) w2_bias = layer.w2_bias.data.to(torch.float32) # Swap w1 and w3 as the definition of # swiglu is different in the trtllm-gen def swap_every_two_rows(x, axis=-1): shape = x.shape if axis < 0: axis = len(shape) + axis # Create a new shape with pairs swapped along specified axis new_shape = list(shape) new_shape[axis] = shape[axis] // 2 new_shape.insert(axis + 1, 2) # Reshape to expose pairs, swap them, and reshape back x = x.reshape(*new_shape) x = x.flip(axis + 1) new_shape = list(shape) return x.reshape(*new_shape) w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2) w13_weight = swap_every_two_rows(w13_weight, -2) w13_bias = swap_every_two_rows(w13_bias, -1) # Do not interleave as the checkpoint is already interleaved # Shuffle weights and scaling factors for transposed mma output gemm1_weights_mxfp4_shuffled = [] gemm1_scales_mxfp4_shuffled = [] gemm2_weights_mxfp4_shuffled = [] gemm2_scales_mxfp4_shuffled = [] gemm1_bias_shuffled = [] gemm2_bias_shuffled = [] epilogue_tile_m = 128 # FIXME: this depends on the kernel internals for i in range(self.num_experts): # w13 weight shuffling permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_weight[i].view(torch.uint8), epilogue_tile_m, ) gemm1_weights_mxfp4_shuffled.append( w13_weight[i] .view(torch.uint8)[permute_indices.to(w13_weight.device)] .contiguous() ) # w13 scale shuffling permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_weight_scale[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, ) gemm1_scales_mxfp4_shuffled.append( nvfp4_block_scale_interleave( w13_weight_scale[i] .view(torch.uint8)[ permute_sf_indices.to(w13_weight_scale.device) ] .contiguous() ) ) # w13 bias shuffling permute_bias_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) gemm1_bias_shuffled.append( w13_bias[i] .clone() .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] .contiguous() ) # w2 weight shuffling permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, ) gemm2_weights_mxfp4_shuffled.append( w2_weight[i] .view(torch.uint8)[permute_indices.to(w2_weight.device)] .contiguous() ) # w2 scale shuffling permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_weight_scale[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, ) gemm2_scales_mxfp4_shuffled.append( nvfp4_block_scale_interleave( w2_weight_scale[i] .view(torch.uint8)[ permute_sf_indices.to(w2_weight_scale.device) ] .contiguous() ) ) # w2 bias shuffling permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) gemm2_bias_shuffled.append( w2_bias[i] .clone() .reshape(-1, 1)[permute_indices.to(w2_bias.device)] .contiguous() ) w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) w13_weight_scale = ( torch.stack(gemm1_scales_mxfp4_shuffled) .reshape( self.num_experts, 2 * self.intermediate_size, self.hidden_size // sf_block_size, ) .view(torch.float8_e4m3fn) ) w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) w2_weight_scale = ( torch.stack(gemm2_scales_mxfp4_shuffled) .reshape( self.num_experts, self.hidden_size, self.intermediate_size // sf_block_size, ) .view(torch.float8_e4m3fn) ) layer.w13_weight = Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) layer.w2_weight = Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) layer.w13_bias = Parameter( torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), requires_grad=False, ) layer.w2_bias = Parameter( torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1), requires_grad=False, ) elif ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 ): sf_block_size = 32 # mxfp4 block size # Common shape assertions assert ( layer.w13_weight.dim() == 3 and layer.w13_weight.shape[0] == self.num_experts and layer.w13_weight.shape[1] == self.intermediate_size * 2 and layer.w13_weight.shape[2] == self.hidden_size // 2 ) assert ( layer.w13_weight_scale.dim() == 3 and layer.w13_weight_scale.shape[0] == self.num_experts and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size ) assert ( layer.w2_weight.dim() == 3 and layer.w2_weight.shape[0] == self.num_experts and layer.w2_weight.shape[1] == self.hidden_size and layer.w2_weight.shape[2] == self.intermediate_size // 2 ) assert ( layer.w2_weight_scale.dim() == 3 and layer.w2_weight_scale.shape[1] == self.hidden_size and layer.w2_weight_scale.shape[2] == self.intermediate_size // sf_block_size ) assert ( layer.w13_bias.dim() == 2 and layer.w13_bias.shape[0] == self.num_experts and layer.w13_bias.shape[1] == self.intermediate_size * 2 ) assert ( layer.w2_bias.dim() == 2 and layer.w2_bias.shape[0] == self.num_experts and layer.w2_bias.shape[1] == self.hidden_size ) # De-interleave and swap for w13 weight, bias, and scales w13_w = layer.w13_weight.data gate_w, up_w = w13_w[:, ::2, :], w13_w[:, 1::2, :] deinterleaved_w13_w = torch.cat([gate_w, up_w], dim=1) w1_w, w3_w = torch.chunk(deinterleaved_w13_w, 2, dim=1) w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) w13_b = layer.w13_bias.data.to(torch.float32) gate_b, up_b = w13_b[:, ::2], w13_b[:, 1::2] deinterleaved_w13_b = torch.cat([gate_b, up_b], dim=1) b1, b3 = torch.chunk(deinterleaved_w13_b, 2, dim=-1) w13_bias_swapped = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) w13_s = layer.w13_weight_scale.data gate_s, up_s = w13_s[:, ::2, :], w13_s[:, 1::2, :] deinterleaved_w13_s = torch.cat([gate_s, up_s], dim=1) s1, s3 = torch.chunk(deinterleaved_w13_s, 2, dim=1) w13_scale_swapped = torch.cat([s3, s1], dim=1) if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: from flashinfer import block_scale_interleave orig_shape = w13_scale_swapped.shape w13_scale_interleaved = block_scale_interleave( w13_scale_swapped.view(torch.uint8) ).reshape(orig_shape) w2_s = layer.w2_weight_scale.data orig_shape = w2_s.shape w2_scale_interleaved = block_scale_interleave( w2_s.view(torch.uint8) ).reshape(orig_shape) layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False) layer.w13_weight_scale = Parameter( w13_scale_interleaved, requires_grad=False ) layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False) layer.w2_weight_scale = Parameter( w2_scale_interleaved, requires_grad=False ) elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: def _interleave_mxfp4_cutlass_sm90(w): w_shape = w.shape w_interleaved = w.reshape( w_shape[0], w_shape[1], (w_shape[2] // 4), 4 ) w_interleaved = w_interleaved.permute(0, 2, 1, 3) w_interleaved = w_interleaved.reshape( w_shape[0], w_shape[2] // 4, w_shape[1] * 4 ) return w_interleaved w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8) w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales) w2_weight_scale = layer.w2_weight_scale.data w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales) layer.w13_weight = torch.nn.Parameter( torch.cat([w3_w, w1_w], dim=1), requires_grad=False ) layer.w13_bias = torch.nn.Parameter( w13_bias_swapped, requires_grad=False ) layer.w13_weight_scale = torch.nn.Parameter( w31_scales_interleaved, requires_grad=False ) layer.w2_weight_scale = torch.nn.Parameter( w2_scales_interleaved, requires_grad=False ) # theses two kernels go through the `flashinfer_cutlass_fused_moe` path from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.moe_quant_config is not None prepare_finalize = maybe_make_prepare_finalize( moe=self.moe, quant_config=self.moe_quant_config, routing_tables=layer._maybe_init_expert_routing_tables(), allow_new_interface=True, ) assert prepare_finalize is not None self.moe_kernel = mk.FusedMoEKernel( prepare_finalize, FlashInferExperts( moe_config=self.moe, quant_config=self.moe_quant_config, ), shared_experts=None, ) elif self.mxfp4_backend == Mxfp4Backend.CK: if layer.w13_bias is not None: layer.w13_bias.data = layer.w13_bias.data.to(torch.float32) if layer.w2_bias.data is not None: layer.w2_bias.data = layer.w2_bias.data.to(torch.float32) e, n, k = layer.w13_weight.shape layer.w13_weight.view(torch.uint8).copy_( layer.w13_weight.data.view(torch.uint8) .view(e, n // 2, 2, k) .permute(0, 2, 1, 3) .contiguous() .view(e, n, k) ) layer.w13_weight_scale.data = ( layer.w13_weight_scale.data.view(e, n // 2, 2, -1) .permute(0, 2, 1, 3) .contiguous() .view(e, n, -1) ) layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2) layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2) layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4( layer.w13_weight, 16, True ) shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]), self.num_experts, True, ) layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4( layer.w2_weight, 16, False ) shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]), self.num_experts, False, ) layer.w13_bias.data = ( layer.w13_bias.data.view(-1, n // 2, 2) .permute(0, 2, 1) .contiguous() .view(-1, n) ) layer.w13_weight_scale = torch.nn.Parameter( shuffled_w13_scale, requires_grad=False ) layer.w2_weight_scale = torch.nn.Parameter( shuffled_w2_scale, requires_grad=False ) # replace_parameter(layer, "w13_bias", w13_bias) # replace_parameter(layer, "w13_weight_scale", w13_weight_scale) # replace_parameter(layer, "w2_weight_scale", w2_weight_scale) # replace_parameter(layer, "w13_weight", w13_weight) # replace_parameter(layer, "w2_weight", w2_weight) elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig w13_bias = layer.w13_bias.to(torch.float32) w2_bias = layer.w2_bias.to(torch.float32) layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False) # Ideally we'd use FusedMoEModularKernel.prepare_finalize object # (stored in self.fused_experts) to determine if the MoE has a # batched activation format. As self.fused_experts is not # initialized at this point, we resort to checking the MoE config # directly. is_batched_moe = self.moe.use_deepep_ll_kernels if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8 w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( layer.w13_weight, layer.w13_weight_scale, num_warps ) w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( layer.w2_weight, layer.w2_weight_scale, num_warps ) self.w13_precision_config = PrecisionConfig( weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) ) self.w2_precision_config = PrecisionConfig( weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) ) self.w13_weight = w13_weight self.w2_weight = w2_weight del layer.w13_weight del layer.w2_weight layer.w13_weight = w13_weight layer.w2_weight = w2_weight else: raise ValueError( f"Unsupported mxfp4_backend: {self.mxfp4_backend}: " f"should be one of: {list(Mxfp4Backend)}." ) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: if self.mxfp4_backend == Mxfp4Backend.MARLIN: return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, ) elif self.mxfp4_backend == Mxfp4Backend.TRITON: w1_scale = self.w13_precision_config w2_scale = self.w2_precision_config return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, w1_scale=w1_scale, w2_scale=w2_scale, ) elif self.mxfp4_backend in [ Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS, ]: return mxfp4_mxfp8_moe_quant_config( w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, ) elif self.mxfp4_backend in [ Mxfp4Backend.SM100_FI_MXFP4_BF16, Mxfp4Backend.SM90_FI_MXFP4_BF16, Mxfp4Backend.CK, ]: return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, ) else: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale return ocp_mx_moe_quant_config( quant_dtype="mxfp4", w1_bias=layer.w13_bias, w2_bias=layer.w2_bias, w1_scale=w1_scale, w2_scale=w2_scale, ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, ) -> mk.FusedMoEExpertsModular: if ( prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts ): if self.mxfp4_backend == Mxfp4Backend.MARLIN: max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None assert self.moe_quant_config is not None return BatchedMarlinExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, moe_config=self.moe, ) else: raise NotImplementedError( f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for " "EP batched experts format" ) else: assert self.moe_quant_config is not None if ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 ): # B200 code-path kwargs = { # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) elif self.mxfp4_backend == Mxfp4Backend.MARLIN: return MarlinExperts(self.moe, self.moe_quant_config) elif self.mxfp4_backend == Mxfp4Backend.TRITON: if self.moe.is_lora_enabled: return UnfusedOAITritonExperts(self.moe, self.moe_quant_config) return OAITritonExperts(self.moe, self.moe_quant_config) else: raise NotImplementedError( f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" ) @property def is_monolithic(self) -> bool: if self.moe.is_lora_enabled: return False return ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.TRITON or self.mxfp4_backend == Mxfp4Backend.CK ) def apply( self, layer: FusedMoE, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic if layer.enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") assert _can_support_mxfp4( layer.use_grouped_topk, layer.topk_group, layer.num_expert_group, layer.expert_map, layer.custom_routing_function, layer.e_score_correction_bias, layer.apply_router_weight_on_input, layer.scoring_func, layer.activation, layer.eplb_state.expert_load_view, layer.eplb_state.logical_to_physical_map, layer.eplb_state.logical_replica_count, ), "MXFP4 are not supported with this configuration." assert ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.MARLIN ) assert self.moe_kernel is not None return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=layer.activation, global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, shared_experts_input=shared_experts_input, ) def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic if layer.enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") assert _can_support_mxfp4( layer.use_grouped_topk, layer.topk_group, layer.num_expert_group, layer.expert_map, layer.custom_routing_function, layer.e_score_correction_bias, layer.apply_router_weight_on_input, layer.scoring_func, layer.activation, layer.eplb_state.expert_load_view, layer.eplb_state.logical_to_physical_map, layer.eplb_state.logical_replica_count, ), "MXFP4 are not supported with this configuration." if ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 ): from flashinfer import trtllm_fp4_block_scale_moe if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16: assert x.dtype == torch.bfloat16 x_quant = x x_scale = None elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: from flashinfer import mxfp8_quantize x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) trtllm_gen_output = trtllm_fp4_block_scale_moe( routing_logits=router_logits.to(torch.bfloat16), routing_bias=None, hidden_states=x_quant, hidden_states_scale=x_scale, gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2) gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2) gemm1_bias=layer.w13_bias, # fp32 per expert per channel gemm1_alpha=layer.gemm1_alpha, # fp32 per expert gemm1_beta=layer.gemm1_beta, # fp32 per expert gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2) gemm2_weights_scale=layer.w2_weight_scale, # ue8m0 gemm2_bias=layer.w2_bias, # fp32 per expert per channel output1_scale_scalar=None, output1_scale_gate_scalar=None, output2_scale_scalar=None, num_experts=layer.global_num_experts, top_k=layer.top_k, n_group=None, topk_group=None, intermediate_size=self.intermediate_size, # padded to multiple of 256 local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=self.num_experts, routed_scaling_factor=None, routing_method_type=1 if layer.renormalize else 0, do_finalize=True, tune_max_num_tokens=max(self.max_capture_size, 1), )[0] return trtllm_gen_output elif self.mxfp4_backend == Mxfp4Backend.CK: topk_weights, topk_ids = rocm_aiter_ops.fused_topk( x, router_logits, layer.top_k, True ) output = rocm_aiter_ops.fused_moe( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"), quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"), w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, doweight_stage1=False, hidden_pad=self.hidden_pad // 128 * 128, intermediate_pad=self.intermediate_pad // 64 * 64 * 2, bias1=layer.w13_bias, bias2=layer.w2_bias, ) return output elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward, ) return triton_kernel_moe_forward( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, gating_output=router_logits, topk=layer.top_k, renormalize=layer.renormalize, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, quant_config=self.moe_quant_config, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") class XpuMxfp4MoEMethod(Mxfp4MoEMethod): def __init__(self, moe_config: FusedMoEConfig): super().__init__(moe_config) self.moe_config = moe_config def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): super().create_weights( layer, num_experts, hidden_size, intermediate_size_per_partition, params_dtype, **extra_weight_attrs, ) self.original_hidden_size = hidden_size def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass @property def is_monolithic(self) -> bool: return True def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor: assert layer.activation == MoEActivation.SWIGLUOAI, ( "Only swiglu_oai activation is supported for " f"XPU MXFP4 MoE, not {layer.activation}." ) from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe M, _ = x.size() routing_weights = torch.empty( M, layer.top_k, dtype=torch.float32, device=x.device ) selected_experts = torch.empty( M, layer.top_k, dtype=torch.int32, device=x.device ) token_expert_indices = torch.empty( M, layer.top_k, dtype=torch.int32, device=x.device ) if layer.use_grouped_topk: routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk( x, router_logits, layer.top_k, layer.renormalize, n_expert_group=layer.num_expert_group, n_topk_group=layer.topk_group, scoring_func=layer.scoring_func, routed_scaling_factor=layer.routed_scaling_factor, bias=layer.e_score_correction_bias, ) else: torch.ops._moe_C.topk_softmax( routing_weights, selected_experts, token_expert_indices, router_logits, layer.renormalize, layer.e_score_correction_bias, ) return xpu_fused_moe( hidden_states=x, w13=layer.w13_weight, w13_bias=layer.w13_bias if self.moe.has_bias else None, w13_scales=layer.w13_weight_scale, w2=layer.w2_weight, w2_bias=layer.w2_bias if self.moe.has_bias else None, w2_scales=layer.w2_weight_scale, topk_weights=routing_weights, topk_ids=selected_experts, n_experts_per_token=layer.top_k, activation=layer.activation.value, num_experts=layer.local_num_experts, is_mxfp4=True, )