# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum import torch from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.math_utils import round_up logger = init_logger(__name__) class FlashinferMoeBackend(Enum): TENSORRT_LLM = "TensorRT-LLM" CUTLASS = "CUTLASS" CUTEDSL = "CUTEDSL" def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: return ( x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape) ) def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor ): """Shuffle weights for for FI TRT-LLM Format""" from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a epilogue_tile_m = 128 num_experts = gemm1_weights.shape[0] hidden_size = gemm1_weights.shape[-1] intermediate_size = gemm1_weights.shape[1] // 2 # Reorder rows of W1 for fused gated activation gemm1_weights_fp8_interleaved = [] for i in range(num_experts): gemm1_weights_fp8_interleaved.append( reorder_rows_for_gated_act_gemm(gemm1_weights[i]) ) # Stack weights and scales for all experts gemm1_weights_fp8_interleaved = torch.stack(gemm1_weights_fp8_interleaved).reshape( num_experts, 2 * intermediate_size, hidden_size ) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] gemm2_weights_fp8_shuffled = [] for i in range(num_experts): gemm1_weights_fp8_shuffled.append( shuffle_matrix_a( gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m ) ) gemm2_weights_fp8_shuffled.append( shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), epilogue_tile_m) ) # Stack weights for all experts gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view( torch.float8_e4m3fn ) gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view( torch.float8_e4m3fn ) def register_scales_for_trtllm_fp8_per_tensor_moe( layer: torch.nn.Module, w13_scale: torch.Tensor, w13_input_scale: torch.Tensor, w2_scale: torch.Tensor, w2_input_scale: torch.Tensor, ) -> None: """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel""" g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( w13_scale=w13_scale, w13_input_scale=w13_input_scale, w2_scale=w2_scale, w2_input_scale=w2_input_scale, ) layer.w2_input_scale_inv = 1.0 / w2_input_scale layer.output1_scales_gate_scalar = g1_alphas layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv layer.output2_scales_scalar = g2_alphas def apply_fi_trtllm_fp8_per_tensor_moe( layer: torch.nn.Module, hidden_states: torch.Tensor, router_logits: torch.Tensor, routing_bias: torch.Tensor | None, top_k: int, num_expert_group: int | None, topk_group: int | None, global_num_experts: int, apply_router_weight_on_input: bool, ) -> torch.Tensor: from flashinfer.fused_moe import RoutingMethodType import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 from vllm.model_executor.models.llama4 import Llama4MoE # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe assert ( hasattr(layer, "output1_scales_scalar") and hasattr(layer, "output1_scales_gate_scalar") and hasattr(layer, "output2_scales_scalar") ) if layer.routing_method_type == RoutingMethodType.Llama4: assert ( not layer.renormalize and layer.custom_routing_function == Llama4MoE.custom_routing_function ), ( "FusedMoE flashinfer kernels with Llama4 routing method are only " "supported for Llama4" ) else: assert layer.custom_routing_function is None, ( "Custom routing function is only supported for Llama4" ) return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( routing_logits=router_logits, routing_bias=routing_bias, hidden_states=hidden_states, input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, gemm2_weights=layer.w2_weight, output1_scales_scalar=layer.output1_scales_scalar, output1_scales_gate_scalar=layer.output1_scales_gate_scalar, output2_scales_scalar=layer.output2_scales_scalar, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, routing_method_type=layer.routing_method_type, ) def make_fp8_moe_alpha_scales_for_fi( w13_scale: torch.Tensor, w13_input_scale: torch.Tensor, w2_scale: torch.Tensor, w2_input_scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: g1_alphas = (w13_scale * w13_input_scale).squeeze() g2_alphas = (w2_scale * w2_input_scale).squeeze() return g1_alphas, g2_alphas def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, "latency": FlashinferMoeBackend.TENSORRT_LLM, "masked_gemm": FlashinferMoeBackend.CUTEDSL, } flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND if flashinfer_moe_backend in backend_map: if ( flashinfer_moe_backend == "latency" and not current_platform.is_device_capability_family(100) ): logger.info_once( "Flashinfer TRTLLM MOE backend is only supported on " "SM100 and later, using CUTLASS backend instead", scope="local", ) return FlashinferMoeBackend.CUTLASS return backend_map[flashinfer_moe_backend] elif current_platform.is_device_capability(90): return FlashinferMoeBackend.CUTLASS raise ValueError( f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. " f"Expected one of {list(backend_map.keys())}." ) def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool: # TODO(shuw@nvidia): Update when new backends are added. backends_supporting_global_sf = ( FlashinferMoeBackend.CUTLASS, FlashinferMoeBackend.TENSORRT_LLM, FlashinferMoeBackend.CUTEDSL, ) return backend in backends_supporting_global_sf def convert_moe_weights_to_flashinfer_trtllm_block_layout( cache_permute_indices: dict[torch.Size, torch.Tensor], w13_weight: torch.Tensor, w2_weight: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Convert expert weights to FlashInfer's block layout. This reorders W13 and W2 into the expected epilogue-tiled block layout and returns the shuffled weight tensors. """ if w13_weight.dtype != torch.bfloat16 or w2_weight.dtype != torch.bfloat16: raise ValueError( "Unquantized Moe Backend FlashInfer TRTLLM requires bfloat16 weights" ) from flashinfer.fused_moe.core import ( _maybe_get_cached_w3_w1_permute_indices, convert_to_block_layout, get_w2_permute_indices_with_cache, ) epilogue_tile_m = 128 block_k = 128 # Reorder rows of W13 and W2 for fused gated activation and convert to the # block layout expected by the FlashInfer kernel. num_experts = w13_weight.shape[0] device_w13 = w13_weight.device device_w2 = w2_weight.device w13_weights_shuffled: list[torch.Tensor] = [] w2_weights_shuffled: list[torch.Tensor] = [] for i in range(num_experts): permute_indices = _maybe_get_cached_w3_w1_permute_indices( cache_permute_indices, w13_weight[i].view(torch.uint8), epilogue_tile_m, ) tmp_weights1 = ( w13_weight[i] .clone() .view(torch.uint8)[permute_indices.to(device_w13)] .contiguous() ) permute_indices = get_w2_permute_indices_with_cache( cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, ) tmp_weights2 = ( w2_weight[i] .clone() .view(torch.uint8)[permute_indices.to(device_w2)] .contiguous() ) tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k) tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k) w13_weights_shuffled.append(tmp_weights1.view(torch.bfloat16)) w2_weights_shuffled.append(tmp_weights2.view(torch.bfloat16)) # Stack weights for all experts and return as BF16 tensors. w13_weights_shuffled_tensor = ( torch.stack(w13_weights_shuffled).view(torch.bfloat16).contiguous() ) w2_weights_shuffled_tensor = ( torch.stack(w2_weights_shuffled).view(torch.bfloat16).contiguous() ) return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor def align_fp8_moe_weights_for_fi( w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool ) -> tuple[torch.Tensor, torch.Tensor, int]: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. Some FlashInfer FP8 MoE kernels require the (gated) intermediate size used for GEMM to be divisible by a small alignment value. When this is not satisfied (e.g. with certain tensor-parallel sizes), we pad the gate/up and down projection weights along the intermediate dim. """ # Current local intermediate size (per partition) is the K dimension of # the down projection. num_experts, hidden_size, intermediate = w2.shape min_alignment = 16 padded_intermediate = round_up(intermediate, min_alignment) if padded_intermediate == intermediate: return w13, w2, intermediate logger.info_once( "Padding intermediate size from %d to %d for up/down projection weights.", intermediate, padded_intermediate, scope="local", ) up_mult = 2 if is_act_and_mul else 1 padded_gate_up_dim = up_mult * padded_intermediate # Pad w13 and w2 along its intermediate dimension. padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size)) padded_w13[:, : w13.shape[1], :] = w13 padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate)) padded_w2[:, :, :intermediate] = w2 return padded_w13, padded_w2, padded_intermediate def prepare_fp8_moe_layer_for_fi( layer: torch.nn.Module, w13: torch.Tensor, w2: torch.Tensor, w13_scale: torch.Tensor, w13_input_scale: torch.Tensor | None, w2_scale: torch.Tensor, w2_input_scale: torch.Tensor | None, is_trtllm: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Convert Fp8 MoE weights to flashinfer kernel format Note that for trtllm we update the model state dict with the scale format needed for these kernels. Note that for per-tensor, we update the layer's intermediate size if the weights needed padding. """ assert hasattr(layer.moe_config, "is_act_and_mul") block_quant = ( hasattr(layer, "weight_block_size") and layer.weight_block_size is not None ) # Some FI MoE kernels require internal alignment of 16 # for the gate-up proj. Pad the weights to respect this. if not block_quant: w13, w2, new_intermediate = align_fp8_moe_weights_for_fi( w13, w2, layer.moe_config.is_act_and_mul, ) layer.intermediate_size_per_partition = new_intermediate # FI kernels require W31 layout rather than W13. if layer.moe_config.is_act_and_mul: w13 = swap_w13_to_w31(w13) if block_quant: w13_scale = swap_w13_to_w31(w13_scale) # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle # and registration of alpha scales. Note that we do not register # as nn.Parameters since they are not needed for weight-reloading. if is_trtllm and not block_quant: assert w13_input_scale is not None assert w2_input_scale is not None rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2) register_scales_for_trtllm_fp8_per_tensor_moe( layer, w13_scale=w13_scale, w13_input_scale=w13_input_scale, w2_scale=w2_scale, w2_input_scale=w2_input_scale, ) return w13, w2, w13_scale