diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 647108cc4..2215fbb03 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -77,6 +77,18 @@ def _supports_routing_method( raise ValueError("Unsupported quantization scheme.") +def _supports_routing_method_bf16( + routing_method: RoutingMethodType, +) -> bool: + return routing_method in [ + RoutingMethodType.Default, + RoutingMethodType.Renormalize, + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Llama4, + RoutingMethodType.RenormalizeNaive, + ] + + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """Supports TRTLLM Kernel does not support EPLB.""" return not moe_parallel_config.enable_eplb @@ -115,6 +127,34 @@ def is_supported_config_trtllm( return True, None +def is_supported_config_trtllm_bf16( + moe_config: FusedMoEConfig, + activation_format: mk.FusedMoEActivationFormat, +) -> tuple[bool, str | None]: + """ + This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config + for BF16 unquantized kernels. + """ + + def _make_reason(reason: str) -> str: + return f"kernel does not support {reason}" + + if not _supports_current_device(): + return False, _make_reason("current device") + elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): + return False, _make_reason("no act_and_mul MLP layer") + elif not _supports_activation(moe_config.activation): + return False, _make_reason(f"{moe_config.activation} activation") + elif not _supports_parallel_config(moe_config.moe_parallel_config): + return False, _make_reason("parallel config") + elif not _supports_routing_method_bf16(moe_config.routing_method): + return False, _make_reason("routing method") + elif activation_format != mk.FusedMoEActivationFormat.Standard: + return False, _make_reason("activation format") + + return True, None + + def flashinfer_fused_moe_blockscale_fp8( routing_logits: torch.Tensor, routing_bias: torch.Tensor, @@ -287,3 +327,66 @@ direct_register_custom_op( fake_impl=fi_trtllm_fp8_per_tensor_moe_fake, tags=(torch.Tag.needs_fixed_stride_order,), ) + + +def flashinfer_fused_moe_bf16( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor | None, + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int | None, + topk_group: int | None, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routing_method_type: int, + tune_max_num_tokens: int = 8192, +) -> torch.Tensor: + from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe + + return flashinfer_trtllm_bf16_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + gemm1_weights=gemm1_weights, + gemm2_weights=gemm2_weights, + num_experts=num_experts, + top_k=top_k, + n_group=n_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routing_method_type=routing_method_type, + tune_max_num_tokens=tune_max_num_tokens, + ) + + +def flashinfer_fused_moe_bf16_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor | None, + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int | None, + topk_group: int | None, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routing_method_type: int = RoutingMethodType.Renormalize, + tune_max_num_tokens: int = 8192, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="flashinfer_fused_moe_bf16", + op_func=flashinfer_fused_moe_bf16, + fake_impl=flashinfer_fused_moe_bf16_fake, + tags=(torch.Tag.needs_fixed_stride_order,), +) diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 14c3f84e6..e79670f9d 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( + is_supported_config_trtllm_bf16, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) @@ -21,12 +24,13 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( swap_w13_to_w31, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.flashinfer import has_flashinfer, has_flashinfer_cutlass_fused_moe logger = init_logger(__name__) class UnquantizedMoeBackend(Enum): + FLASHINFER_TRTLLM = "FlashInfer TRTLLM" FLASHINFER_CUTLASS = "FlashInfer CUTLASS" AITER = "ROCm AITER" TRITON = "TRITON" @@ -40,6 +44,7 @@ class UnquantizedMoeBackend(Enum): # that is not conform with Modular kernel format. # We will directly call the kernel for those backend UNSUPPORTED_BACKEND = [ + UnquantizedMoeBackend.FLASHINFER_TRTLLM, UnquantizedMoeBackend.CPU, UnquantizedMoeBackend.XPU, UnquantizedMoeBackend.TPU, @@ -48,11 +53,12 @@ UNSUPPORTED_BACKEND = [ def select_unquantized_moe_backend( + moe_config: FusedMoEConfig, use_ep: bool, use_dp: bool, ) -> UnquantizedMoeBackend: """ - Select the primary FP8 MoE backend + Select the primary Unquantized MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ @@ -61,13 +67,27 @@ def select_unquantized_moe_backend( rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if moe_config.moe_parallel_config.use_batched_activation_format + else mk.FusedMoEActivationFormat.Standard + ) + + # Check if FlashInfer TRTLLM BF16 MoE is supported + trtllm_supported, _ = is_supported_config_trtllm_bf16( + moe_config=moe_config, + activation_format=activation_format, + ) + flashinfer_trtllm_moe_enabled = ( + has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported + ) # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS flashinfer_cutlass_moe_enabled = ( has_flashinfer_cutlass_fused_moe() and envs.VLLM_USE_FLASHINFER_MOE_FP16 and use_ep and (not use_dp) - and current_platform.get_device_capability()[0] >= 9 + and current_platform.has_device_capability(90) ) if current_platform.is_rocm(): if rocm_aiter_moe_enabled: @@ -75,12 +95,21 @@ def select_unquantized_moe_backend( else: backend = UnquantizedMoeBackend.TRITON if current_platform.is_cuda(): - if flashinfer_cutlass_moe_enabled: + if flashinfer_trtllm_moe_enabled: + backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM + elif flashinfer_cutlass_moe_enabled: backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS else: - if use_ep and (not use_dp): + if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported: logger.info_once( - "FlashInfer CUTLASS MoE is available for EP" + "FlashInfer TRTLLM MoE is available but not enabled, " + "consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 " + "to enable it for better performance.", + scope="local", + ) + elif use_ep and (not use_dp): + logger.info_once( + "FlashInfer MoE is available for EP" " but not enabled, consider setting" " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", scope="local", diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 4b85cc5c2..2ddaf272b 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -6,6 +6,7 @@ from collections.abc import Callable import torch import torch.nn.functional as F from torch.nn import Module +from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -32,6 +33,9 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( make_unquantized_moe_kernel, select_unquantized_moe_backend, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + convert_moe_weights_to_flashinfer_trtllm_block_layout, +) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -54,6 +58,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.unquantized_backend = select_unquantized_moe_backend( + moe_config=self.moe, use_ep=self.moe.moe_parallel_config.use_ep, use_dp=self.moe.moe_parallel_config.dp_size > 1, ) @@ -64,7 +69,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul ) self.kernel: mk.FusedMoEModularKernel | None = None - self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu() + self._is_monolithic = ( + current_platform.is_cpu() + or current_platform.is_xpu() + or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM + ) + + if self.is_monolithic: + self.apply_monolithic: Callable = self._select_monolithic() + + def _select_monolithic(self) -> Callable: + """Select the monolithic implementation based on platform.""" + if current_platform.is_cpu(): + return self.forward_monolithic_cpu + elif current_platform.is_xpu(): + return self.forward_monolithic_xpu + else: + return self.forward_monolithic_cuda + + def forward_native( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + return self.forward_cuda(layer, x, topk_weights, topk_ids) @property def is_monolithic(self) -> bool: @@ -211,7 +241,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - if self.unquantized_backend == UnquantizedMoeBackend.XPU: + if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM: + _cache_permute_indices: dict[torch.Size, torch.Tensor] = {} + # Swap halves to arrange as [w3; w1] (kernel expectation) + w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1) + w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1) + layer.w13_weight.data = w13_weight_swapped.contiguous() + w13_weights_shuffled, w2_weights_shuffled = ( + convert_moe_weights_to_flashinfer_trtllm_block_layout( + _cache_permute_indices, + layer.w13_weight.data, + layer.w2_weight.data, + ) + ) + layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False) + layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False) + elif self.unquantized_backend == UnquantizedMoeBackend.XPU: import intel_extension_for_pytorch as ipex ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts @@ -290,6 +335,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.kernel is not None + return self.kernel( hidden_states=x, w1=layer.w13_weight, @@ -303,6 +349,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=layer.expert_map, ) + def forward_monolithic_cuda( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + x: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401 + + assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM + + return torch.ops.vllm.flashinfer_fused_moe_bf16( + routing_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + hidden_states=x, + gemm1_weights=layer.w13_weight, + gemm2_weights=layer.w2_weight, + num_experts=layer.global_num_experts, + top_k=layer.top_k, + n_group=layer.num_expert_group, + topk_group=layer.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routing_method_type=layer.routing_method_type, + ) + def forward_monolithic_cpu( self, layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 @@ -344,12 +416,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.num_expert_group, custom_routing_function=layer.custom_routing_function, ) - - if current_platform.is_cpu(): - forward_native: Callable = forward_monolithic_cpu - apply_monolithic = forward_monolithic_cpu - elif current_platform.is_xpu(): - forward_native = forward_monolithic_xpu - apply_monolithic = forward_monolithic_xpu - else: - forward_native = forward_cuda diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index cd82b5432..c9c186ef9 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -195,6 +195,81 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> return backend in backends_supporting_global_sf +def convert_moe_weights_to_flashinfer_trtllm_block_layout( + cache_permute_indices: dict[torch.Size, torch.Tensor], + w13_weight: torch.Tensor, + w2_weight: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Convert expert weights to FlashInfer's block layout. + + This reorders W13 and W2 into the expected epilogue-tiled block layout and + returns the shuffled weight tensors. + """ + if w13_weight.dtype != torch.bfloat16 or w2_weight.dtype != torch.bfloat16: + raise ValueError( + "Unquantized Moe Backend FlashInfer TRTLLM requires bfloat16 weights" + ) + + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + convert_to_block_layout, + get_w2_permute_indices_with_cache, + ) + + epilogue_tile_m = 128 + block_k = 128 + + # Reorder rows of W13 and W2 for fused gated activation and convert to the + # block layout expected by the FlashInfer kernel. + num_experts = w13_weight.shape[0] + device_w13 = w13_weight.device + device_w2 = w2_weight.device + + w13_weights_shuffled: list[torch.Tensor] = [] + w2_weights_shuffled: list[torch.Tensor] = [] + + for i in range(num_experts): + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + cache_permute_indices, + w13_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + tmp_weights1 = ( + w13_weight[i] + .clone() + .view(torch.uint8)[permute_indices.to(device_w13)] + .contiguous() + ) + + permute_indices = get_w2_permute_indices_with_cache( + cache_permute_indices, + w2_weight[i].view(torch.uint8), + epilogue_tile_m, + ) + tmp_weights2 = ( + w2_weight[i] + .clone() + .view(torch.uint8)[permute_indices.to(device_w2)] + .contiguous() + ) + + tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k) + tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k) + + w13_weights_shuffled.append(tmp_weights1.view(torch.bfloat16)) + w2_weights_shuffled.append(tmp_weights2.view(torch.bfloat16)) + + # Stack weights for all experts and return as BF16 tensors. + w13_weights_shuffled_tensor = ( + torch.stack(w13_weights_shuffled).view(torch.bfloat16).contiguous() + ) + w2_weights_shuffled_tensor = ( + torch.stack(w2_weights_shuffled).view(torch.bfloat16).contiguous() + ) + + return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor + + def align_fp8_moe_weights_for_fi( w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool ) -> tuple[torch.Tensor, torch.Tensor, int]: diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 067c6fb3e..44fdff247 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -105,6 +105,9 @@ def _lazy_import_wrapper( # Create lazy wrappers for each function +flashinfer_trtllm_bf16_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "trtllm_bf16_moe" +) flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe" )