[MoE Refactor] Migrate Unquantized to Full Oracle Flow (#36286)
Signed-off-by: Yifan Zong <yzong@redhat.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: yzong-rh <yzong@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -1664,7 +1664,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
|
||||
intermediate_size_per_partition=n,
|
||||
num_local_experts=e,
|
||||
num_logical_experts=e,
|
||||
activation="silu",
|
||||
activation=MoEActivation.SILU,
|
||||
device="cuda",
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
in_dtype=dtype,
|
||||
@@ -1695,13 +1695,25 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
|
||||
layer.topk_group = 1
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
layer.activation = "silu"
|
||||
layer.activation = MoEActivation.SILU
|
||||
layer.e_score_correction_bias = None
|
||||
layer.routing_method_type = RoutingMethodType.Renormalize
|
||||
layer.expert_map = None
|
||||
layer.apply_router_weight_on_input = False
|
||||
layer.routed_scaling_factor = None
|
||||
layer.shared_experts = None
|
||||
layer._maybe_init_expert_routing_tables = lambda: None
|
||||
|
||||
quant_method.process_weights_after_loading(layer)
|
||||
|
||||
trtllm_output = quant_method.forward_monolithic_cuda(
|
||||
assert quant_method.moe_kernel is not None, (
|
||||
"moe_kernel should be set after process_weights_after_loading"
|
||||
)
|
||||
assert quant_method.supports_internal_mk, (
|
||||
"supports_internal_mk should be True after setup"
|
||||
)
|
||||
|
||||
trtllm_output = quant_method.apply_monolithic(
|
||||
layer=layer,
|
||||
x=a,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.platforms import current_platform
|
||||
],
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
|
||||
"vllm.utils.flashinfer.has_flashinfer",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(
|
||||
@@ -54,13 +54,29 @@ def test_select_default_backend_by_platform(
|
||||
# Set only the specified platform to True
|
||||
getattr(mock_platform, platform_method).return_value = True
|
||||
|
||||
with (
|
||||
patch.object(current_platform, "is_cuda", return_value=False),
|
||||
patch.object(current_platform, "is_rocm", return_value=False),
|
||||
patch.object(current_platform, "is_cpu", return_value=False),
|
||||
patch.object(current_platform, "is_xpu", return_value=False),
|
||||
patch.object(current_platform, "is_tpu", return_value=False),
|
||||
patch.object(current_platform, "is_out_of_tree", return_value=False),
|
||||
patch.object(current_platform, platform_method, return_value=True),
|
||||
):
|
||||
moe_config = make_dummy_moe_config()
|
||||
selected_backend = select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
use_dp=False,
|
||||
selected_backend, expert_cls = select_unquantized_moe_backend(
|
||||
moe_config=moe_config
|
||||
)
|
||||
|
||||
assert selected_backend == expected_backend
|
||||
if expected_backend in [
|
||||
UnquantizedMoeBackend.CPU,
|
||||
UnquantizedMoeBackend.OOT,
|
||||
UnquantizedMoeBackend.TPU,
|
||||
]:
|
||||
assert expert_cls is None
|
||||
else:
|
||||
assert expert_cls is not None
|
||||
|
||||
|
||||
@patch(
|
||||
@@ -87,88 +103,90 @@ def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer):
|
||||
mock_platform.is_out_of_tree.return_value = False
|
||||
|
||||
moe_config = make_dummy_moe_config()
|
||||
selected_backend = select_unquantized_moe_backend(
|
||||
selected_backend, expert_cls = select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
use_dp=False,
|
||||
)
|
||||
|
||||
assert selected_backend == UnquantizedMoeBackend.AITER
|
||||
assert expert_cls is not None
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
|
||||
return_value=True,
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
|
||||
"vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
|
||||
return_value=(True, None),
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
|
||||
)
|
||||
def test_select_cuda_flashinfer_trtllm_backend(
|
||||
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
|
||||
):
|
||||
def test_select_cuda_flashinfer_trtllm_backend(mock_is_supported_trtllm, monkeypatch):
|
||||
"""Test CUDA backend selection when FlashInfer TRTLLM is available and enabled."""
|
||||
with patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
|
||||
) as mock_platform:
|
||||
# Set as CUDA platform
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.is_rocm.return_value = False
|
||||
mock_platform.is_cpu.return_value = False
|
||||
mock_platform.is_xpu.return_value = False
|
||||
mock_platform.is_tpu.return_value = False
|
||||
mock_platform.is_out_of_tree.return_value = False
|
||||
|
||||
with (
|
||||
patch.object(current_platform, "is_cuda", return_value=True),
|
||||
patch.object(current_platform, "is_rocm", return_value=False),
|
||||
patch.object(current_platform, "is_cpu", return_value=False),
|
||||
patch.object(current_platform, "is_xpu", return_value=False),
|
||||
patch.object(current_platform, "is_tpu", return_value=False),
|
||||
patch.object(current_platform, "is_out_of_tree", return_value=False),
|
||||
patch.object(current_platform, "has_device_capability", return_value=True),
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
|
||||
|
||||
moe_config = make_dummy_moe_config()
|
||||
# TRTLLM requires EP and does not support DP
|
||||
moe_config.moe_parallel_config.use_ep = True
|
||||
moe_config.moe_parallel_config.use_dp = False
|
||||
|
||||
selected_backend = select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
use_dp=False,
|
||||
selected_backend, experts_cls = select_unquantized_moe_backend(
|
||||
moe_config=moe_config
|
||||
)
|
||||
|
||||
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
assert experts_cls is not None
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.has_flashinfer",
|
||||
"vllm.utils.flashinfer.has_flashinfer",
|
||||
return_value=True,
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.is_supported_config_trtllm_bf16",
|
||||
"vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe.TrtLlmBf16Experts.is_supported_config",
|
||||
return_value=(False, None),
|
||||
)
|
||||
@patch(
|
||||
"vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts.is_supported_config",
|
||||
return_value=(True, None),
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="Only supported on NVIDIA platforms."
|
||||
)
|
||||
def test_select_cuda_flashinfer_cutlass_backend(
|
||||
mock_has_flashinfer, mock_is_supported_trtllm, monkeypatch
|
||||
mock_has_flashinfer,
|
||||
mock_is_supported_trtllm,
|
||||
mock_is_supported_cutlass,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Test CUDA backend selection when FlashInfer TRTLLM is not available
|
||||
and FlashInfer CUTLASS is available."""
|
||||
with patch(
|
||||
"vllm.model_executor.layers.fused_moe.oracle.unquantized.current_platform"
|
||||
) as mock_platform:
|
||||
# Set as CUDA platform with Hopper capability
|
||||
mock_platform.is_cuda.return_value = True
|
||||
mock_platform.is_rocm.return_value = False
|
||||
mock_platform.is_cpu.return_value = False
|
||||
mock_platform.is_xpu.return_value = False
|
||||
mock_platform.is_tpu.return_value = False
|
||||
mock_platform.is_out_of_tree.return_value = False
|
||||
mock_platform.has_device_capability.return_value = True # SM90+
|
||||
|
||||
with (
|
||||
patch.object(current_platform, "is_cuda", return_value=True),
|
||||
patch.object(current_platform, "is_rocm", return_value=False),
|
||||
patch.object(current_platform, "is_cpu", return_value=False),
|
||||
patch.object(current_platform, "is_xpu", return_value=False),
|
||||
patch.object(current_platform, "is_tpu", return_value=False),
|
||||
patch.object(current_platform, "is_out_of_tree", return_value=False),
|
||||
patch.object(current_platform, "has_device_capability", return_value=True),
|
||||
):
|
||||
# Enable FlashInfer via env var
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP16", "1")
|
||||
|
||||
moe_config = make_dummy_moe_config()
|
||||
# CUTLASS requires EP and does not support DP
|
||||
moe_config.moe_parallel_config.use_ep = True
|
||||
moe_config.moe_parallel_config.use_dp = False
|
||||
|
||||
selected_backend = select_unquantized_moe_backend(
|
||||
moe_config=moe_config,
|
||||
use_dp=False, # CUTLASS doesn't support DP
|
||||
selected_backend, experts_cls = select_unquantized_moe_backend(
|
||||
moe_config=moe_config
|
||||
)
|
||||
|
||||
assert selected_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS
|
||||
assert experts_cls is not None
|
||||
|
||||
@@ -210,6 +210,13 @@ def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
|
||||
## Qwen3 Next ##
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=(
|
||||
"FLASHINFER TRTLLM MoE has a bug with all negative router logits "
|
||||
"for models with RENORMALIZE. This will be re-enabled once the "
|
||||
"issue is fixed in flashinfer."
|
||||
)
|
||||
)
|
||||
def test_qwen3_next_bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
can_initialize(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
|
||||
@@ -49,6 +49,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
assert not self.base_layer.use_ep, (
|
||||
"EP support for Fused MoE LoRA is not implemented yet."
|
||||
)
|
||||
assert not self.base_layer.quant_method.is_monolithic, (
|
||||
"Monolithic kernels are not supported for Fused MoE LoRA."
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.device = _get_lora_device(base_layer)
|
||||
|
||||
148
vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py
Normal file
148
vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
|
||||
|
||||
class TrtLlmBf16Experts(mk.FusedMoEExpertsMonolithic):
|
||||
"""
|
||||
BF16 unquantized TRTLLM-Gen MoE kernels. Supports monolithic interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(moe_config, quant_config)
|
||||
self.routing_method_type = moe_config.routing_method
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return (
|
||||
p.is_cuda()
|
||||
and p.is_device_capability_family(100)
|
||||
and has_flashinfer_trtllm_fused_moe()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""BF16 kernels do not support non-gated MoE"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
"""Supports only unquantized inputs."""
|
||||
return weight_key is None and activation_key is None
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU]
|
||||
|
||||
@staticmethod
|
||||
def _supports_routing_method(
|
||||
routing_method: RoutingMethodType,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
return routing_method in [
|
||||
RoutingMethodType.Default,
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
# NOTE: TRTLLM Kernel has issue with Qwen3.5 router.
|
||||
# Re-enable once the issue is resolved.
|
||||
# https://github.com/vllm-project/vllm/issues/37591
|
||||
# RoutingMethodType.Renormalize,
|
||||
# RoutingMethodType.RenormalizeNaive
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> bool:
|
||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||
return (
|
||||
not moe_parallel_config.use_all2all_kernels
|
||||
or moe_parallel_config.use_ag_rs_all2all_kernels
|
||||
) and not moe_parallel_config.enable_eplb
|
||||
|
||||
@staticmethod
|
||||
def _supports_router_logits_dtype(
|
||||
router_logits_dtype: torch.dtype | None,
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
num_expert_group: int | None = None,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
routed_scaling_factor: float | None = None,
|
||||
topk_group: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
import flashinfer
|
||||
|
||||
return flashinfer.fused_moe.trtllm_bf16_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
gemm1_weights=w1,
|
||||
gemm2_weights=w2,
|
||||
num_experts=global_num_experts,
|
||||
top_k=self.topk,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=self.intermediate_size_per_partition,
|
||||
local_expert_offset=self.ep_rank * self.local_num_experts,
|
||||
local_num_experts=self.local_num_experts,
|
||||
routing_method_type=self.routing_method_type,
|
||||
)
|
||||
@@ -1,141 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
#
|
||||
# Methods used by the oracle for kernel selection.
|
||||
#
|
||||
|
||||
|
||||
def _supports_current_device() -> bool:
|
||||
"""Supports only Blackwell-family GPUs."""
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
"""BF16 kernels do not support non-gated MoE"""
|
||||
return False
|
||||
|
||||
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation in [MoEActivation.SILU]
|
||||
|
||||
|
||||
def _supports_routing_method_bf16(
|
||||
routing_method: RoutingMethodType,
|
||||
) -> bool:
|
||||
return routing_method in [
|
||||
RoutingMethodType.Default,
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
]
|
||||
|
||||
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
"""Supports TRTLLM Kernel does not support EPLB."""
|
||||
return not moe_parallel_config.enable_eplb
|
||||
|
||||
|
||||
def is_supported_config_trtllm_bf16(
|
||||
moe_config: FusedMoEConfig,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
|
||||
for BF16 unquantized kernels.
|
||||
"""
|
||||
|
||||
def _make_reason(reason: str) -> str:
|
||||
return f"kernel does not support {reason}"
|
||||
|
||||
if not _supports_current_device():
|
||||
return False, _make_reason(f"current device {current_platform.device_name}")
|
||||
elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
|
||||
return False, _make_reason("no act_and_mul MLP layer")
|
||||
elif not _supports_activation(moe_config.activation):
|
||||
return False, _make_reason(f"{moe_config.activation} activation")
|
||||
elif not _supports_parallel_config(moe_config.moe_parallel_config):
|
||||
return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
|
||||
elif not _supports_routing_method_bf16(moe_config.routing_method):
|
||||
return False, _make_reason(f"routing method {moe_config.routing_method}")
|
||||
elif activation_format != mk.FusedMoEActivationFormat.Standard:
|
||||
return False, _make_reason(f"activation format {activation_format}")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def flashinfer_fused_moe_bf16(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routing_method_type: int,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_bf16_moe
|
||||
|
||||
return flashinfer_trtllm_bf16_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
gemm2_weights=gemm2_weights,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=n_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routing_method_type=routing_method_type,
|
||||
tune_max_num_tokens=tune_max_num_tokens,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_bf16_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routing_method_type: int = RoutingMethodType.Renormalize,
|
||||
tune_max_num_tokens: int = 8192,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_bf16",
|
||||
op_func=flashinfer_fused_moe_bf16,
|
||||
fake_impl=flashinfer_fused_moe_bf16_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order,),
|
||||
)
|
||||
@@ -1967,6 +1967,10 @@ class TritonExperts(mk.FusedMoEExpertsModular):
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_batch_invariance():
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
@@ -563,6 +564,8 @@ class FusedMoEExperts(ABC):
|
||||
)
|
||||
elif activation_format != cls.activation_format():
|
||||
return False, _make_reason(f"{activation_format.value} activation format")
|
||||
elif envs.VLLM_BATCH_INVARIANT and not cls._supports_batch_invariance():
|
||||
return False, _make_reason("batch invariance")
|
||||
return True, None
|
||||
|
||||
@staticmethod
|
||||
@@ -645,6 +648,15 @@ class FusedMoEExperts(ABC):
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_batch_invariance() -> bool:
|
||||
"""
|
||||
Whether the kernel supports batch invariance, i.e. the output does not
|
||||
depend on the order of the tokens in the input batch. This is useful
|
||||
for determining if the kernel can used with VLLM_BATCH_INVARIANT=1.
|
||||
"""
|
||||
return False
|
||||
|
||||
#
|
||||
# Various helpers for accessing quantization parameters from the
|
||||
# quant_config.
|
||||
|
||||
@@ -11,21 +11,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
|
||||
is_supported_config_trtllm_bf16,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
convert_moe_weights_to_flashinfer_trtllm_block_layout,
|
||||
get_flashinfer_moe_backend,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer, has_flashinfer_cutlass_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -35,21 +34,96 @@ class UnquantizedMoeBackend(Enum):
|
||||
FLASHINFER_CUTLASS = "FlashInfer CUTLASS"
|
||||
AITER = "ROCm AITER"
|
||||
TRITON = "TRITON"
|
||||
BATCHED_TRITON = "BATCHED_TRITON"
|
||||
CPU = "CPU"
|
||||
XPU = "XPU"
|
||||
TPU = "TPU"
|
||||
OOT = "OOT"
|
||||
|
||||
|
||||
# NOTE(zyongye): Unsupported backend means backend
|
||||
# that is not conform with Modular kernel format.
|
||||
# We will directly call the kernel for those backend
|
||||
UNSUPPORTED_BACKEND = [
|
||||
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
|
||||
UnquantizedMoeBackend.CPU,
|
||||
UnquantizedMoeBackend.TPU,
|
||||
UnquantizedMoeBackend.OOT,
|
||||
]
|
||||
def _get_priority_backends(moe_config: FusedMoEConfig) -> list[UnquantizedMoeBackend]:
|
||||
"""
|
||||
Get available backends in priority order based on platform and config.
|
||||
|
||||
This function can be extended to become more complex as needed.
|
||||
"""
|
||||
|
||||
def _move_to_back(
|
||||
backends: list[UnquantizedMoeBackend],
|
||||
backend: UnquantizedMoeBackend,
|
||||
) -> None:
|
||||
backends.append(backends.pop(backends.index(backend)))
|
||||
|
||||
if current_platform.is_rocm():
|
||||
_AVAILABLE_BACKENDS = [
|
||||
UnquantizedMoeBackend.AITER,
|
||||
UnquantizedMoeBackend.TRITON,
|
||||
UnquantizedMoeBackend.BATCHED_TRITON,
|
||||
]
|
||||
elif current_platform.is_cuda():
|
||||
_AVAILABLE_BACKENDS = [
|
||||
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
|
||||
UnquantizedMoeBackend.FLASHINFER_CUTLASS,
|
||||
UnquantizedMoeBackend.TRITON,
|
||||
UnquantizedMoeBackend.BATCHED_TRITON,
|
||||
]
|
||||
|
||||
# HACK: Qwen3.5 has crash with FLASHINFER_CUTLASS BF16 if DEP.
|
||||
# Updating the oracle querying logic is out of the scope of this
|
||||
# PR. Need to fix the kernel or update structure in follow up.
|
||||
if moe_config.moe_parallel_config.dp_size > 1:
|
||||
_move_to_back(_AVAILABLE_BACKENDS, UnquantizedMoeBackend.FLASHINFER_CUTLASS)
|
||||
|
||||
elif current_platform.is_xpu():
|
||||
_AVAILABLE_BACKENDS = [UnquantizedMoeBackend.XPU]
|
||||
elif current_platform.is_cpu():
|
||||
_AVAILABLE_BACKENDS = [UnquantizedMoeBackend.CPU]
|
||||
return _AVAILABLE_BACKENDS
|
||||
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: UnquantizedMoeBackend,
|
||||
) -> type[mk.FusedMoEExperts]:
|
||||
if backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_bf16_moe import (
|
||||
TrtLlmBf16Experts,
|
||||
)
|
||||
|
||||
return TrtLlmBf16Experts
|
||||
|
||||
elif backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return FlashInferExperts
|
||||
|
||||
elif backend == UnquantizedMoeBackend.AITER:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
return AiterExperts
|
||||
|
||||
elif backend == UnquantizedMoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
|
||||
return TritonExperts
|
||||
|
||||
elif backend == UnquantizedMoeBackend.BATCHED_TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
|
||||
return BatchedTritonExperts
|
||||
|
||||
elif backend == UnquantizedMoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExperts
|
||||
|
||||
return XPUExperts
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown unquantized MoE backend: {backend.value}")
|
||||
|
||||
|
||||
def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend:
|
||||
@@ -70,194 +144,224 @@ def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend
|
||||
|
||||
def select_unquantized_moe_backend(
|
||||
moe_config: FusedMoEConfig,
|
||||
use_dp: bool,
|
||||
) -> UnquantizedMoeBackend:
|
||||
) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]:
|
||||
"""
|
||||
Select the primary Unquantized MoE backend
|
||||
Select the primary Unquantized MoE backend.
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
|
||||
def _make_log_backend(backend: UnquantizedMoeBackend):
|
||||
return f"Using {backend.value} backend for Unquantized MoE"
|
||||
if current_platform.is_cpu():
|
||||
# TODO: migrate to MK structure.
|
||||
return UnquantizedMoeBackend.CPU, None
|
||||
|
||||
if current_platform.is_tpu():
|
||||
return UnquantizedMoeBackend.TPU, None
|
||||
|
||||
if current_platform.is_out_of_tree():
|
||||
return UnquantizedMoeBackend.OOT, None
|
||||
|
||||
# NOTE: the kernels are selected in the following order.
|
||||
AVAILABLE_BACKENDS = _get_priority_backends(moe_config)
|
||||
|
||||
# NOTE(rob): We need to peak into the P/F selection to determine
|
||||
# if we are using the batched or standard expert format, which
|
||||
# if not ideal. Once we unify TP + DP/EP, we can select P/F first.
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if moe_config.moe_parallel_config.use_batched_activation_format
|
||||
else mk.FusedMoEActivationFormat.Standard
|
||||
)
|
||||
|
||||
# Check if FlashInfer TRTLLM BF16 MoE is supported
|
||||
trtllm_supported, _ = is_supported_config_trtllm_bf16(
|
||||
moe_config=moe_config,
|
||||
activation_format=activation_format,
|
||||
)
|
||||
flashinfer_trtllm_available = has_flashinfer() and trtllm_supported
|
||||
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
|
||||
flashinfer_cutlass_available = (
|
||||
has_flashinfer_cutlass_fused_moe()
|
||||
and (not use_dp)
|
||||
and current_platform.has_device_capability(90)
|
||||
)
|
||||
flashinfer_trtllm_moe_enabled = (
|
||||
flashinfer_trtllm_available
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "latency"
|
||||
)
|
||||
flashinfer_cutlass_moe_enabled = (
|
||||
flashinfer_cutlass_available and envs.VLLM_USE_FLASHINFER_MOE_FP16
|
||||
)
|
||||
rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
def _make_log_backend(backend: UnquantizedMoeBackend) -> str:
|
||||
available_strs = [b.value for b in AVAILABLE_BACKENDS]
|
||||
return (
|
||||
f"Using {backend.value} Unquantized MoE backend out "
|
||||
f"of potential backends: {available_strs}."
|
||||
)
|
||||
|
||||
def _make_log_unsupported(
|
||||
backend: UnquantizedMoeBackend, reason: str | None
|
||||
) -> str:
|
||||
if reason:
|
||||
return (
|
||||
f"Unquantized MoE backend {backend.value} does not support the "
|
||||
f"deployment configuration since {reason}."
|
||||
)
|
||||
return (
|
||||
f"Unquantized MoE backend '{backend.value}' does not support the "
|
||||
"deployment configuration."
|
||||
)
|
||||
|
||||
def _return_or_raise(
|
||||
backend: UnquantizedMoeBackend,
|
||||
config: FusedMoEConfig,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[UnquantizedMoeBackend, type[mk.FusedMoEExperts] | None]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, None, None, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = moe_config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_unquantized_backend(runner_backend)
|
||||
if requested_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
|
||||
if not flashinfer_trtllm_available:
|
||||
raise ValueError(
|
||||
"FlashInfer TRTLLM MoE backend is not available for this "
|
||||
"configuration."
|
||||
)
|
||||
elif requested_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
|
||||
if not flashinfer_cutlass_available:
|
||||
raise ValueError(
|
||||
"FlashInfer CUTLASS MoE backend is not available for this "
|
||||
"configuration."
|
||||
)
|
||||
elif requested_backend == UnquantizedMoeBackend.AITER and not (
|
||||
current_platform.is_rocm() and rocm_aiter_moe_enabled
|
||||
if (
|
||||
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
and requested_backend == UnquantizedMoeBackend.TRITON
|
||||
):
|
||||
raise ValueError(
|
||||
"ROCm AITer MoE backend is not available for this configuration."
|
||||
requested_backend = UnquantizedMoeBackend.BATCHED_TRITON
|
||||
|
||||
return _return_or_raise(requested_backend, moe_config, activation_format)
|
||||
|
||||
# Handle explicit FlashInfer FP16 configuration.
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP16"):
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_FP16:
|
||||
if UnquantizedMoeBackend.FLASHINFER_TRTLLM in AVAILABLE_BACKENDS:
|
||||
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_TRTLLM)
|
||||
if UnquantizedMoeBackend.FLASHINFER_CUTLASS in AVAILABLE_BACKENDS:
|
||||
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.FLASHINFER_CUTLASS)
|
||||
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
fi_backend = get_flashinfer_moe_backend()
|
||||
if fi_backend == FlashinferMoeBackend.CUTLASS:
|
||||
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
|
||||
elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
raise ValueError(
|
||||
f"FlashInfer MOE backend {fi_backend} "
|
||||
"does not support unquantized MoE."
|
||||
)
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
return _return_or_raise(backend, moe_config, activation_format)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try both.
|
||||
for backend in [
|
||||
UnquantizedMoeBackend.FLASHINFER_TRTLLM,
|
||||
UnquantizedMoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, moe_config, None, None, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(
|
||||
_make_log_unsupported(backend, reason), scope="local"
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Found VLLM_USE_FLASHINFER_MOE_FP16=1, but no "
|
||||
"FlashInfer unquantized MoE backend supports the configuration."
|
||||
)
|
||||
logger.info_once(_make_log_backend(requested_backend), scope="local")
|
||||
return requested_backend
|
||||
|
||||
if current_platform.is_rocm():
|
||||
if rocm_aiter_moe_enabled:
|
||||
# Handle explicit AITER FP8 configuration.
|
||||
if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"):
|
||||
if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE:
|
||||
if UnquantizedMoeBackend.AITER in AVAILABLE_BACKENDS:
|
||||
AVAILABLE_BACKENDS.remove(UnquantizedMoeBackend.AITER)
|
||||
else:
|
||||
backend = UnquantizedMoeBackend.AITER
|
||||
else:
|
||||
backend = UnquantizedMoeBackend.TRITON
|
||||
if current_platform.is_cuda():
|
||||
if flashinfer_trtllm_moe_enabled:
|
||||
backend = UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
elif flashinfer_cutlass_moe_enabled:
|
||||
backend = UnquantizedMoeBackend.FLASHINFER_CUTLASS
|
||||
if trtllm_supported:
|
||||
logger.info_once(
|
||||
"FlashInfer TRTLLM MoE is available but not enabled, "
|
||||
"consider setting VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"to enable it for better performance.",
|
||||
scope="local",
|
||||
)
|
||||
else:
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_FP16 and trtllm_supported:
|
||||
logger.info_once(
|
||||
"FlashInfer TRTLLM MoE is available but not enabled, "
|
||||
"consider setting VLLM_USE_FLASHINFER_MOE_FP16=1 "
|
||||
"and VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"to enable it for better performance.",
|
||||
scope="local",
|
||||
)
|
||||
elif not use_dp and flashinfer_cutlass_available:
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is available"
|
||||
" but not enabled, consider setting"
|
||||
" VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.",
|
||||
scope="local",
|
||||
)
|
||||
elif use_dp:
|
||||
logger.info_once(
|
||||
"FlashInfer CUTLASS MoE is currently not available for DP.",
|
||||
scope="local",
|
||||
)
|
||||
backend = UnquantizedMoeBackend.TRITON
|
||||
if current_platform.is_xpu():
|
||||
backend = UnquantizedMoeBackend.XPU
|
||||
if current_platform.is_cpu():
|
||||
backend = UnquantizedMoeBackend.CPU
|
||||
if current_platform.is_tpu():
|
||||
backend = UnquantizedMoeBackend.TPU
|
||||
if current_platform.is_out_of_tree():
|
||||
backend = UnquantizedMoeBackend.OOT
|
||||
return _return_or_raise(backend, moe_config, activation_format)
|
||||
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
k_cls = backend_to_kernel_cls(backend)
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, moe_config, None, None, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend), scope="local")
|
||||
return backend, k_cls
|
||||
|
||||
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
|
||||
|
||||
raise NotImplementedError(
|
||||
"No Unquantized MoE backend supports the deployment configuration."
|
||||
)
|
||||
|
||||
|
||||
def convert_to_unquantized_kernel_format(
|
||||
unquantized_backend: UnquantizedMoeBackend,
|
||||
layer: Module,
|
||||
w13_weight: torch.Tensor | None = None,
|
||||
w2_weight: torch.Tensor | None = None,
|
||||
w13_weight: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if unquantized_backend == UnquantizedMoeBackend.AITER:
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(w13_weight, w2_weight)
|
||||
|
||||
elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if layer.moe_config.is_act_and_mul:
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
# Non-gated MoE: w13 is a single projection, no need to swap.
|
||||
w13_weight = swap_w13_to_w31(w13_weight)
|
||||
|
||||
return w13_weight, w2_weight
|
||||
elif unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w13_weight = swap_w13_to_w31(w13_weight)
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
w13_weight, w2_weight = convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
_cache_permute_indices,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
)
|
||||
|
||||
return w13_weight.contiguous(), w2_weight.contiguous()
|
||||
|
||||
|
||||
def make_unquantized_moe_kernel(
|
||||
backend: UnquantizedMoeBackend,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> mk.FusedMoEKernel | None:
|
||||
if backend in UNSUPPORTED_BACKEND:
|
||||
return None
|
||||
backend: UnquantizedMoeBackend,
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize
|
||||
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=is_monolithic,
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
|
||||
|
||||
# Create Experts
|
||||
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
||||
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens is not None
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
else:
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
FlashInferExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts=(
|
||||
shared_experts
|
||||
if moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
inplace=(not moe_config.disable_inplace and not is_monolithic),
|
||||
)
|
||||
|
||||
elif backend == UnquantizedMoeBackend.AITER:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
AiterExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
TritonExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe import XPUExperts
|
||||
|
||||
kernel = mk.FusedMoEKernel(
|
||||
MoEPrepareAndFinalizeNoDPEPModular(),
|
||||
XPUExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
return kernel
|
||||
|
||||
@@ -6,11 +6,8 @@ from collections.abc import Callable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -23,7 +20,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEExpertsModular,
|
||||
FusedMoEPrepareAndFinalizeModular,
|
||||
)
|
||||
@@ -33,20 +29,10 @@ from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
|
||||
make_unquantized_moe_kernel,
|
||||
select_unquantized_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
convert_moe_weights_to_flashinfer_trtllm_block_layout,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts
|
||||
else:
|
||||
TritonExperts = None # type: ignore
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,45 +45,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.unquantized_backend = select_unquantized_moe_backend(
|
||||
self.unquantized_backend, self.experts_cls = select_unquantized_moe_backend(
|
||||
moe_config=self.moe,
|
||||
use_dp=self.moe.moe_parallel_config.dp_size > 1,
|
||||
)
|
||||
|
||||
# AITER only supports gated activations (silu/gelu), so disable it
|
||||
# for non-gated MoE (is_act_and_mul=False)
|
||||
self.rocm_aiter_moe_enabled = (
|
||||
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
|
||||
)
|
||||
self.kernel: mk.FusedMoEKernel | None = None
|
||||
self._is_monolithic = (
|
||||
current_platform.is_cpu()
|
||||
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
)
|
||||
|
||||
if self.is_monolithic:
|
||||
self.apply_monolithic: Callable = self._select_monolithic()
|
||||
|
||||
def _select_monolithic(self) -> Callable:
|
||||
"""Select the monolithic implementation based on platform."""
|
||||
if current_platform.is_cpu():
|
||||
return self.forward_monolithic_cpu
|
||||
else:
|
||||
return self.forward_monolithic_cuda
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_cuda(layer, x, topk_weights, topk_ids, shared_experts_input)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self._is_monolithic
|
||||
# Escape hatch for CPU, which stays on the old monolithic path.
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
|
||||
return True
|
||||
return super().is_monolithic
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
@@ -106,43 +63,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic for all but the CPU backend. CPU backend is monolithic. "
|
||||
"So this function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEExpertsModular:
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
return BatchedTritonExperts(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
elif (
|
||||
self.unquantized_backend == UnquantizedMoeBackend.AITER
|
||||
and rocm_aiter_ops.is_fused_moe_enabled()
|
||||
):
|
||||
from .rocm_aiter_fused_moe import AiterExperts
|
||||
|
||||
logger.debug("AiterExperts %s", self.moe)
|
||||
return AiterExperts(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
return TritonExperts(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -227,14 +163,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
replace_parameter(layer, "w13_weight", w13)
|
||||
replace_parameter(layer, "w2_weight", w2)
|
||||
|
||||
# Setup Modular Kernel for TP Case
|
||||
# Setup moe kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
self.kernel = make_unquantized_moe_kernel(
|
||||
backend=self.unquantized_backend,
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_unquantized_moe_kernel(
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
backend=self.unquantized_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
@@ -244,22 +183,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM:
|
||||
_cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
layer.w13_weight.data = w13_weight_swapped.contiguous()
|
||||
w13_weights_shuffled, w2_weights_shuffled = (
|
||||
convert_moe_weights_to_flashinfer_trtllm_block_layout(
|
||||
_cache_permute_indices,
|
||||
layer.w13_weight.data,
|
||||
layer.w2_weight.data,
|
||||
)
|
||||
)
|
||||
layer.w13_weight = Parameter(w13_weights_shuffled, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weights_shuffled, requires_grad=False)
|
||||
if self.unquantized_backend in [
|
||||
UnquantizedMoeBackend.TPU,
|
||||
UnquantizedMoeBackend.OOT,
|
||||
]:
|
||||
# OOT handles internally.
|
||||
return
|
||||
|
||||
elif self.unquantized_backend == UnquantizedMoeBackend.CPU:
|
||||
# CPU stays on the old path — no oracle, no moe_kernel.
|
||||
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
|
||||
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
@@ -290,13 +222,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
else:
|
||||
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
elif current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
else:
|
||||
self._setup_kernel(
|
||||
layer=layer,
|
||||
w13=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
)
|
||||
else:
|
||||
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
@@ -313,16 +254,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
)
|
||||
else:
|
||||
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
def forward_cuda(
|
||||
def forward_native(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
@@ -330,9 +262,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel is not None
|
||||
|
||||
return self.kernel.apply(
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -345,53 +276,58 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
def forward_monolithic_cuda(
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.forward_native(
|
||||
layer, x, topk_weights, topk_ids, shared_experts_input
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: F401
|
||||
|
||||
assert self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
return torch.ops.vllm.flashinfer_fused_moe_bf16(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
hidden_states=x,
|
||||
gemm1_weights=layer.w13_weight,
|
||||
gemm2_weights=layer.w2_weight,
|
||||
num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
n_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
intermediate_size=layer.intermediate_size_per_partition,
|
||||
local_expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
routing_method_type=layer.routing_method_type,
|
||||
)
|
||||
|
||||
def forward_monolithic_cpu(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.cpu_fused_moe(
|
||||
layer,
|
||||
x,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
layer.global_num_experts,
|
||||
layer.expert_map,
|
||||
layer.custom_routing_function,
|
||||
layer.scoring_func,
|
||||
layer.routed_scaling_factor,
|
||||
layer.e_score_correction_bias,
|
||||
layer.apply_router_weight_on_input,
|
||||
layer.activation,
|
||||
)
|
||||
assert self.is_monolithic
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.CPU:
|
||||
assert self.moe_kernel is None
|
||||
return self.cpu_fused_moe(
|
||||
layer,
|
||||
x,
|
||||
layer.use_grouped_topk,
|
||||
layer.top_k,
|
||||
router_logits,
|
||||
layer.renormalize,
|
||||
layer.topk_group,
|
||||
layer.num_expert_group,
|
||||
layer.global_num_experts,
|
||||
layer.expert_map,
|
||||
layer.custom_routing_function,
|
||||
layer.scoring_func,
|
||||
layer.routed_scaling_factor,
|
||||
layer.e_score_correction_bias,
|
||||
layer.apply_router_weight_on_input,
|
||||
layer.activation,
|
||||
)
|
||||
else:
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
@@ -202,6 +202,7 @@ def has_flashinfer_trtllm_fused_moe() -> bool:
|
||||
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_bf16_moe"),
|
||||
]
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
|
||||
Reference in New Issue
Block a user