[MoE Refactor] Migrate Unquantized to Full Oracle Flow (#36286)

Signed-off-by: Yifan Zong <yzong@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: yzong-rh <yzong@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
yzong-rh
2026-03-31 15:43:33 -04:00
committed by GitHub
parent 598190aac3
commit d9b90a07ac
11 changed files with 618 additions and 514 deletions

View File

@@ -1664,7 +1664,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
intermediate_size_per_partition=n,
num_local_experts=e,
num_logical_experts=e,
activation="silu",
activation=MoEActivation.SILU,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype,
@@ -1695,13 +1695,25 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
layer.topk_group = 1
layer.intermediate_size_per_partition = n
layer.ep_rank = 0
layer.activation = "silu"
layer.activation = MoEActivation.SILU
layer.e_score_correction_bias = None
layer.routing_method_type = RoutingMethodType.Renormalize
layer.expert_map = None
layer.apply_router_weight_on_input = False
layer.routed_scaling_factor = None
layer.shared_experts = None
layer._maybe_init_expert_routing_tables = lambda: None
quant_method.process_weights_after_loading(layer)
trtllm_output = quant_method.forward_monolithic_cuda(
assert quant_method.moe_kernel is not None, (
"moe_kernel should be set after process_weights_after_loading"
)
assert quant_method.supports_internal_mk, (
"supports_internal_mk should be True after setup"
)
trtllm_output = quant_method.apply_monolithic(
layer=layer,
x=a,
router_logits=router_logits,

View File

@@ -24,7 +24,7 @@ from vllm.platforms import current_platform
],
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
"vllm.utils.flashinfer.has_flashinfer",
return_value=False,
)
@patch(
@@ -54,13 +54,29 @@ def test_select_default_backend_by_platform(
# Set only the specified platform to True
getattr(mock_platform, platform_method).return_value = True
with (
patch.object(current_platform, "is_cuda", return_value=False),
patch.object(current_platform, "is_rocm", return_value=False),
patch.object(current_platform, "is_cpu", return_value=False),
patch.object(current_platform, "is_xpu", return_value=False),
patch.object(current_platform, "is_tpu", return_value=False),
patch.object(current_platform, "is_out_of_tree", return_value=False),
patch.object(current_platform, platform_method, return_value=True),
):
moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_dp=False,
selected_backend, expert_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == expected_backend
if expected_backend in [
UnquantizedMoeBackend.CPU,
UnquantizedMoeBackend.OOT,
UnquantizedMoeBackend.TPU,
]:
assert expert_cls is None
else:
assert expert_cls is not None
@patch(
@@ -87,88 +103,90 @@ def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer):
mock_platform.is_out_of_tree.return_value = False
moe_config = make_dummy_moe_config()
selected_backend = select_unquantized_moe_backend(
selected_backend, expert_cls = select_unquantized_moe_backend(
moe_config=moe_config,
use_dp=False,
)
assert selected_backend == UnquantizedMoeBackend.AITER
assert expert_cls is not None
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
return_value=True,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
"vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
return_value=(True, None),
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
)
def test_select_cuda_flashinfer_trtllm_backend(
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
):
def test_select_cuda_flashinfer_trtllm_backend(mock_is_supported_trtllm, monkeypatch):
"""Test CUDA backend selection when FlashInfer TRTLLM is available and enabled."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set as CUDA platform
mock_platform.is_cuda.return_value = True
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False
with (
patch.object(current_platform, "is_cuda", return_value=True),
patch.object(current_platform, "is_rocm", return_value=False),
patch.object(current_platform, "is_cpu", return_value=False),
patch.object(current_platform, "is_xpu", return_value=False),
patch.object(current_platform, "is_tpu", return_value=False),
patch.object(current_platform, "is_out_of_tree", return_value=False),
patch.object(current_platform, "has_device_capability", return_value=True),
):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
moe_config = make_dummy_moe_config()
# TRTLLM requires EP and does not support DP
moe_config.moe_parallel_config.use_ep = True
moe_config.moe_parallel_config.use_dp = False
selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_dp=False,
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
assert experts_cls is not None
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
"vllm.utils.flashinfer.has_flashinfer",
return_value=True,
)
@patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
"vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
return_value=(False, None),
)
@patch(
"vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts.is_supported_config",
return_value=(True, None),
)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
)
def test_select_cuda_flashinfer_cutlass_backend(
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
mock_has_flashinfer,
mock_is_supported_trtllm,
mock_is_supported_cutlass,
monkeypatch,
):
"""Test CUDA backend selection when FlashInfer TRTLLM is not available
and FlashInfer CUTLASS is available."""
with patch(
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
) as mock_platform:
# Set as CUDA platform with Hopper capability
mock_platform.is_cuda.return_value = True
mock_platform.is_rocm.return_value = False
mock_platform.is_cpu.return_value = False
mock_platform.is_xpu.return_value = False
mock_platform.is_tpu.return_value = False
mock_platform.is_out_of_tree.return_value = False
mock_platform.has_device_capability.return_value = True # SM90+
with (
patch.object(current_platform, "is_cuda", return_value=True),
patch.object(current_platform, "is_rocm", return_value=False),
patch.object(current_platform, "is_cpu", return_value=False),
patch.object(current_platform, "is_xpu", return_value=False),
patch.object(current_platform, "is_tpu", return_value=False),
patch.object(current_platform, "is_out_of_tree", return_value=False),
patch.object(current_platform, "has_device_capability", return_value=True),
):
# Enable FlashInfer via env var
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
moe_config = make_dummy_moe_config()
# CUTLASS requires EP and does not support DP
moe_config.moe_parallel_config.use_ep = True
moe_config.moe_parallel_config.use_dp = False
selected_backend = select_unquantized_moe_backend(
moe_config=moe_config,
use_dp=False, # CUTLASS doesn't support DP
selected_backend, experts_cls = select_unquantized_moe_backend(
moe_config=moe_config
)
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
assert experts_cls is not None

View File

@@ -210,6 +210,13 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
## Qwen3 Next ##
@pytest.mark.skip(
reason=(
"FLASHINFER TRTLLM MoE has a bug with all negative router logits "
"for models with RENORMALIZE. This will be re-enabled once the "
"issue is fixed in flashinfer."
)
)
def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
can_initialize(
"Qwen/Qwen3-Next-80B-A3B-Instruct",

View File

@@ -49,6 +49,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
assert not self.base_layer.quant_method.is_monolithic, (
"Monolithic kernels are not supported for Fused MoE LoRA."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer)

View File

@@ -0,0 +1,148 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
"""
BF16 unquantized TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return (
p.is_cuda()
and p.is_device_capability_family(100)
and has_flashinfer_trtllm_fused_moe()
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
"""BF16 kernels do not support non-gated MoE"""
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports only unquantized inputs."""
return weight_key is None and activation_key is None
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
# NOTE: TRTLLM Kernel has issue with Qwen3.5 router.
# Re-enable once the issue is resolved.
# https://github.com/vllm-project/vllm/issues/37591
# RoutingMethodType.Renormalize,
# RoutingMethodType.RenormalizeNaive
]
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
return (
not moe_parallel_config.use_all2all_kernels
or moe_parallel_config.use_ag_rs_all2all_kernels
) and not moe_parallel_config.enable_eplb
@staticmethod
def _supports_router_logits_dtype(
router_logits_dtype: torch.dtype | None,
routing_method: RoutingMethodType,
) -> bool:
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
return True
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
apply_router_weight_on_input: bool,
num_expert_group: int | None = None,
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
import flashinfer
return flashinfer.fused_moe.trtllm_bf16_moe(
routing_logits=router_logits,
routing_bias=e_score_correction_bias,
hidden_states=hidden_states,
gemm1_weights=w1,
gemm2_weights=w2,
num_experts=global_num_experts,
top_k=self.topk,
n_group=num_expert_group,
topk_group=topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routing_method_type=self.routing_method_type,
)

View File

@@ -1,141 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
#
# Methods used by the oracle for kernel selection.
#
def _supports_current_device() -> bool:
"""Supports only Blackwell-family GPUs."""
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
def _supports_no_act_and_mul() -> bool:
"""BF16 kernels do not support non-gated MoE"""
return False
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU]
def _supports_routing_method_bf16(
routing_method: RoutingMethodType,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.Renormalize,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.RenormalizeNaive,
]
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Supports TRTLLM Kernel does not support EPLB."""
return not moe_parallel_config.enable_eplb
def is_supported_config_trtllm_bf16(
moe_config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
"""
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
for BF16 unquantized kernels.
"""
def _make_reason(reason: str) -> str:
return f"kernel does not support {reason}"
if not _supports_current_device():
return False, _make_reason(f"current device {current_platform.device_name}")
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
return False, _make_reason("no act_and_mul MLP layer")
elif not _supports_activation(moe_config.activation):
return False, _make_reason(f"{moe_config.activation} activation")
elif not _supports_parallel_config(moe_config.moe_parallel_config):
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
elif not _supports_routing_method_bf16(moe_config.routing_method):
return False, _make_reason(f"routing method {moe_config.routing_method}")
elif activation_format != mk.FusedMoEActivationFormat.Standard:
return False, _make_reason(f"activation format {activation_format}")
return True, None
def flashinfer_fused_moe_bf16(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routing_method_type: int,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe
return flashinfer_trtllm_bf16_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=local_expert_offset,
local_num_experts=local_num_experts,
routing_method_type=routing_method_type,
tune_max_num_tokens=tune_max_num_tokens,
)
def flashinfer_fused_moe_bf16_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int | None,
topk_group: int | None,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routing_method_type: int = RoutingMethodType.Renormalize,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="flashinfer_fused_moe_bf16",
op_func=flashinfer_fused_moe_bf16,
fake_impl=flashinfer_fused_moe_bf16_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)

View File

@@ -1967,6 +1967,10 @@ class TritonExperts(mk.FusedMoEExpertsModular):
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
@staticmethod
def _supports_batch_invariance():
return True
def supports_expert_map(self) -> bool:
return True

View File

@@ -9,6 +9,7 @@ from typing import final
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
@@ -563,6 +564,8 @@ class FusedMoEExperts(ABC):
)
elif activation_format != cls.activation_format():
return False, _make_reason(f"{activation_format.value} activation format")
elif envs.VLLM_BATCH_INVARIANT and not cls._supports_batch_invariance():
return False, _make_reason("batch invariance")
return True, None
@staticmethod
@@ -645,6 +648,15 @@ class FusedMoEExperts(ABC):
"""
return True
@staticmethod
def _supports_batch_invariance() -> bool:
"""
Whether the kernel supports batch invariance, i.e. the output does not
depend on the order of the tokens in the input batch. This is useful
for determining if the kernel can used with VLLM_BATCH_INVARIANT=1.
"""
return False
#
# Various helpers for accessing quantization parameters from the
# quant_config.

View File

@@ -11,21 +11,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_bf16,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
convert_moe_weights_to_flashinfer_trtllm_block_layout,
get_flashinfer_moe_backend,
swap_w13_to_w31,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer, has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__)
@@ -35,21 +34,96 @@ class UnquantizedMoeBackend(Enum):
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
AITER = "ROCm AITER"
TRITON = "TRITON"
BATCHED_TRITON = "BATCHED_TRITON"
CPU = "CPU"
XPU = "XPU"
TPU = "TPU"
OOT = "OOT"
# NOTE(zyongye): Unsupported backend means backend
# that is not conform with Modular kernel format.
# We will directly call the kernel for those backend
UNSUPPORTED_BACKEND = [
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
UnquantizedMoeBackend.CPU,
UnquantizedMoeBackend.TPU,
UnquantizedMoeBackend.OOT,
]
def _get_priority_backends(moe_config: FusedMoEConfig) -> list[UnquantizedMoeBackend]:
"""
Get available backends in priority order based on platform and config.
This function can be extended to become more complex as needed.
"""
def _move_to_back(
backends: list[UnquantizedMoeBackend],
backend: UnquantizedMoeBackend,
) -> None:
backends.append(backends.pop(backends.index(backend)))
if current_platform.is_rocm():
_AVAILABLE_BACKENDS = [
UnquantizedMoeBackend.AITER,
UnquantizedMoeBackend.TRITON,
UnquantizedMoeBackend.BATCHED_TRITON,
]
elif current_platform.is_cuda():
_AVAILABLE_BACKENDS = [
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
UnquantizedMoeBackend.FLASHINFER_CUTLASS,
UnquantizedMoeBackend.TRITON,
UnquantizedMoeBackend.BATCHED_TRITON,
]
# HACK: Qwen3.5 has crash with FLASHINFER_CUTLASS BF16 if DEP.
# Updating the oracle querying logic is out of the scope of this
# PR. Need to fix the kernel or update structure in follow up.
if moe_config.moe_parallel_config.dp_size > 1:
_move_to_back(_AVAILABLE_BACKENDS, UnquantizedMoeBackend.FLASHINFER_CUTLASS)
elif current_platform.is_xpu():
_AVAILABLE_BACKENDS = [UnquantizedMoeBackend.XPU]
elif current_platform.is_cpu():
_AVAILABLE_BACKENDS = [UnquantizedMoeBackend.CPU]
return _AVAILABLE_BACKENDS
def backend_to_kernel_cls(
backend: UnquantizedMoeBackend,
) -> type[mk.FusedMoEExperts]:
if backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
from vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe import (
TrtLlmBf16Experts,
)
return TrtLlmBf16Experts
elif backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return FlashInferExperts
elif backend == UnquantizedMoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
return AiterExperts
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
return TritonExperts
elif backend == UnquantizedMoeBackend.BATCHED_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
return BatchedTritonExperts
elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExperts
return XPUExperts
else:
raise ValueError(f"Unknown unquantized MoE backend: {backend.value}")
def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend:
@@ -70,194 +144,224 @@ def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend
def select_unquantized_moe_backend(
moe_config: FusedMoEConfig,
use_dp: bool,
) -> UnquantizedMoeBackend:
) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the primary Unquantized MoE backend
Select the primary Unquantized MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
def _make_log_backend(backend: UnquantizedMoeBackend):
return f"Using {backend.value} backend for Unquantized MoE"
if current_platform.is_cpu():
# TODO: migrate to MK structure.
return UnquantizedMoeBackend.CPU, None
if current_platform.is_tpu():
return UnquantizedMoeBackend.TPU, None
if current_platform.is_out_of_tree():
return UnquantizedMoeBackend.OOT, None
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = _get_priority_backends(moe_config)
# NOTE(rob): We need to peak into the P/F selection to determine
# if we are using the batched or standard expert format, which
# if not ideal. Once we unify TP + DP/EP, we can select P/F first.
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if moe_config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
# Check if FlashInfer TRTLLM BF16 MoE is supported
trtllm_supported, _ = is_supported_config_trtllm_bf16(
moe_config=moe_config,
activation_format=activation_format,
)
flashinfer_trtllm_available = has_flashinfer() and trtllm_supported
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
flashinfer_cutlass_available = (
has_flashinfer_cutlass_fused_moe()
and (not use_dp)
and current_platform.has_device_capability(90)
)
flashinfer_trtllm_moe_enabled = (
flashinfer_trtllm_available
and envs.VLLM_USE_FLASHINFER_MOE_FP16
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
)
flashinfer_cutlass_moe_enabled = (
flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16
)
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
def _make_log_backend(backend: UnquantizedMoeBackend) -> str:
available_strs = [b.value for b in AVAILABLE_BACKENDS]
return (
f"Using {backend.value} Unquantized MoE backend out "
f"of potential backends: {available_strs}."
)
def _make_log_unsupported(
backend: UnquantizedMoeBackend, reason: str | None
) -> str:
if reason:
return (
f"Unquantized MoE backend {backend.value} does not support the "
f"deployment configuration since {reason}."
)
return (
f"Unquantized MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: UnquantizedMoeBackend,
config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, None, None, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = moe_config.moe_backend
if runner_backend != "auto":
requested_backend = map_unquantized_backend(runner_backend)
if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
if not flashinfer_trtllm_available:
raise ValueError(
"FlashInfer TRTLLM MoE backend is not available for this "
"configuration."
)
elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
if not flashinfer_cutlass_available:
raise ValueError(
"FlashInfer CUTLASS MoE backend is not available for this "
"configuration."
)
elif requested_backend == UnquantizedMoeBackend.AITER and not (
current_platform.is_rocm() and rocm_aiter_moe_enabled
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == UnquantizedMoeBackend.TRITON
):
raise ValueError(
"ROCm AITer MoE backend is not available for this configuration."
requested_backend = UnquantizedMoeBackend.BATCHED_TRITON
return _return_or_raise(requested_backend, moe_config, activation_format)
# Handle explicit FlashInfer FP16 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP16:
if UnquantizedMoeBackend.FLASHINFER_TRTLLM in AVAILABLE_BACKENDS:
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_TRTLLM)
if UnquantizedMoeBackend.FLASHINFER_CUTLASS in AVAILABLE_BACKENDS:
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_CUTLASS)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.CUTLASS:
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
else:
raise ValueError(
f"FlashInfer MOE backend {fi_backend} "
"does not support unquantized MoE."
)
k_cls = backend_to_kernel_cls(backend)
return _return_or_raise(backend, moe_config, activation_format)
else:
# If the user is not explicit about the backend, try both.
for backend in [
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
UnquantizedMoeBackend.FLASHINFER_CUTLASS,
]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, moe_config, None, None, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP16=1, but no "
"FlashInfer unquantized MoE backend supports the configuration."
)
logger.info_once(_make_log_backend(requested_backend), scope="local")
return requested_backend
if current_platform.is_rocm():
if rocm_aiter_moe_enabled:
# Handle explicit AITER FP8 configuration.
if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"):
if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE:
if UnquantizedMoeBackend.AITER in AVAILABLE_BACKENDS:
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.AITER)
else:
backend = UnquantizedMoeBackend.AITER
else:
backend = UnquantizedMoeBackend.TRITON
if current_platform.is_cuda():
if flashinfer_trtllm_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
elif flashinfer_cutlass_moe_enabled:
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
if trtllm_supported:
logger.info_once(
"FlashInfer TRTLLM MoE is available but not enabled, "
"consider setting VLLM_FLASHINFER_MOE_BACKEND=latency "
"to enable it for better performance.",
scope="local",
)
else:
if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported:
logger.info_once(
"FlashInfer TRTLLM MoE is available but not enabled, "
"consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 "
"and VLLM_FLASHINFER_MOE_BACKEND=latency "
"to enable it for better performance.",
scope="local",
)
elif not use_dp and flashinfer_cutlass_available:
logger.info_once(
"FlashInfer CUTLASS MoE is available"
" but not enabled, consider setting"
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
scope="local",
)
elif use_dp:
logger.info_once(
"FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local",
)
backend = UnquantizedMoeBackend.TRITON
if current_platform.is_xpu():
backend = UnquantizedMoeBackend.XPU
if current_platform.is_cpu():
backend = UnquantizedMoeBackend.CPU
if current_platform.is_tpu():
backend = UnquantizedMoeBackend.TPU
if current_platform.is_out_of_tree():
backend = UnquantizedMoeBackend.OOT
return _return_or_raise(backend, moe_config, activation_format)
logger.info_once(_make_log_backend(backend), scope="local")
return backend
for backend in AVAILABLE_BACKENDS:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, moe_config, None, None, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No Unquantized MoE backend supports the deployment configuration."
)
def convert_to_unquantized_kernel_format(
unquantized_backend: UnquantizedMoeBackend,
layer: Module,
w13_weight: torch.Tensor | None = None,
w2_weight: torch.Tensor | None = None,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if unquantized_backend == UnquantizedMoeBackend.AITER:
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(w13_weight, w2_weight)
elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
if layer.moe_config.is_act_and_mul:
# Swap halves to arrange as [w3; w1] (kernel expectation)
# Non-gated MoE: w13 is a single projection, no need to swap.
w13_weight = swap_w13_to_w31(w13_weight)
return w13_weight, w2_weight
elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w13_weight = swap_w13_to_w31(w13_weight)
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
w13_weight, w2_weight = convert_moe_weights_to_flashinfer_trtllm_block_layout(
_cache_permute_indices,
w13_weight,
w2_weight,
)
return w13_weight.contiguous(), w2_weight.contiguous()
def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
) -> mk.FusedMoEKernel | None:
if backend in UNSUPPORTED_BACKEND:
return None
backend: UnquantizedMoeBackend,
experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=is_monolithic,
)
assert prepare_finalize is not None
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
# Create Experts
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=quant_config,
)
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
shared_experts
if moe_config.moe_parallel_config.use_deepep_ll_kernels
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=(not moe_config.disable_inplace and not is_monolithic),
)
elif backend == UnquantizedMoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
AiterExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
TritonExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts
kernel = mk.FusedMoEKernel(
MoEPrepareAndFinalizeNoDPEPModular(),
XPUExperts(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
return kernel

View File

@@ -6,11 +6,8 @@ from collections.abc import Callable
import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.config import (
@@ -23,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEExpertsModular,
FusedMoEPrepareAndFinalizeModular,
)
@@ -33,20 +29,10 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
make_unquantized_moe_kernel,
select_unquantized_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
convert_moe_weights_to_flashinfer_trtllm_block_layout,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
if current_platform.is_cuda_alike() or current_platform.is_xpu():
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts
else:
TritonExperts = None # type: ignore
logger = init_logger(__name__)
@@ -59,45 +45,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.unquantized_backend = select_unquantized_moe_backend(
self.unquantized_backend, self.experts_cls = select_unquantized_moe_backend(
moe_config=self.moe,
use_dp=self.moe.moe_parallel_config.dp_size > 1,
)
# AITER only supports gated activations (silu/gelu), so disable it
# for non-gated MoE (is_act_and_mul=False)
self.rocm_aiter_moe_enabled = (
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
)
self.kernel: mk.FusedMoEKernel | None = None
self._is_monolithic = (
current_platform.is_cpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
)
if self.is_monolithic:
self.apply_monolithic: Callable = self._select_monolithic()
def _select_monolithic(self) -> Callable:
"""Select the monolithic implementation based on platform."""
if current_platform.is_cpu():
return self.forward_monolithic_cpu
else:
return self.forward_monolithic_cuda
def forward_native(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
@property
def is_monolithic(self) -> bool:
return self._is_monolithic
# Escape hatch for CPU, which stays on the old monolithic path.
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
return True
return super().is_monolithic
@property
def supports_eplb(self) -> bool:
@@ -106,43 +63,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalizeModular | None:
return super().maybe_make_prepare_finalize(routing_tables)
):
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic for all but the CPU backend. CPU backend is monolithic. "
"So this function should not be called."
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
) -> FusedMoEExpertsModular:
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
logger.debug("BatchedTritonExperts %s", self.moe)
return BatchedTritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
elif (
self.unquantized_backend == UnquantizedMoeBackend.AITER
and rocm_aiter_ops.is_fused_moe_enabled()
):
from .rocm_aiter_fused_moe import AiterExperts
logger.debug("AiterExperts %s", self.moe)
return AiterExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
def create_weights(
self,
@@ -227,14 +163,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
# Setup Modular Kernel for TP Case
# Setup moe kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.kernel = make_unquantized_moe_kernel(
backend=self.unquantized_backend,
assert self.experts_cls is not None
self.moe_kernel = make_unquantized_moe_kernel(
quant_config=self.moe_quant_config,
moe_config=self.moe,
backend=self.unquantized_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -244,22 +183,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
w13_weights_shuffled, w2_weights_shuffled = (
convert_moe_weights_to_flashinfer_trtllm_block_layout(
_cache_permute_indices,
layer.w13_weight.data,
layer.w2_weight.data,
)
)
layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False)
layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False)
if self.unquantized_backend in [
UnquantizedMoeBackend.TPU,
UnquantizedMoeBackend.OOT,
]:
# OOT handles internally.
return
elif self.unquantized_backend == UnquantizedMoeBackend.CPU:
# CPU stays on the old path — no oracle, no moe_kernel.
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
@@ -290,13 +222,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else:
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike() or current_platform.is_xpu():
else:
self._setup_kernel(
layer=layer,
w13=layer.w13_weight,
w2=layer.w2_weight,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
@@ -313,16 +254,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_experts_input=shared_experts_input,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
if self.moe.has_bias:
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG
def forward_cuda(
def forward_native(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
@@ -330,9 +262,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
return self.kernel.apply(
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -345,53 +276,58 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shared_experts_input=shared_experts_input,
)
def forward_monolithic_cuda(
def forward_cuda(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
layer, x, topk_weights, topk_ids, shared_experts_input
)
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401
assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
return torch.ops.vllm.flashinfer_fused_moe_bf16(
routing_logits=router_logits,
routing_bias=layer.e_score_correction_bias,
hidden_states=x,
gemm1_weights=layer.w13_weight,
gemm2_weights=layer.w2_weight,
num_experts=layer.global_num_experts,
top_k=layer.top_k,
n_group=layer.num_expert_group,
topk_group=layer.topk_group,
intermediate_size=layer.intermediate_size_per_partition,
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routing_method_type=layer.routing_method_type,
)
def forward_monolithic_cpu(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.cpu_fused_moe(
layer,
x,
layer.use_grouped_topk,
layer.top_k,
router_logits,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
layer.global_num_experts,
layer.expert_map,
layer.custom_routing_function,
layer.scoring_func,
layer.routed_scaling_factor,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.activation,
)
assert self.is_monolithic
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
assert self.moe_kernel is None
return self.cpu_fused_moe(
layer,
x,
layer.use_grouped_topk,
layer.top_k,
router_logits,
layer.renormalize,
layer.topk_group,
layer.num_expert_group,
layer.global_num_experts,
layer.expert_map,
layer.custom_routing_function,
layer.scoring_func,
layer.routed_scaling_factor,
layer.e_score_correction_bias,
layer.apply_router_weight_on_input,
layer.activation,
)
else:
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)

View File

@@ -202,6 +202,7 @@ def has_flashinfer_trtllm_fused_moe() -> bool:
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_bf16_moe"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)