[NVIDIA] [feat] Integrate flashinfer Trtllmgen bf16 moe (#32954)
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
This commit is contained in:
@@ -77,6 +77,18 @@ def _supports_routing_method(
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
|
||||
def _supports_routing_method_bf16(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
return routing_method in [
|
||||
RoutingMethodType.Default,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
|
||||
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Supports TRTLLM Kernel does not support EPLB."""
|
||||
return not moe_parallel_config.enable_eplb
|
||||
@@ -115,6 +127,34 @@ def is_supported_config_trtllm(
|
||||
return True, None
|
||||
|
||||
|
||||
def is_supported_config_trtllm_bf16(
|
||||
moe_config: FusedMoEConfig,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
for BF16 unquantized kernels.
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason("current device")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason("parallel config")
|
||||
elif not _supports_routing_method_bf16(moe_config.routing_method):
|
||||
return False, _make_reason("routing method")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason("activation format")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
@@ -287,3 +327,66 @@ direct_register_custom_op(
|
||||
fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_bf16(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routing_method_type: int,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe
|
||||
|
||||
return flashinfer_trtllm_bf16_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
gemm2_weights=gemm2_weights,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routing_method_type=routing_method_type,
|
||||
tune_max_num_tokens=tune_max_num_tokens,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_bf16_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routing_method_type: int = RoutingMethodType.Renormalize,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_bf16",
|
||||
op_func=flashinfer_fused_moe_bf16,
|
||||
fake_impl=flashinfer_fused_moe_bf16_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
|
||||
@@ -14,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_bf16,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
@@ -21,12 +24,13 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.flashinfer import has_flashinfer, has_flashinfer_cutlass_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class UnquantizedMoeBackend(Enum):
|
||||
FLASHINFER_TRTLLM = "FlashInfer TRTLLM"
|
||||
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
|
||||
AITER = "ROCm AITER"
|
||||
TRITON = "TRITON"
|
||||
@@ -40,6 +44,7 @@ class UnquantizedMoeBackend(Enum):
|
||||
# that is not conform with Modular kernel format.
|
||||
# We will directly call the kernel for those backend
|
||||
UNSUPPORTED_BACKEND = [
|
||||
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
|
||||
UnquantizedMoeBackend.CPU,
|
||||
UnquantizedMoeBackend.XPU,
|
||||
UnquantizedMoeBackend.TPU,
|
||||
@@ -48,11 +53,12 @@ UNSUPPORTED_BACKEND = [
|
||||
|
||||
|
||||
def select_unquantized_moe_backend(
|
||||
moe_config: FusedMoEConfig,
|
||||
use_ep: bool,
|
||||
use_dp: bool,
|
||||
) -> UnquantizedMoeBackend:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Select the primary Unquantized MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
|
||||
@@ -61,13 +67,27 @@ def select_unquantized_moe_backend(
|
||||
|
||||
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if moe_config.moe_parallel_config.use_batched_activation_format
|
||||
else mk.FusedMoEActivationFormat.Standard
|
||||
)
|
||||
|
||||
# Check if FlashInfer TRTLLM BF16 MoE is supported
|
||||
trtllm_supported, _ = is_supported_config_trtllm_bf16(
|
||||
moe_config=moe_config,
|
||||
activation_format=activation_format,
|
||||
)
|
||||
flashinfer_trtllm_moe_enabled = (
|
||||
has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported
|
||||
)
|
||||
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
|
||||
flashinfer_cutlass_moe_enabled = (
|
||||
has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
and use_ep
|
||||
and (not use_dp)
|
||||
and current_platform.get_device_capability()[0] >= 9
|
||||
and current_platform.has_device_capability(90)
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
if rocm_aiter_moe_enabled:
|
||||
@@ -75,12 +95,21 @@ def select_unquantized_moe_backend(
|
||||
else:
|
||||
backend = UnquantizedMoeBackend.TRITON
|
||||
if current_platform.is_cuda():
|
||||
if flashinfer_cutlass_moe_enabled:
|
||||
if flashinfer_trtllm_moe_enabled:
|
||||
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
elif flashinfer_cutlass_moe_enabled:
|
||||
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
|
||||
else:
|
||||
if use_ep and (not use_dp):
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported:
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is available for EP"
|
||||
"FlashInfer TRTLLM MoE is available but not enabled, "
|
||||
"consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 "
|
||||
"to enable it for better performance.",
|
||||
scope="local",
|
||||
)
|
||||
elif use_ep and (not use_dp):
|
||||
logger.info_once(
|
||||
"FlashInfer MoE is available for EP"
|
||||
" but not enabled, consider setting"
|
||||
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
|
||||
scope="local",
|
||||
|
||||
@@ -6,6 +6,7 @@ from collections.abc import Callable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@@ -32,6 +33,9 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
make_unquantized_moe_kernel,
|
||||
select_unquantized_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
convert_moe_weights_to_flashinfer_trtllm_block_layout,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
@@ -54,6 +58,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.unquantized_backend = select_unquantized_moe_backend(
|
||||
moe_config=self.moe,
|
||||
use_ep=self.moe.moe_parallel_config.use_ep,
|
||||
use_dp=self.moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
@@ -64,7 +69,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
|
||||
)
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
self._is_monolithic = current_platform.is_cpu() or current_platform.is_xpu()
|
||||
self._is_monolithic = (
|
||||
current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
)
|
||||
|
||||
if self.is_monolithic:
|
||||
self.apply_monolithic: Callable = self._select_monolithic()
|
||||
|
||||
def _select_monolithic(self) -> Callable:
|
||||
"""Select the monolithic implementation based on platform."""
|
||||
if current_platform.is_cpu():
|
||||
return self.forward_monolithic_cpu
|
||||
elif current_platform.is_xpu():
|
||||
return self.forward_monolithic_xpu
|
||||
else:
|
||||
return self.forward_monolithic_cuda
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_cuda(layer, x, topk_weights, topk_ids)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
@@ -211,7 +241,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.XPU:
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
layer.w13_weight.data = w13_weight_swapped.contiguous()
|
||||
w13_weights_shuffled, w2_weights_shuffled = (
|
||||
convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
_cache_permute_indices,
|
||||
layer.w13_weight.data,
|
||||
layer.w2_weight.data,
|
||||
)
|
||||
)
|
||||
layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False)
|
||||
elif self.unquantized_backend == UnquantizedMoeBackend.XPU:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
|
||||
@@ -290,6 +335,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel is not None
|
||||
|
||||
return self.kernel(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -303,6 +349,32 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
|
||||
def forward_monolithic_cuda(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401
|
||||
|
||||
assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_bf16(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
hidden_states=x,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
n_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
)
|
||||
|
||||
def forward_monolithic_cpu(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
@@ -344,12 +416,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.num_expert_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
)
|
||||
|
||||
if current_platform.is_cpu():
|
||||
forward_native: Callable = forward_monolithic_cpu
|
||||
apply_monolithic = forward_monolithic_cpu
|
||||
elif current_platform.is_xpu():
|
||||
forward_native = forward_monolithic_xpu
|
||||
apply_monolithic = forward_monolithic_xpu
|
||||
else:
|
||||
forward_native = forward_cuda
|
||||
|
||||
@@ -195,6 +195,81 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
|
||||
return backend in backends_supporting_global_sf
|
||||
|
||||
|
||||
def convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
cache_permute_indices: dict[torch.Size, torch.Tensor],
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert expert weights to FlashInfer's block layout.
|
||||
|
||||
This reorders W13 and W2 into the expected epilogue-tiled block layout and
|
||||
returns the shuffled weight tensors.
|
||||
"""
|
||||
if w13_weight.dtype != torch.bfloat16 or w2_weight.dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
"Unquantized Moe Backend FlashInfer TRTLLM requires bfloat16 weights"
|
||||
)
|
||||
|
||||
from flashinfer.fused_moe.core import (
|
||||
_maybe_get_cached_w3_w1_permute_indices,
|
||||
convert_to_block_layout,
|
||||
get_w2_permute_indices_with_cache,
|
||||
)
|
||||
|
||||
epilogue_tile_m = 128
|
||||
block_k = 128
|
||||
|
||||
# Reorder rows of W13 and W2 for fused gated activation and convert to the
|
||||
# block layout expected by the FlashInfer kernel.
|
||||
num_experts = w13_weight.shape[0]
|
||||
device_w13 = w13_weight.device
|
||||
device_w2 = w2_weight.device
|
||||
|
||||
w13_weights_shuffled: list[torch.Tensor] = []
|
||||
w2_weights_shuffled: list[torch.Tensor] = []
|
||||
|
||||
for i in range(num_experts):
|
||||
permute_indices = _maybe_get_cached_w3_w1_permute_indices(
|
||||
cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
tmp_weights1 = (
|
||||
w13_weight[i]
|
||||
.clone()
|
||||
.view(torch.uint8)[permute_indices.to(device_w13)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
)
|
||||
tmp_weights2 = (
|
||||
w2_weight[i]
|
||||
.clone()
|
||||
.view(torch.uint8)[permute_indices.to(device_w2)]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
tmp_weights1 = convert_to_block_layout(tmp_weights1.view(torch.uint8), block_k)
|
||||
tmp_weights2 = convert_to_block_layout(tmp_weights2.view(torch.uint8), block_k)
|
||||
|
||||
w13_weights_shuffled.append(tmp_weights1.view(torch.bfloat16))
|
||||
w2_weights_shuffled.append(tmp_weights2.view(torch.bfloat16))
|
||||
|
||||
# Stack weights for all experts and return as BF16 tensors.
|
||||
w13_weights_shuffled_tensor = (
|
||||
torch.stack(w13_weights_shuffled).view(torch.bfloat16).contiguous()
|
||||
)
|
||||
w2_weights_shuffled_tensor = (
|
||||
torch.stack(w2_weights_shuffled).view(torch.bfloat16).contiguous()
|
||||
)
|
||||
|
||||
return w13_weights_shuffled_tensor, w2_weights_shuffled_tensor
|
||||
|
||||
|
||||
def align_fp8_moe_weights_for_fi(
|
||||
w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
|
||||
Reference in New Issue
Block a user