[NVIDIA] [feat] Integrate flashinfer Trtllmgen bf16 moe (#32954)

Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
This commit is contained in:
Linda
2026-01-29 19:00:13 +01:00
committed by GitHub
parent 8c8ebeb941
commit 0493d897c4
5 changed files with 290 additions and 17 deletions

View File

@@ -77,6 +77,18 @@ def _supports_routing_method(
raise ValueError("Unsupported quantization scheme.")
def _supports_routing_method_bf16(
routing_method: RoutingMethodType,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.Renormalize,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.RenormalizeNaive,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Supports TRTLLM Kernel does not support EPLB."""
return not moe_parallel_config.enable_eplb
@@ -115,6 +127,34 @@ def is_supported_config_trtllm(
return True, None
def is_supported_config_trtllm_bf16(
moe_config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
for BF16 unquantized kernels.
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason("current device")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason("parallel config")
elif not _supports_routing_method_bf16(moe_config.routing_method):
return False, _make_reason("routing method")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason("activation format")
return True, None
def flashinfer_fused_moe_blockscale_fp8(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
@@ -287,3 +327,66 @@ direct_register_custom_op(
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
def flashinfer_fused_moe_bf16(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routing_method_type: int,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe
return flashinfer_trtllm_bf16_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routing_method_type=routing_method_type,
tune_max_num_tokens=tune_max_num_tokens,
)
def flashinfer_fused_moe_bf16_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routing_method_type: int = RoutingMethodType.Renormalize,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="flashinfer_fused_moe_bf16",
op_func=flashinfer_fused_moe_bf16,
fake_impl=flashinfer_fused_moe_bf16_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)

View File

@@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_bf16,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
@@ -21,12 +24,13 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.flashinfer import has_flashinfer, has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__)
class UnquantizedMoeBackend(Enum):
FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
AITER = "ROCm AITER"
TRITON = "TRITON"
@@ -40,6 +44,7 @@ class UnquantizedMoeBackend(Enum):
# that is not conform with Modular kernel format.
# We will directly call the kernel for those backend
UNSUPPORTED_BACKEND = [
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
UnquantizedMoeBackend.CPU,
UnquantizedMoeBackend.XPU,
UnquantizedMoeBackend.TPU,
@@ -48,11 +53,12 @@ UNSUPPORTED_BACKEND = [
def select_unquantized_moe_backend(
moe_config: FusedMoEConfig,
use_ep: bool,
use_dp: bool,
) -> UnquantizedMoeBackend:
"""
Select the primary FP8 MoE backend
Select the primary Unquantized MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
@@ -61,13 +67,27 @@ def select_unquantized_moe_backend(
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if moe_config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
# Check if FlashInfer TRTLLM BF16 MoE is supported
trtllm_supported, _ = is_supported_config_trtllm_bf16(
moe_config=moe_config,
activation_format=activation_format,
)
flashinfer_trtllm_moe_enabled = (
has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported
)
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
flashinfer_cutlass_moe_enabled = (
has_flashinfer_cutlass_fused_moe()
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and use_ep
and (not use_dp)
and current_platform.get_device_capability()[0] >= 9
and current_platform.has_device_capability(90)
)
if current_platform.is_rocm():
if rocm_aiter_moe_enabled:
@@ -75,12 +95,21 @@ def select_unquantized_moe_backend(
else:
backend = UnquantizedMoeBackend.TRITON
if current_platform.is_cuda():
if flashinfer_cutlass_moe_enabled:
if flashinfer_trtllm_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
elif flashinfer_cutlass_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
else:
if use_ep and (not use_dp):
if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported:
logger.info_once(
"FlashInfer CUTLASS MoE is available for EP"
"FlashInfer TRTLLM MoE is available but not enabled, "
"consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 "
"to enable it for better performance.",
scope="local",
)
elif use_ep and (not use_dp):
logger.info_once(
"FlashInfer MoE is available for EP"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
scope="local",

View File

@@ -6,6 +6,7 @@ from collections.abc import Callable
import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -32,6 +33,9 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel,
select_unquantized_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
convert_moe_weights_to_flashinfer_trtllm_block_layout,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
@@ -54,6 +58,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.unquantized_backend = select_unquantized_moe_backend(
moe_config=self.moe,
use_ep=self.moe.moe_parallel_config.use_ep,
use_dp=self.moe.moe_parallel_config.dp_size > 1,
)
@@ -64,7 +69,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
)
self.kernel: mk.FusedMoEModularKernel | None = None
self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu()
self._is_monolithic = (
current_platform.is_cpu()
or current_platform.is_xpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
)
if self.is_monolithic:
self.apply_monolithic: Callable = self._select_monolithic()
def _select_monolithic(self) -> Callable:
"""Select the monolithic implementation based on platform."""
if current_platform.is_cpu():
return self.forward_monolithic_cpu
elif current_platform.is_xpu():
return self.forward_monolithic_xpu
else:
return self.forward_monolithic_cuda
def forward_native(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(layer, x, topk_weights, topk_ids)
@property
def is_monolithic(self) -> bool:
@@ -211,7 +241,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if self.unquantized_backend == UnquantizedMoeBackend.XPU:
if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
w13_weights_shuffled, w2_weights_shuffled = (
convert_moe_weights_to_flashinfer_trtllm_block_layout(
_cache_permute_indices,
layer.w13_weight.data,
layer.w2_weight.data,
)
)
layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False)
layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False)
elif self.unquantized_backend == UnquantizedMoeBackend.XPU:
import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
@@ -290,6 +335,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
return self.kernel(
hidden_states=x,
w1=layer.w13_weight,
@@ -303,6 +349,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=layer.expert_map,
)
def forward_monolithic_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401
assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
return torch.ops.vllm.flashinfer_fused_moe_bf16(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
hidden_states=x,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routing_method_type=layer.routing_method_type,
)
def forward_monolithic_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
@@ -344,12 +416,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.num_expert_group,
custom_routing_function=layer.custom_routing_function,
)
if current_platform.is_cpu():
forward_native: Callable = forward_monolithic_cpu
apply_monolithic = forward_monolithic_cpu
elif current_platform.is_xpu():
forward_native = forward_monolithic_xpu
apply_monolithic = forward_monolithic_xpu
else:
forward_native = forward_cuda

View File

@@ -195,6 +195,81 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
return backend in backends_supporting_global_sf
def convert_moe_weights_to_flashinfer_trtllm_block_layout(
cache_permute_indices: dict[torch.Size, torch.Tensor],
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert expert weights to FlashInfer's block layout.
This reorders W13 and W2 into the expected epilogue-tiled block layout and
returns the shuffled weight tensors.
"""
if w13_weight.dtype != torch.bfloat16 or w2_weight.dtype != torch.bfloat16:
raise ValueError(
"Unquantized Moe Backend FlashInfer TRTLLM requires bfloat16 weights"
)
from flashinfer.fused_moe.core import (
_maybe_get_cached_w3_w1_permute_indices,
convert_to_block_layout,
get_w2_permute_indices_with_cache,
)
epilogue_tile_m = 128
block_k = 128
# Reorder rows of W13 and W2 for fused gated activation and convert to the
# block layout expected by the FlashInfer kernel.
num_experts = w13_weight.shape[0]
device_w13 = w13_weight.device
device_w2 = w2_weight.device
w13_weights_shuffled: list[torch.Tensor] = []
w2_weights_shuffled: list[torch.Tensor] = []
for i in range(num_experts):
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
)
tmp_weights1 = (
w13_weight[i]
.clone()
.view(torch.uint8)[permute_indices.to(device_w13)]
.contiguous()
)
permute_indices = get_w2_permute_indices_with_cache(
cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
)
tmp_weights2 = (
w2_weight[i]
.clone()
.view(torch.uint8)[permute_indices.to(device_w2)]
.contiguous()
)
tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k)
tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k)
w13_weights_shuffled.append(tmp_weights1.view(torch.bfloat16))
w2_weights_shuffled.append(tmp_weights2.view(torch.bfloat16))
# Stack weights for all experts and return as BF16 tensors.
w13_weights_shuffled_tensor = (
torch.stack(w13_weights_shuffled).view(torch.bfloat16).contiguous()
)
w2_weights_shuffled_tensor = (
torch.stack(w2_weights_shuffled).view(torch.bfloat16).contiguous()
)
return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor
def align_fp8_moe_weights_for_fi(
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
) -> tuple[torch.Tensor, torch.Tensor, int]: