diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index 0e7ab7255..c239cb5d0 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -1131,7 +1131,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+ - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index c2587dd81..d1a536a07 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -1017,7 +1017,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+ - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml
index 2772d69b4..8641a18b4 100644
--- a/.buildkite/test_areas/kernels.yaml
+++ b/.buildkite/test_areas/kernels.yaml
@@ -85,7 +85,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+ - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
index 4894d37c4..c1f4f0aa9 100644
--- a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
+++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
@@ -197,7 +197,7 @@ def bench_run(
)
kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
make_dummy_moe_config(),
quant_config=quant_config,
@@ -242,7 +242,7 @@ def bench_run(
)
kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
make_dummy_moe_config(),
quant_config=quant_config,
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 022c4f2e8..75ebee6ec 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -36,8 +36,7 @@ th {
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
-| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] |
-| flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
+| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
| MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |
diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py
index 07964031f..4ee18e342 100644
--- a/tests/kernels/moe/modular_kernel_tools/common.py
+++ b/tests/kernels/moe/modular_kernel_tools/common.py
@@ -22,6 +22,9 @@ from vllm.distributed import (
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
+from vllm.model_executor.layers.fused_moe.all2all_utils import (
+ maybe_make_prepare_finalize,
+)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -40,7 +43,6 @@ from .mk_objects import (
TestMoEQuantConfig,
expert_info,
make_fused_experts,
- make_prepare_finalize,
prepare_finalize_info,
)
from .parallel_utils import ProcessGroupInfo
@@ -603,10 +605,12 @@ def make_modular_kernel(
routing_method=RoutingMethodType.DeepSeekV3,
)
- # make modular kernel
- prepare_finalize = make_prepare_finalize(
- config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
+ prepare_finalize = maybe_make_prepare_finalize(
+ moe=moe,
+ quant_config=quant_config,
+ allow_new_interface=True,
)
+ assert prepare_finalize is not None
fused_experts = make_fused_experts(
config.fused_experts_type,
diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
index f62e11641..d215d2ab6 100644
--- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py
+++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
@@ -7,9 +7,6 @@ import torch
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import TritonExperts
-from vllm.model_executor.layers.fused_moe.all2all_utils import (
- maybe_make_prepare_finalize,
-)
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
@@ -255,13 +252,12 @@ if has_pplx():
)
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
+ from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
+ FlashInferCutlassMoEPrepareAndFinalize,
+ )
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- FlashInferCutlassMoEPrepareAndFinalize,
- create_flashinfer_prepare_finalize,
- )
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
@@ -429,24 +425,6 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
]
-def make_prepare_finalize(
- prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
- backend: str | None,
- moe: FusedMoEConfig,
- quant_config: FusedMoEQuantConfig,
-) -> mk.FusedMoEPrepareAndFinalize:
- if backend != "naive" and backend is not None:
- prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
- assert prepare_finalize is not None
- return prepare_finalize
- elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
- return create_flashinfer_prepare_finalize(
- use_dp=moe.moe_parallel_config.dp_size > 1
- )
- else:
- return MoEPrepareAndFinalizeNoEP()
-
-
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index e5752e404..a484a3f21 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -294,12 +294,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
)
kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(
- defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
- moe_config=moe_config,
- quant_config=quant_config,
- )
- ),
+ MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py
index 66ceb0c64..9bb61ddfa 100644
--- a/tests/kernels/moe/test_flashinfer_moe.py
+++ b/tests/kernels/moe/test_flashinfer_moe.py
@@ -106,12 +106,7 @@ def test_flashinfer_fp4_moe_no_graph(
)
flashinfer_experts = FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(
- defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
- moe_config=moe_config,
- quant_config=quant_config,
- )
- ),
+ MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)
diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py
index 0011149e3..a22b2088b 100644
--- a/tests/kernels/moe/test_nvfp4_moe.py
+++ b/tests/kernels/moe/test_nvfp4_moe.py
@@ -90,7 +90,7 @@ def test_cutlass_fp4_moe_no_graph(
)
kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py
index 97a9a2815..678cd4580 100644
--- a/vllm/distributed/device_communicators/all2all.py
+++ b/vllm/distributed/device_communicators/all2all.py
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
return buffer
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
return hidden_states, router_logits
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ if extra_tensors is not None:
+ raise NotImplementedError(
+ "extra_tensors is not supported for NaiveAll2AllManager"
+ )
+ sp_size = self.tp_group.world_size if is_sequence_parallel else 1
+ dp_metadata = get_forward_context().dp_metadata
+ assert dp_metadata is not None
+ cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
+
+ hidden_states = self.naive_multicast(
+ hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
+ )
+ topk_weights = self.naive_multicast(
+ topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
+ )
+ topk_ids = self.naive_multicast(
+ topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
+ )
+ return hidden_states, topk_weights, topk_ids
+
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ """
+ Gather hidden_states and router_logits from all dp ranks.
+ """
+ dp_metadata = get_forward_context().dp_metadata
+ assert dp_metadata is not None
+ sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
+ assert sizes is not None
+ dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
+ assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
+
+ tensors_to_gather = [hidden_states, topk_weights, topk_ids]
+ if extra_tensors is not None:
+ tensors_to_gather.extend(extra_tensors)
+
+ gathered_tensors = dist_group.all_gatherv(
+ tensors_to_gather,
+ dim=0,
+ sizes=sizes,
+ )
+
+ hidden_states = gathered_tensors[0]
+ topk_weights = gathered_tensors[1]
+ topk_ids = gathered_tensors[2]
+
+ if extra_tensors is None:
+ return hidden_states, topk_weights, topk_ids
+
+ return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
+
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ raise NotImplementedError
+
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ raise NotImplementedError
+
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py
index b09d5f44d..572bac80f 100644
--- a/vllm/distributed/device_communicators/base_device_communicator.py
+++ b/vllm/distributed/device_communicators/base_device_communicator.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
-from typing import Any
from weakref import WeakValueDictionary
import torch
@@ -64,13 +63,32 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> Any:
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ # Subclasses should either:
+ # - implement handling for extra_tensors, or
+ # - raise a clear error if extra_tensors is not supported.
+ raise NotImplementedError
+
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
@@ -280,7 +298,7 @@ class DeviceCommunicatorBase:
for module in moe_modules:
module.maybe_init_modular_kernel()
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -294,8 +312,29 @@ class DeviceCommunicatorBase:
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
+ if extra_tensors is not None:
+ return hidden_states, router_logits, extra_tensors
return hidden_states, router_logits
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ """
+ Dispatch the hidden states and topk weights/ids to the appropriate device.
+ This is a no-op in the base class.
+ """
+ if extra_tensors is not None:
+ return hidden_states, topk_weights, topk_ids, extra_tensors
+ return hidden_states, topk_weights, topk_ids
+
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py
index 25db6a2eb..b5fbdfcc3 100644
--- a/vllm/distributed/device_communicators/cpu_communicator.py
+++ b/vllm/distributed/device_communicators/cpu_communicator.py
@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src)
- def dispatch( # type: ignore[override]
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ """
+ Dispatch the hidden states and router logits to the appropriate device.
+ This is a no-op in the base class.
+ """
+
assert self.all2all_manager is not None
- return self.all2all_manager.dispatch(
+ return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
- extra_tensors, # type: ignore[call-arg]
+ extra_tensors,
+ )
+
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ """
+ Dispatch the hidden states and topk weights/ids to the appropriate device.
+ This is a no-op in the base class.
+ """
+ assert self.all2all_manager is not None
+ return self.all2all_manager.dispatch(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel,
+ extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
+ """
+ Combine the hidden states and router logits from the appropriate device.
+ This is a no-op in the base class.
+ """
assert self.all2all_manager is not None
- hidden_states = self.all2all_manager.combine(
- hidden_states, is_sequence_parallel
+ return self.all2all_manager.combine(
+ hidden_states,
+ is_sequence_parallel,
)
- return hidden_states
class _CPUSHMDistributed:
diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py
index f67ddfd31..4c78871e1 100644
--- a/vllm/distributed/device_communicators/cuda_communicator.py
+++ b/vllm/distributed/device_communicators/cuda_communicator.py
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
- def dispatch( # type: ignore[override]
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
+ """
+ Dispatch the hidden states and router logits to the appropriate device.
+ This is a no-op in the base class.
+ """
+
assert self.all2all_manager is not None
- return self.all2all_manager.dispatch(
+ return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
- extra_tensors, # type: ignore[call-arg]
+ extra_tensors,
+ )
+
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ """
+ Dispatch the hidden states and topk weights/ids to the appropriate device.
+ This is a no-op in the base class.
+ """
+ assert self.all2all_manager is not None
+ return self.all2all_manager.dispatch(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel,
+ extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
+ """
+ Combine the hidden states and router logits from the appropriate device.
+ This is a no-op in the base class.
+ """
assert self.all2all_manager is not None
- hidden_states = self.all2all_manager.combine(
- hidden_states, is_sequence_parallel
+ return self.all2all_manager.combine(
+ hidden_states,
+ is_sequence_parallel,
)
- return hidden_states
diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py
index 61aee2db4..81f4ae207 100644
--- a/vllm/distributed/device_communicators/mnnvl_compat.py
+++ b/vllm/distributed/device_communicators/mnnvl_compat.py
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any
+
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group)
return gathered
+ # NOTE(rob): CommBackend is an abstract class, and bcast/barrier
+ # are unimplemented on vLLM side. If we need to utilize these
+ # methods in the future, can create a concrete implementation.
+ def bcast(self, data: Any, root: int) -> Any:
+ raise NotImplementedError
+
+ def barrier(self) -> None:
+ raise NotImplementedError
+
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py
index 6bc26b6f3..85c7f18e3 100644
--- a/vllm/distributed/device_communicators/xpu_communicator.py
+++ b/vllm/distributed/device_communicators/xpu_communicator.py
@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
dist.broadcast(input_, src=src, group=self.device_group)
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ """
+ Dispatch the hidden states and router logits to the appropriate device.
+ This is a no-op in the base class.
+ """
+
assert self.all2all_manager is not None
- return self.all2all_manager.dispatch(
+ return self.all2all_manager.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
- extra_tensors, # type: ignore[call-arg]
+ extra_tensors,
+ )
+
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ ):
+ """
+ Dispatch the hidden states and topk weights/ids to the appropriate device.
+ This is a no-op in the base class.
+ """
+ assert self.all2all_manager is not None
+ return self.all2all_manager.dispatch(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel,
+ extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
+ """
+ Combine the hidden states and router logits from the appropriate device.
+ This is a no-op in the base class.
+ """
assert self.all2all_manager is not None
- hidden_states = self.all2all_manager.combine(
- hidden_states, is_sequence_parallel
+ return self.all2all_manager.combine(
+ hidden_states,
+ is_sequence_parallel,
)
- return hidden_states
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index d8c6ceba3..8f9dc0354 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -1000,7 +1000,7 @@ class GroupCoordinator:
if self.device_communicator is not None:
self.device_communicator.prepare_communication_buffer_for_model(model)
- def dispatch(
+ def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -1011,7 +1011,7 @@ class GroupCoordinator:
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
- return self.device_communicator.dispatch( # type: ignore[call-arg]
+ return self.device_communicator.dispatch_router_logits(
hidden_states,
router_logits,
is_sequence_parallel,
@@ -1020,6 +1020,28 @@ class GroupCoordinator:
else:
return hidden_states, router_logits
+ def dispatch(
+ self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ is_sequence_parallel: bool = False,
+ extra_tensors: list[torch.Tensor] | None = None,
+ ) -> (
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
+ | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ ):
+ if self.device_communicator is not None:
+ return self.device_communicator.dispatch(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel,
+ extra_tensors,
+ )
+ else:
+ return hidden_states, topk_weights, topk_ids
+
def combine(
self, hidden_states, is_sequence_parallel: bool = False
) -> torch.Tensor:
diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py
index 638602a92..bf8ec2dc6 100644
--- a/vllm/model_executor/layers/fused_moe/all2all_utils.py
+++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py
@@ -7,17 +7,27 @@ import torch
from vllm.distributed import (
get_ep_group,
)
+from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
+from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
+ FlashInferA2APrepareAndFinalize,
+)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNaiveEP,
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
+logger = init_logger(__name__)
+
if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None:
+ # NOTE(rob): we are migrating each quant_method to hold the MK
+ # in all cases. The allow_new_interface=False flag allow us to fall
+ # back to the old method for methods that have not yet been migrated.
+ #
+ # In old method:
+ # * maybe_init_modular_kernel() calls this function. If we are
+ # using no Dp/Ep or naive all2all, we return None this function
+ # returns None and no ModularKernelMethod is created. If non-naive
+ # all2all is used, this returns a PrepareAndFinalize object and
+ # a ModularKernelMethod is created.
+ # In new method:
+ # * maybe_make_prepare_finalize() is called from the oracle. We
+ # always return a PrepareAndFinalize object and the quant method
+ # holds the ModularKernel.
if not moe.moe_parallel_config.use_all2all_kernels:
- return None
+ if not allow_new_interface:
+ return None
+
+ # For DP/TP case, fall back to naive P/F.
+ if moe.moe_parallel_config.dp_size > 1:
+ logger.info_once(
+ "Detected DP deployment with no --enable-expert-parallel. "
+ "Falling back to AllGather+ReduceScatter dispatch/combine."
+ )
+ return MoEPrepareAndFinalizeNaiveEP(
+ is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
+ num_dispatchers=(
+ get_ep_group().device_communicator.all2all_manager.world_size
+ ),
+ )
+ else:
+ return MoEPrepareAndFinalizeNoEP()
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
- # TODO(rob): update this as part of the MoE refactor.
- assert not moe.use_flashinfer_cutlass_kernels, (
- "Must be created in modelopt.py or fp8.py"
- )
-
if moe.use_pplx_kernels:
assert quant_config is not None
@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch,
)
+ elif moe.use_fi_all2allv_kernels:
+ assert quant_config is not None
+ prepare_finalize = FlashInferA2APrepareAndFinalize(
+ num_dispatchers=all2all_manager.world_size,
+ )
+
+ elif moe.use_naive_all2all_kernels and allow_new_interface:
+ prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
+ is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
+ num_dispatchers=all2all_manager.world_size,
+ )
+
return prepare_finalize
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 9a28c3193..b275bf414 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
-from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import cdiv
@@ -862,6 +861,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
+ is_sequence_parallel: bool # whether sequence parallelism is used
enable_eplb: bool # whether to enable expert load balancing
@property
@@ -883,6 +883,12 @@ class FusedMoEParallelConfig:
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
+ @property
+ def use_fi_all2allv_kernels(self):
+ return (
+ self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
+ )
+
@property
def use_batched_activation_format(self):
return self.use_deepep_ll_kernels or self.use_pplx_kernels
@@ -1014,6 +1020,7 @@ class FusedMoEParallelConfig:
ep_rank=0,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
+ is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
# DP + EP / TP + EP / DP + TP + EP
@@ -1033,6 +1040,7 @@ class FusedMoEParallelConfig:
ep_rank=ep_rank,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
+ is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe,
enable_eplb=vllm_parallel_config.enable_eplb,
)
@@ -1051,6 +1059,7 @@ class FusedMoEParallelConfig:
use_ep=False,
all2all_backend="naive",
enable_eplb=False,
+ is_sequence_parallel=False,
)
@@ -1145,12 +1154,9 @@ class FusedMoEConfig:
return self.moe_parallel_config.use_mori_kernels
@property
- def use_flashinfer_cutlass_kernels(self):
- """
- Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
- """
- return (
- envs.VLLM_USE_FLASHINFER_MOE_FP4
- and has_flashinfer_cutlass_fused_moe()
- and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
- )
+ def use_fi_all2allv_kernels(self):
+ return self.moe_parallel_config.use_fi_all2allv_kernels
+
+ @property
+ def use_naive_all2all_kernels(self):
+ return self.moe_parallel_config.use_naive_all2all_kernels
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index 0d8690638..86edbe303 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -103,7 +103,14 @@ def run_cutlass_moe_fp8(
or a2_scale.size(0) == a1q.shape[0]
), "Intermediate scale shape mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
- if expert_map is not None:
+
+ # NOTE(rob): the expert_map is used for the STANDARD case and
+ # the batched format is used by the BATCHED case.
+ # TODO(rob): update the MK interface to only pass the expert_map
+ # during the STANDARD case to make this clearer across all kernels.
+ if use_batched_format:
+ assert expert_num_tokens is not None
+ else:
assert expert_num_tokens is None
# We have two modes: batched experts and non-batched experts.
@@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
- return not moe_parallel_config.use_all2all_kernels
+ return not (
+ moe_parallel_config.use_fi_all2allv_kernels
+ or moe_parallel_config.use_deepep_ht_kernels
+ )
def supports_chunking(self) -> bool:
return True
@@ -641,10 +651,8 @@ def run_cutlass_moe_fp4(
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
- @staticmethod
- def expects_unquantized_inputs(
- moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
- ) -> bool:
+ @property
+ def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index 1d5d039b6..fafcf6de6 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- return True
+ # NOTE(rob): discovered an IMA with this combination. Needs investigation.
+ return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
index 929cff799..514aa205a 100644
--- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
@@ -103,6 +103,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool,
) -> Callable:
has_scales = token_scales is not None
@@ -174,6 +175,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights,
a1_scale,
quant_config,
+ defer_input_quant=defer_input_quant,
)
def _receiver(
@@ -187,6 +189,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool,
) -> mk.PrepareResultType:
if event.event is not None:
event.current_stream_wait()
@@ -221,14 +224,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_num_tokens_per_expert_list, device=expert_x.device
)
- # Dispatch and Quant
- # DeepEP kernels only support dispatching block-quantized
- # activation scales.
- # Dispatch in bfloat16 and quantize afterwards
- if not quant_config.is_block_quantized:
+ # * For non-block quant, dispatch in b16 and quantize now as
+ # DeepEP kernels only support dispatching block scales.
+ # * For expert kernels that require unquantized inputs,
+ # defer quantization to FusedMoEExpertsPermuteUnpermute.
+ if not quant_config.is_block_quantized and not defer_input_quant:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
+ # TODO: support per_act_token_quant,
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
@@ -257,6 +261,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> mk.ReceiverType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
@@ -266,8 +271,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
a1 = a1 * topk_weights.to(a1.dtype)
- if quant_config.is_block_quantized:
- # Quant and Dispatch
+ # * DeepEP only supports fp8 block scales so quantize
+ # before the dispatch for these models.
+ # * For all other quantization, dispatch after.
+ # * For expert kernels that require unquantized inputs,
+ # defer quantization to FusedMoEExpertsPermuteUnpermute.
+ if quant_config.is_block_quantized and not defer_input_quant:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,
@@ -281,7 +290,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else:
a1q = a1
a1q_scale = None
- a1_post_scale = quant_config.a1_scale
+ a1_post_scale = (
+ quant_config.a1_gscale
+ if quant_config.quant_dtype == "nvfp4"
+ else quant_config.a1_scale
+ )
return self._do_dispatch(
tokens=a1q,
@@ -291,6 +304,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config,
+ defer_input_quant=defer_input_quant,
)
def prepare(
@@ -302,6 +316,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
@@ -311,6 +326,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map,
apply_router_weight_on_input,
quant_config,
+ defer_input_quant,
)
return receiver()
diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
index c0e0b0b22..f5a3da438 100644
--- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
@@ -242,7 +242,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
+ if defer_input_quant:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not support defer_input_quant=True. "
+ "Please select an MoE kernel that accepts quantized inputs."
+ )
+
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
f"Hidden Size {hidden_size} not in supported list of hidden sizes"
@@ -344,7 +351,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> mk.PrepareResultType:
+ if defer_input_quant:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not support defer_input_quant=True. "
+ "Please select an MoE kernel that accepts quantized inputs."
+ )
hook, receiver = self.prepare_async(
a1,
topk_weights,
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
new file mode 100644
index 000000000..39b373861
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
@@ -0,0 +1,226 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.distributed import get_ep_group
+from vllm.distributed.device_communicators.base_device_communicator import (
+ All2AllManagerBase,
+)
+from vllm.forward_context import get_forward_context
+from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
+from vllm.utils.flashinfer import nvfp4_block_scale_interleave
+
+
+def get_local_sizes():
+ return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
+
+
+class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
+ """Base class for FlashInfer MoE prepare and finalize operations."""
+
+ def __init__(
+ self,
+ num_dispatchers: int = 1,
+ ):
+ super().__init__()
+ self.num_dispatchers_ = num_dispatchers
+ self.all2all_manager = get_ep_group().device_communicator.all2all_manager
+
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return self.num_dispatchers_
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def _apply_router_weight_on_input(
+ self,
+ a1: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ apply_router_weight_on_input: bool,
+ ) -> None:
+ """Apply router weight on input if needed."""
+ if apply_router_weight_on_input:
+ topk = topk_ids.size(1)
+ assert topk == 1, (
+ "apply_router_weight_on_input is only implemented for topk=1"
+ )
+ a1.mul_(topk_weights.to(a1.dtype))
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> mk.PrepareResultType:
+ self._apply_router_weight_on_input(
+ a1, topk_weights, topk_ids, apply_router_weight_on_input
+ )
+ global_num_tokens_cpu = get_local_sizes()
+ top_k = topk_ids.size(1)
+
+ (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
+ flashinfer_alltoall_dispatch(
+ self.all2all_manager,
+ global_num_tokens_cpu,
+ a1,
+ quant_config.a1_gscale,
+ topk_ids,
+ topk_weights,
+ top_k,
+ num_experts,
+ quant_config,
+ defer_input_quant=defer_input_quant,
+ )
+ )
+
+ return a1q, a1q_scale, None, topk_ids, topk_weights
+
+ def finalize(
+ self,
+ output: torch.Tensor,
+ fused_expert_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ apply_router_weight_on_input: bool,
+ weight_and_reduce_impl: mk.TopKWeightAndReduce,
+ ) -> None:
+ top_k = topk_ids.size(1)
+ token_count = output.shape[0]
+ fused_expert_output = flashinfer_alltoall_combine(
+ self.all2all_manager,
+ fused_expert_output,
+ top_k=top_k,
+ token_count=token_count,
+ alltoall_info=self.alltoall_info,
+ )
+ output.copy_(fused_expert_output)
+
+
+def flashinfer_alltoall_dispatch(
+ all2all_manager: All2AllManagerBase,
+ global_num_tokens_cpu: list[int],
+ x: torch.Tensor,
+ gs: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk_weights: torch.Tensor,
+ top_k: int,
+ num_experts: int,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+):
+ from flashinfer.comm.trtllm_alltoall import MnnvlMoe
+
+ assert all2all_manager.ensure_alltoall_workspace_initialized(), (
+ "FlashInfer AllToAll workspace not available"
+ )
+
+ ep_rank = all2all_manager.rank
+ ep_size = all2all_manager.world_size
+ max_num_token = (
+ max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
+ )
+ orig_topk_weights_dtype = topk_weights.dtype
+ alltoall_info, topk_ids, topk_weights, _ = (
+ MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
+ topk_ids,
+ topk_weights,
+ None,
+ all2all_manager.prepare_workspace_tensor,
+ max_num_token,
+ ep_rank,
+ ep_size,
+ num_experts,
+ num_experts,
+ top_k,
+ )
+ )
+ topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
+
+ if not defer_input_quant:
+ x, x_sf = moe_kernel_quantize_input(
+ x,
+ gs,
+ quant_config.quant_dtype,
+ quant_config.per_act_token_quant,
+ quant_config.block_shape,
+ # NOTE: swizzling pads the scales to multiple of 128
+ # which makes the scales tensor different shape than
+ # the hidden states, breaking the A2A kernel. So, we
+ # delay the swizzling until after the A2A.
+ is_fp4_scale_swizzled=False,
+ )
+
+ x = MnnvlMoe.mnnvl_moe_alltoallv(
+ x,
+ alltoall_info,
+ all2all_manager.workspace_tensor,
+ ep_rank,
+ ep_size,
+ )
+
+ x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
+ x_sf,
+ alltoall_info,
+ all2all_manager.workspace_tensor,
+ ep_rank,
+ ep_size,
+ )
+
+ # Swizzle after the A2A if nvfp4.
+ if quant_config.quant_dtype == "nvfp4":
+ if x_sf.element_size() == 1:
+ x_sf = x_sf.view(torch.uint8)
+ x_sf = nvfp4_block_scale_interleave(x_sf)
+ else:
+ # Block-scale path: pass activations through without quantization
+ x_sf = None
+ x = MnnvlMoe.mnnvl_moe_alltoallv(
+ x,
+ alltoall_info,
+ all2all_manager.workspace_tensor,
+ ep_rank,
+ ep_size,
+ )
+ return alltoall_info, topk_ids, topk_weights, x, x_sf
+
+
+def flashinfer_alltoall_combine(
+ all2all_manager: All2AllManagerBase,
+ output: torch.Tensor,
+ top_k: int,
+ token_count: int,
+ alltoall_info,
+):
+ from flashinfer.comm.trtllm_alltoall import MnnvlMoe
+
+ assert all2all_manager.ensure_alltoall_workspace_initialized(), (
+ "FlashInfer AllToAll workspace not available"
+ )
+ return MnnvlMoe.mnnvl_moe_alltoallv_combine(
+ output,
+ alltoall_info,
+ all2all_manager.workspace_tensor,
+ ep_rank=all2all_manager.rank,
+ ep_size=all2all_manager.world_size,
+ top_k=top_k,
+ token_count=token_count,
+ )
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
index 8d5985875..faa654ea3 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
@@ -78,16 +78,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
- @staticmethod
- def expects_unquantized_inputs(
- moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
- ) -> bool:
- # NVFP4 TP kernels and FP8 block-quantized kernels apply
- # input quantization inside FusedMoEPermuteExpertsUnpermute.
- return (
- quant_config.use_nvfp4_w4a4
- and not moe_config.moe_parallel_config.use_all2all_kernels
- ) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized)
+ @property
+ def expects_unquantized_inputs(self) -> bool:
+ return self.quant_config.use_fp8_w8a8 and self.quant_config.is_block_quantized
@staticmethod
def _supports_current_device() -> bool:
@@ -144,10 +137,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function.
- return (
- moe_parallel_config.dp_size == 1
- or moe_parallel_config.dp_size == moe_parallel_config.ep_size
- )
+ # TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
+ return not moe_parallel_config.is_sequence_parallel
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
@@ -194,8 +185,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
workspace1 = (M, K)
workspace2 = (0,)
- # For TP, the quantization is fused with fused_moe call.
- output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
+ # For NVFP4, the output is stored in a packed int8 format,
+ # so the actual hidden dim is 2x the size of K here.
+ output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape)
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
deleted file mode 100644
index a56d68566..000000000
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
+++ /dev/null
@@ -1,373 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import torch
-
-import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.distributed import get_dp_group, get_ep_group
-from vllm.distributed.device_communicators.base_device_communicator import (
- All2AllManagerBase,
-)
-from vllm.forward_context import get_forward_context
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
-from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
-)
-from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
- TopKWeightAndReduceNoOP,
-)
-from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
-from vllm.utils.flashinfer import nvfp4_block_scale_interleave
-
-
-def get_local_sizes():
- return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
-
-
-class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- """Base class for FlashInfer MoE prepare and finalize operations."""
-
- def __init__(
- self,
- use_dp: bool,
- num_dispatchers: int = 1,
- use_deepseek_fp8_block_scale: bool = False,
- ):
- super().__init__()
- self.num_dispatchers_ = num_dispatchers
- self.use_dp = use_dp
- self.local_tokens = None
- # Toggle for DeepSeek-style FP8 block-scale path where activations are
- # not quantized here and weight block scales are consumed by the kernel.
- self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale
-
- @property
- def activation_format(self) -> mk.FusedMoEActivationFormat:
- return mk.FusedMoEActivationFormat.Standard
-
- def max_num_tokens_per_rank(self) -> int | None:
- return None
-
- def topk_indices_dtype(self) -> torch.dtype | None:
- return None
-
- def num_dispatchers(self) -> int:
- return self.num_dispatchers_
-
- def output_is_reduced(self) -> bool:
- return False
-
- def _apply_router_weight_on_input(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- ) -> None:
- """Apply router weight on input if needed."""
- if apply_router_weight_on_input:
- topk = topk_ids.size(1)
- assert topk == 1, (
- "apply_router_weight_on_input is only implemented for topk=1"
- )
- a1.mul_(topk_weights.to(a1.dtype))
-
-
-class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize):
- """FlashInfer implementation using AllToAll communication."""
-
- def __init__(
- self,
- use_dp: bool,
- num_dispatchers: int = 1,
- use_deepseek_fp8_block_scale: bool = False,
- ):
- super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
- self.alltoall_info = None
-
- # Initialize all2all_manager only for DP case
- self.all2all_manager = None
- if self.use_dp:
- self.all2all_manager = get_ep_group().device_communicator.all2all_manager
-
- def prepare(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_experts: int,
- expert_map: torch.Tensor | None,
- apply_router_weight_on_input: bool,
- quant_config: FusedMoEQuantConfig,
- ) -> mk.PrepareResultType:
- self._apply_router_weight_on_input(
- a1, topk_weights, topk_ids, apply_router_weight_on_input
- )
-
- if not self.use_dp:
- # Non-DP case: quantize activations unless using block-scale path
- if not self.use_deepseek_fp8_block_scale:
- a1q, a1q_scale = moe_kernel_quantize_input(
- a1,
- quant_config.a1_gscale,
- quant_config.quant_dtype,
- quant_config.per_act_token_quant,
- quant_config.block_shape,
- is_fp4_scale_swizzled=not self.use_dp,
- )
- else:
- a1q = a1
- a1q_scale = None
- else:
- # DP case: use FlashInfer AllToAll
- global_num_tokens_cpu = get_local_sizes()
- top_k = topk_ids.size(1)
-
- (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = (
- flashinfer_alltoall_dispatch(
- self.all2all_manager,
- global_num_tokens_cpu,
- a1,
- quant_config.a1_gscale,
- topk_ids,
- topk_weights,
- top_k,
- num_experts,
- quant_config,
- use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
- )
- )
-
- return a1q, a1q_scale, None, topk_ids, topk_weights
-
- def finalize(
- self,
- output: torch.Tensor,
- fused_expert_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- weight_and_reduce_impl: mk.TopKWeightAndReduce,
- ) -> None:
- if self.use_dp:
- top_k = topk_ids.size(1)
- token_count = output.shape[0]
- fused_expert_output = flashinfer_alltoall_combine(
- self.all2all_manager,
- fused_expert_output,
- top_k=top_k,
- token_count=token_count,
- alltoall_info=self.alltoall_info,
- )
- output.copy_(fused_expert_output)
-
-
-class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize):
- def __init__(
- self,
- use_dp: bool,
- num_dispatchers: int = 1,
- use_deepseek_fp8_block_scale: bool = False,
- ):
- super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale)
-
- def prepare(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_experts: int,
- expert_map: torch.Tensor | None,
- apply_router_weight_on_input: bool,
- quant_config: FusedMoEQuantConfig,
- ) -> mk.PrepareResultType:
- self._apply_router_weight_on_input(
- a1, topk_weights, topk_ids, apply_router_weight_on_input
- )
- is_nvfp4 = quant_config.quant_dtype == "nvfp4"
- if not self.use_dp and is_nvfp4:
- return a1, None, None, topk_ids, topk_weights
-
- if not self.use_deepseek_fp8_block_scale:
- a1q, a1q_scale = moe_kernel_quantize_input(
- a1,
- quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale,
- quant_config.quant_dtype,
- quant_config.per_act_token_quant,
- quant_config.block_shape,
- is_fp4_scale_swizzled=not self.use_dp,
- )
- else:
- # Block-scale path: pass activations through, omit per-token scales
- a1q = a1
- a1q_scale = None
-
- if self.use_dp:
- # Build gather list conditionally - omit a1q_scale if None
- # (block-scale path)
- gather_list = [topk_weights, topk_ids, a1q]
- if a1q_scale is not None:
- gather_list.append(a1q_scale)
- gathered = get_dp_group().all_gatherv(
- gather_list,
- dim=0,
- sizes=get_local_sizes(),
- )
- topk_weights, topk_ids, a1q, a1q_scale = gathered
- else:
- gathered = get_dp_group().all_gatherv(
- gather_list,
- dim=0,
- sizes=get_local_sizes(),
- )
- topk_weights, topk_ids, a1q = gathered
- a1q_scale = None
-
- if is_nvfp4 and a1q_scale is not None:
- if a1q_scale.element_size() == 1:
- a1q_scale = a1q_scale.view(torch.uint8)
- a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
-
- return a1q, a1q_scale, None, topk_ids, topk_weights
-
- def finalize(
- self,
- output: torch.Tensor,
- fused_expert_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- weight_and_reduce_impl: mk.TopKWeightAndReduce,
- ) -> None:
- assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP)
-
- if self.use_dp:
- fused_expert_output = get_dp_group().reduce_scatterv(
- fused_expert_output, dim=0, sizes=get_local_sizes()
- )
- output.copy_(fused_expert_output)
-
-
-def flashinfer_alltoall_dispatch(
- all2all_manager: All2AllManagerBase,
- global_num_tokens_cpu: list[int],
- x: torch.Tensor,
- gs: torch.Tensor,
- topk_ids: torch.Tensor,
- topk_weights: torch.Tensor,
- top_k: int,
- num_experts: int,
- quant_config: FusedMoEQuantConfig,
- use_deepseek_fp8_block_scale: bool = False,
-):
- from flashinfer.comm.trtllm_alltoall import MnnvlMoe
-
- assert all2all_manager.ensure_alltoall_workspace_initialized(), (
- "FlashInfer AllToAll workspace not available"
- )
-
- ep_rank = all2all_manager.rank
- ep_size = all2all_manager.world_size
- max_num_token = (
- max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
- )
- orig_topk_weights_dtype = topk_weights.dtype
- alltoall_info, topk_ids, topk_weights, _ = (
- MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
- topk_ids,
- topk_weights,
- None,
- all2all_manager.prepare_workspace_tensor,
- max_num_token,
- ep_rank,
- ep_size,
- num_experts,
- num_experts,
- top_k,
- )
- )
- topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
-
- if not use_deepseek_fp8_block_scale:
- x, x_sf = moe_kernel_quantize_input(
- x,
- gs,
- quant_config.quant_dtype,
- quant_config.per_act_token_quant,
- quant_config.block_shape,
- is_fp4_scale_swizzled=False, # delay swizzle to after comm
- )
- x = MnnvlMoe.mnnvl_moe_alltoallv(
- x,
- alltoall_info,
- all2all_manager.workspace_tensor,
- ep_rank,
- ep_size,
- )
-
- x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
- x_sf,
- alltoall_info,
- all2all_manager.workspace_tensor,
- ep_rank,
- ep_size,
- )
- if quant_config.quant_dtype == "nvfp4":
- x_sf = nvfp4_block_scale_interleave(x_sf)
- else:
- # Block-scale path: pass activations through without quantization
- x_sf = None
- x = MnnvlMoe.mnnvl_moe_alltoallv(
- x,
- alltoall_info,
- all2all_manager.workspace_tensor,
- ep_rank,
- ep_size,
- )
- return alltoall_info, topk_ids, topk_weights, x, x_sf
-
-
-def flashinfer_alltoall_combine(
- all2all_manager: All2AllManagerBase,
- output: torch.Tensor,
- top_k: int,
- token_count: int,
- alltoall_info,
-):
- from flashinfer.comm.trtllm_alltoall import MnnvlMoe
-
- assert all2all_manager.ensure_alltoall_workspace_initialized(), (
- "FlashInfer AllToAll workspace not available"
- )
- return MnnvlMoe.mnnvl_moe_alltoallv_combine(
- output,
- alltoall_info,
- all2all_manager.workspace_tensor,
- ep_rank=all2all_manager.rank,
- ep_size=all2all_manager.world_size,
- top_k=top_k,
- token_count=token_count,
- )
-
-
-def create_flashinfer_prepare_finalize(
- use_dp: bool,
- use_nvfp4: bool = False,
- enable_alltoallv: bool = False,
- use_deepseek_fp8_block_scale: bool = False,
-) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP:
- """Factory function to create the appropriate FlashInfer implementation."""
-
- if use_dp:
- if enable_alltoallv:
- assert use_nvfp4
- return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
- return FlashInferAllGatherMoEPrepareAndFinalize(
- use_dp=True,
- use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
- )
- else:
- # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization
- # in a single call with the MoE experts kernel.
- defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4
- return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant)
diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
index 509bacfbc..c681e083a 100644
--- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
@@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> mk.PrepareResultType:
+ if defer_input_quant:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not support defer_input_quant=True. "
+ "Please select an MoE kernel that accepts quantized inputs."
+ )
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index ce77eafa6..2e5167bdf 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -593,7 +593,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- return True
+ return not moe_parallel_config.use_fi_all2allv_kernels
@property
def quant_type_id(self) -> int:
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 669a6e74b..0206e19de 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -1951,7 +1951,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- return True
+ return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
index 36562142c..3ad56cc4c 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
@@ -5,6 +5,7 @@ from abc import abstractmethod
import torch
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -26,6 +27,19 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__()
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
+ self.moe_mk: mk.FusedMoEModularKernel | None = None
+
+ @property
+ def supports_internal_mk(self) -> bool:
+ # NOTE(rob): temporary attribute to indicate support for
+ # completed migration to the new internal MK interface.
+ return self.moe_mk is not None
+
+ @property
+ def mk_owns_shared_expert(self) -> bool:
+ # NOTE(rob): temporary attribute to indicate support for
+ # completed migration to the new internal MK interface.
+ return self.moe_mk is not None and self.moe_mk.shared_experts is not None
@abstractmethod
def create_weights(
@@ -91,6 +105,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property
def topk_indices_dtype(self) -> torch.dtype | None:
+ if self.moe_mk is not None:
+ return self.moe_mk.prepare_finalize.topk_indices_dtype()
return None
@property
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
index 53df0dc85..7a2244a9b 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
):
super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config
- self.fused_experts = experts
+ self.moe_mk = experts
self.disable_expert_map = getattr(
old_quant_method,
"disable_expert_map",
- not self.fused_experts.supports_expert_map(),
+ not self.moe_mk.supports_expert_map(),
)
self.old_quant_method = old_quant_method
assert not self.old_quant_method.is_monolithic
@@ -57,10 +57,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
),
)
- @property
- def topk_indices_dtype(self) -> torch.dtype | None:
- return self.fused_experts.prepare_finalize.topk_indices_dtype()
-
@property
def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb
@@ -96,7 +92,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- return self.fused_experts(
+ assert self.moe_mk is not None
+ return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index daf2e0e6b..e38275004 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -571,9 +571,6 @@ class FusedMoE(CustomOp):
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
)
- self.moe_config_use_flashinfer_cutlass_kernels = (
- self.moe_config.use_flashinfer_cutlass_kernels
- )
if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now."
@@ -646,6 +643,11 @@ class FusedMoE(CustomOp):
# This is called after all weight loading and post-processing, so it
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
+ # NOTE(rob): WIP refactor. For quant methods that own the MK
+ # we create the MK during process_weights_after_loading.
+ if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
+ return None
+
self.ensure_moe_quant_config_init()
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
@@ -728,14 +730,6 @@ class FusedMoE(CustomOp):
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels
- @property
- def use_flashinfer_cutlass_kernels(self):
- return (
- self.moe_quant_config is not None
- and self.moe_quant_config.quant_dtype == "nvfp4"
- and self.moe_config_use_flashinfer_cutlass_kernels
- )
-
@property
def use_marlin_kernels(self):
return getattr(self.quant_method, "use_marlin", False)
@@ -746,7 +740,7 @@ class FusedMoE(CustomOp):
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
- or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
+ or self.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@property
@@ -1532,7 +1526,7 @@ class FusedMoE(CustomOp):
assert self.quant_method is not None
return (
isinstance(self.quant_method, FusedMoEModularMethod)
- and self.quant_method.fused_experts.output_is_reduced()
+ and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr]
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
@@ -1765,7 +1759,7 @@ class FusedMoE(CustomOp):
self.ensure_dp_chunking_init()
has_separate_shared_experts = (
- not isinstance(self.quant_method, FusedMoEModularMethod)
+ not self.quant_method.mk_owns_shared_expert
and self.shared_experts is not None
)
@@ -1789,8 +1783,10 @@ class FusedMoE(CustomOp):
hidden_states, router_logits, has_separate_shared_experts
)
- do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance(
- self.quant_method, FusedMoEModularMethod
+ # NOTE(rob): once we finish migrating all the quant methods to use
+ # MKs, we can remove the naive dispatch/combine path from here.
+ do_naive_dispatch_combine = (
+ self.dp_size > 1 and not self.quant_method.supports_internal_mk
)
ctx = get_forward_context()
@@ -1818,7 +1814,7 @@ class FusedMoE(CustomOp):
else:
hidden_states_to_dispatch = hidden_states
- dispatch_res = get_ep_group().dispatch(
+ dispatch_res = get_ep_group().dispatch_router_logits(
hidden_states_to_dispatch,
router_logits,
self.is_sequence_parallel,
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index 6aff6401e..940a2c55f 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool,
) -> PrepareResultType:
"""
Perform any quantization (and/or) dispatching needed for this kernel.
@@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
+ - defer_input_quant: Runtime parameter indicating whether or not to
+ defer input quantization to the FusedMoEPermuteExpertsUnpermute
+ in cases where the compute kernel expects unquantized inputs
Returns a tuple of:
- quantized + dispatched a.
@@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool,
) -> tuple[Callable, ReceiverType] | ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
@@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC):
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
+ - defer_input_quant: Runtime parameter indicating whether or not to
+ defer input quantization to the FusedMoEPermuteExpertsUnpermute
+ in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
@@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
- @staticmethod
- def expects_unquantized_inputs(
- moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig
- ) -> bool:
+ @property
+ def expects_unquantized_inputs(self) -> bool:
"""
Whether or not the PrepareFinalize should defer input quantization
in the prepare step. If True, then the Experts kernel will
@@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
+ defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
else:
# Overlap shared expert compute with all2all dispatch.
@@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
+ defer_input_quant=self.fused_experts.expects_unquantized_inputs,
)
# TODO(lucas): refactor this in the alternative schedules followup
diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
index 930e7ae3f..dc0f32dc1 100644
--- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
@@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> mk.PrepareResultType:
"""
Returns a tuple of:
@@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
+ if defer_input_quant:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not support defer_input_quant=True. "
+ "Please select an MoE kernel that accepts quantized inputs."
+ )
assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now."
)
diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
index dd0ad4523..15fc6e237 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
@@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.all2all_utils import (
+ maybe_make_prepare_finalize,
+)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm,
)
-from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
-)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
@@ -465,68 +465,52 @@ def make_fp8_moe_quant_config(
)
-def make_fp8_moe_kernel_for_mkm(
+def make_fp8_moe_kernel(
+ moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
- quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
-) -> mk.FusedMoEPermuteExpertsUnpermute:
+ fp8_backend: Fp8MoeBackend,
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ shared_experts: torch.nn.Module | None = None,
+) -> tuple[mk.FusedMoEModularKernel, bool]:
+ # Create Prepare/Finalize.
+ prepare_finalize = maybe_make_prepare_finalize(
+ moe=moe_config,
+ quant_config=moe_quant_config,
+ routing_tables=routing_tables,
+ allow_new_interface=True,
+ )
+ assert prepare_finalize is not None
+
+ logger.info_once("Using %s", prepare_finalize.__class__.__name__)
+
+ # Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
- max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
- assert max_num_tokens_per_rank is not None
+ max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
+ assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
- quant_config=quant_config,
- max_num_tokens=max_num_tokens_per_rank,
+ quant_config=moe_quant_config,
+ max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
- quant_config=quant_config,
+ quant_config=moe_quant_config,
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
-
-
-def make_fp8_moe_kernel(
- moe_quant_config: FusedMoEQuantConfig,
- moe_config: FusedMoEConfig,
- fp8_backend: Fp8MoeBackend,
- experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
-) -> tuple[mk.FusedMoEModularKernel, bool]:
- # TODO(rob): unify after we merge tp and dp/ep.
- if (
- moe_config.moe_parallel_config.use_all2all_kernels
- and moe_config.moe_parallel_config.all2all_backend
- not in ["allgather_reducescatter", "naive"]
- ):
- raise ValueError(
- "Fp8 Oracle should not create non-naive A2A P/F. "
- "This should happen via the ModularKernelMethod."
- )
-
- # Create Prepare/Finalize.
- prepare_finalize = MoEPrepareAndFinalizeNoEP(
- defer_input_quant=experts_cls.expects_unquantized_inputs(
- moe_config, moe_quant_config
- ),
- )
-
- # Create Experts.
- experts = experts_cls(
- moe_config=moe_config,
- quant_config=moe_quant_config,
- )
-
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
- shared_experts=None,
+ shared_experts=(
+ shared_experts
+ if moe_config.moe_parallel_config.use_all2all_kernels
+ else None
+ ),
moe_parallel_config=moe_config.moe_parallel_config,
)
diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
index 897631e23..276d231eb 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
@@ -7,6 +7,9 @@ import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.all2all_utils import (
+ maybe_make_prepare_finalize,
+)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
-from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
-)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
@@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config(
)
-def make_nvfp4_moe_kernel_for_mkm(
+def make_nvfp4_moe_kernel(
+ moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
- quant_config: FusedMoEQuantConfig,
experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
-) -> mk.FusedMoEPermuteExpertsUnpermute:
+ routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ shared_experts: torch.nn.Module | None = None,
+) -> mk.FusedMoEModularKernel:
+ # Create Prepare/Finalize.
+ prepare_finalize = maybe_make_prepare_finalize(
+ moe=moe_config,
+ quant_config=moe_quant_config,
+ routing_tables=routing_tables,
+ allow_new_interface=True,
+ )
+ assert prepare_finalize is not None
+
+ logger.info_once("Using %s", prepare_finalize.__class__.__name__)
+
+ # Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
- max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
- assert max_num_tokens_per_rank is not None
+ max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
+ assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
- quant_config=quant_config,
- max_num_tokens=max_num_tokens_per_rank,
+ quant_config=moe_quant_config,
+ max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
- quant_config=quant_config,
+ quant_config=moe_quant_config,
)
- logger.debug_once("Using %s", experts.__class__.__name__)
- return experts
-
-
-def make_nvfp4_moe_kernel(
- moe_quant_config: FusedMoEQuantConfig,
- moe_config: FusedMoEConfig,
- experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
-) -> mk.FusedMoEModularKernel:
- # TODO(rob): unify after we merge tp and dp/ep.
- if (
- moe_config.moe_parallel_config.use_all2all_kernels
- and moe_config.moe_parallel_config.all2all_backend
- not in ["allgather_reducescatter", "naive"]
- ):
- raise ValueError(
- "NvFP4 Oracle should not create non-naive A2A P/F. "
- "This should happen via the ModularKernelMethod."
- )
-
- # Create Prepare/Finalize.
- prepare_finalize = MoEPrepareAndFinalizeNoEP(
- defer_input_quant=experts_cls.expects_unquantized_inputs(
- moe_config, moe_quant_config
- ),
- )
-
- # Create Experts.
- experts = experts_cls(
- moe_config=moe_config,
- quant_config=moe_quant_config,
- )
-
# NOTE(rob): we only want the mk to control the shared_expert
# if using all2all (for SBO). bnell is making this explict in
# the new MoE runner class.
kernel = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
- shared_experts=None,
+ shared_experts=(
+ shared_experts
+ if moe_config.moe_parallel_config.use_all2all_kernels
+ else None
+ ),
moe_parallel_config=moe_config.moe_parallel_config,
)
diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
index 8c4da9711..78b941498 100644
--- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
@@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
+ if defer_input_quant:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not support defer_input_quant=True. "
+ "Please select an MoE kernel that accepts quantized inputs."
+ )
+
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
@@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
a1,
@@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map,
apply_router_weight_on_input,
quant_config,
+ defer_input_quant=defer_input_quant,
)
hook()
return receiver()
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py
index 5d806fa84..d10476702 100644
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py
@@ -4,19 +4,133 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.distributed import get_ep_group
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceContiguous,
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
+from vllm.utils.flashinfer import nvfp4_block_scale_interleave
+
+
+class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
+ def __init__(
+ self,
+ is_sequence_parallel: bool = False,
+ num_dispatchers: int = 1,
+ ) -> None:
+ super().__init__()
+ self.is_sequence_parallel = is_sequence_parallel
+ self._num_dispatchers = num_dispatchers
+
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return self._num_dispatchers
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> mk.PrepareResultType:
+ if apply_router_weight_on_input:
+ topk = topk_ids.size(1)
+ assert topk == 1, (
+ "apply_router_weight_on_input is only implemented for topk=1"
+ )
+ # Note: do not use inplace for shared experts overlap
+ a1 = a1 * topk_weights.to(a1.dtype)
+
+ # Defer input quantization to the MoE kernel.
+ use_nvfp4 = quant_config.use_nvfp4_w4a4
+ if defer_input_quant:
+ a1q = a1
+ a1q_scale = None
+ else:
+ a1q, a1q_scale = moe_kernel_quantize_input(
+ a1,
+ quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
+ quant_config.quant_dtype,
+ quant_config.per_act_token_quant,
+ quant_config.block_shape,
+ # NOTE: swizzling pads the scales to multiple of 128
+ # which makes the scales tensor different shape than
+ # the hidden states, breaking the A2A kernel. So, we
+ # delay the swizzling until after the A2A.
+ is_fp4_scale_swizzled=False,
+ )
+
+ # Skip gathering scales if we have static quantization
+ # (the scale is a scalar, replicated on all ranks) or
+ # if quantization is deferred.
+ skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
+ scales = None if skip_gather_scales else [a1q_scale]
+
+ res = get_ep_group().dispatch(
+ a1q,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel=self.is_sequence_parallel,
+ extra_tensors=scales,
+ )
+ if skip_gather_scales:
+ a1q, topk_weights, topk_ids = res
+ else:
+ a1q, topk_weights, topk_ids, scales = res
+ assert scales is not None and len(scales) == 1
+ a1q_scale = scales[0]
+ if quant_config.quant_dtype == "nvfp4":
+ assert a1q_scale is not None
+ if a1q_scale.element_size() == 1:
+ a1q_scale = a1q_scale.view(torch.uint8)
+ a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
+
+ return a1q, a1q_scale, None, topk_ids, topk_weights
+
+ def finalize(
+ self,
+ output: torch.Tensor,
+ fused_expert_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ apply_router_weight_on_input: bool,
+ weight_and_reduce_impl: mk.TopKWeightAndReduce,
+ ) -> None:
+ if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
+ weight_and_reduce_impl = TopKWeightAndReduceContiguous()
+
+ out = weight_and_reduce_impl.apply(
+ output=None,
+ fused_expert_output=fused_expert_output,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ )
+
+ output.copy_(
+ get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
+ )
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
- def __init__(self, defer_input_quant: bool = False) -> None:
- super().__init__()
- self.defer_input_quant = defer_input_quant
-
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
@@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
# Defer input quant to moe kernel for backends (e.g. AITER, FI)
# which use a single kernel call for quant + experts.
- if self.defer_input_quant:
+ if defer_input_quant:
return a1, None, None, None, None
+ input_sf = (
+ quant_config.a1_gscale
+ if quant_config.use_nvfp4_w4a4
+ else quant_config.a1_scale
+ )
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
- quant_config.a1_scale,
+ input_sf,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
index 81ed0c504..33150da6f 100644
--- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
@@ -287,17 +287,14 @@ def rocm_aiter_fused_experts(
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
+ @property
+ def expects_unquantized_inputs(self) -> bool:
+ return True
+
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
- @staticmethod
- def expects_unquantized_inputs(
- fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig
- ) -> bool:
- # AITER fused MoE kernels handle input quantization internally.
- return True
-
@staticmethod
def _supports_current_device() -> bool:
return rocm_aiter_ops.is_fused_moe_enabled()
@@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- return True
+ return not moe_parallel_config.use_fi_all2allv_kernels
def supports_expert_map(self):
return True
diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
index a143347b1..63ddb24b8 100644
--- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py
@@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE):
use_overlapped
and not (
(self.enable_eplb and backend != "allgather_reducescatter")
- or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
+ or self.moe_parallel_config.use_fi_all2allv_kernels
)
and self._shared_experts is not None
)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index ec120cab4..f26ddfb87 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
- make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
@@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
is_global_sf_supported_for_nvfp4_backend,
make_mxfp4_moe_quant_config,
make_nvfp4_moe_kernel,
- make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
@@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
- build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
@@ -243,7 +240,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.group_size = 32
self.mxfp4_backend = NvFp4MoeBackend.MARLIN
self.experts_cls = MarlinExperts
- self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
@@ -320,7 +316,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale
)
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight_packed.data, requires_grad=False
)
@@ -335,10 +331,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
- self.kernel = make_nvfp4_moe_kernel(
+ self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
+ shared_experts=layer.shared_experts,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
)
def apply(
@@ -348,8 +346,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.kernel is not None
- return self.kernel(
+ assert self.moe_mk is not None
+ return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
@@ -380,19 +378,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
activation_key=None if use_a16 else kNvfp4Dynamic,
)
- # Delay creation of the kernel until after process-weights.
- self.kernel: mk.FusedMoEModularKernel | None = None
-
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
- @property
- def topk_indices_dtype(self) -> torch.dtype | None:
- if self.kernel is not None:
- return self.kernel.prepare_finalize.topk_indices_dtype()
- return None
-
def create_weights(
self,
layer: torch.nn.Module,
@@ -506,7 +495,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ def process_weights_after_loading(self, layer: FusedMoE) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
@@ -572,48 +561,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config and (
- (not self.moe.moe_parallel_config.use_all2all_kernels)
- or self.moe.moe_parallel_config.use_naive_all2all_kernels
- ):
+ if self.moe_quant_config:
assert self.experts_cls is not None
- self.kernel = make_nvfp4_moe_kernel(
+ self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
+ shared_experts=layer.shared_experts,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- return None
- elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- # For no-EP case, don't use the MKM framework.
- if not self.moe.moe_parallel_config.use_all2all_kernels:
- return None
-
- prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- self.moe,
- use_deepseek_fp8_block_scale=False,
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- assert self.experts_cls is not None
- return make_nvfp4_moe_kernel_for_mkm(
- moe_config=self.moe,
- quant_config=self.moe_quant_config,
- experts_cls=self.experts_cls,
- prepare_finalize=prepare_finalize,
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
)
def get_fused_moe_quant_config(
@@ -684,8 +658,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts,
)
else:
- assert self.kernel is not None
- return self.kernel(
+ assert self.moe_mk is not None
+ return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
@@ -759,15 +733,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
allow_vllm_cutlass=True,
)
- # Delay creation of the kernel until after process-weights.
- self.kernel: mk.FusedMoEModularKernel | None = None
-
- @property
- def topk_indices_dtype(self) -> torch.dtype | None:
- if self.kernel is not None:
- return self.kernel.prepare_finalize.topk_indices_dtype()
- return None
-
def create_weights(
self,
layer: torch.nn.Module,
@@ -927,7 +892,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ def process_weights_after_loading(self, layer: FusedMoE) -> None:
# Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight
w2 = layer.w2_weight
@@ -989,49 +954,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config and (
- (not self.moe.moe_parallel_config.use_all2all_kernels)
- or self.moe.moe_parallel_config.use_naive_all2all_kernels
- ):
+ if self.moe_quant_config:
assert self.experts_cls is not None
- self.kernel, self.use_inplace = make_fp8_moe_kernel(
+ self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
+ shared_experts=layer.shared_experts,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- return None
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- # For no-EP case, don't use the MKM framework.
- if not self.moe.moe_parallel_config.use_all2all_kernels:
- return None
-
- prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- self.moe,
- use_deepseek_fp8_block_scale=self.block_quant,
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
- ) -> FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- assert self.experts_cls is not None
- return make_fp8_moe_kernel_for_mkm(
- moe_config=self.moe,
- quant_config=self.moe_quant_config,
- experts_cls=self.experts_cls,
- prepare_finalize=prepare_finalize,
+ ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
)
def get_fused_moe_quant_config(
@@ -1120,8 +1070,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
- assert self.kernel is not None
- return self.kernel(
+ assert self.moe_mk is not None
+ return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index fe59022cb..60600e1e3 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
- make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
- build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_vllm_cutlass=False,
)
- # Delay creation of the kernel until after process-weights.
- self.kernel: mk.FusedMoEModularKernel | None = None
-
- @property
- def topk_indices_dtype(self) -> torch.dtype | None:
- if self.kernel is not None:
- return self.kernel.prepare_finalize.topk_indices_dtype()
- return None
-
def create_weights(
self,
layer: Module,
@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel(
self,
- layer: Module,
+ layer: FusedMoE,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config and (
- (not self.moe.moe_parallel_config.use_all2all_kernels)
- or self.moe.moe_parallel_config.use_naive_all2all_kernels
- ):
+ if self.moe_quant_config:
assert self.experts_cls is not None
- self.kernel, self.use_inplace = make_fp8_moe_kernel(
+ self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
+ shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: Module) -> None:
@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- return None
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- # For no-EP case, don't use the MKM framework.
- if not self.moe.moe_parallel_config.use_all2all_kernels:
- return None
-
- prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- self.moe,
- use_deepseek_fp8_block_scale=self.block_quant,
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- assert self.experts_cls is not None
- return make_fp8_moe_kernel_for_mkm(
- moe_config=self.moe,
- quant_config=self.moe_quant_config,
- experts_cls=self.experts_cls,
- prepare_finalize=prepare_finalize,
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
)
def get_fused_moe_quant_config(
@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.kernel is not None
+ assert self.moe_mk is not None
assert not self.is_monolithic
- return self.kernel(
+ return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 476ad618e..e10144ed1 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
- make_fp8_moe_kernel_for_mkm,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
@@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
- make_nvfp4_moe_kernel_for_mkm,
make_nvfp4_moe_quant_config,
select_nvfp4_moe_backend,
)
@@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
flashinfer_trtllm_fp4_routed_moe,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
- build_flashinfer_fp8_cutlass_moe_prepare_finalize,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
@@ -739,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
activation_key=kFp8StaticTensorSym,
)
- # Delay creation of the kernel until after process-weights.
- self.kernel: mk.FusedMoEModularKernel | None = None
-
- @property
- def topk_indices_dtype(self) -> torch.dtype | None:
- if self.kernel is not None:
- return self.kernel.prepare_finalize.topk_indices_dtype()
- return None
-
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- # TRT LLM not supported with all2all yet.
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- return None
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- # For no-EP case, don't use the MKM framework.
- if not self.moe.moe_parallel_config.use_all2all_kernels:
- return None
-
- prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- self.moe,
- use_deepseek_fp8_block_scale=False,
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- assert self.experts_cls is not None
- return make_fp8_moe_kernel_for_mkm(
- moe_config=self.moe,
- quant_config=self.moe_quant_config,
- experts_cls=self.experts_cls,
- prepare_finalize=prepare_finalize,
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
)
def create_weights(
@@ -863,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def _setup_kernel(
self,
- layer: torch.nn.Module,
+ layer: FusedMoE,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
@@ -893,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
- self.kernel, self.use_inplace = make_fp8_moe_kernel(
+ self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
+ shared_experts=layer.shared_experts,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -998,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
f"but got {layer.activation}"
)
- assert self.kernel is not None
- return self.kernel(
+ assert self.moe_mk is not None
+ return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
@@ -1379,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
activation_key=kNvfp4Dynamic,
)
- # Delay creation of the kernel until after process-weights.
- self.kernel: mk.FusedMoEModularKernel | None = None
-
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend
)
- @property
- def topk_indices_dtype(self) -> torch.dtype | None:
- if self.kernel is not None:
- return self.kernel.prepare_finalize.topk_indices_dtype()
- return None
-
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- return None
- elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
- # For no-EP case, don't use the MKM framework.
- if not self.moe.moe_parallel_config.use_all2all_kernels:
- return None
- # For now, fp4 moe only works with the flashinfer dispatcher.
- prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
- self.moe
- )
- logger.debug_once("%s", prepare_finalize.__class__.__name__)
- return prepare_finalize
- else:
- return super().maybe_make_prepare_finalize(routing_tables)
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
+ )
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
- assert self.moe_quant_config is not None
- assert self.experts_cls is not None
- return make_nvfp4_moe_kernel_for_mkm(
- moe_config=self.moe,
- quant_config=self.moe_quant_config,
- experts_cls=self.experts_cls,
- prepare_finalize=prepare_finalize,
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
)
def uses_weight_scale_2_pattern(self) -> bool:
@@ -1547,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
layer.register_parameter("w2_input_scale", w2_input_scale)
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ def process_weights_after_loading(self, layer: FusedMoE) -> None:
"""
Convert NVFP4 MoE weights into kernel format and setup the kernel.
"""
@@ -1599,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
# in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config and (
- (not self.moe.moe_parallel_config.use_all2all_kernels)
- or self.moe.moe_parallel_config.use_naive_all2all_kernels
- ):
+ if self.moe_quant_config:
assert self.experts_cls is not None
- self.kernel = make_nvfp4_moe_kernel(
+ self.moe_mk = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
+ shared_experts=layer.shared_experts,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
)
@property
@@ -1708,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=layer.global_num_experts,
)
else:
- assert self.kernel is not None
- return self.kernel(
+ assert self.moe_mk is not None
+ return self.moe_mk(
x,
layer.w13_weight,
layer.w2_weight,
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
index 52d0a5c47..ae5a934fb 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
@@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
RoutingMethodType,
)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- create_flashinfer_prepare_finalize,
-)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
@@ -42,7 +39,6 @@ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1",
- "build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
#
@@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1(
)
-def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
- moe: FusedMoEConfig,
-) -> mk.FusedMoEPrepareAndFinalize:
- """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
- use_dp = moe.moe_parallel_config.dp_size > 1
- enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
- return create_flashinfer_prepare_finalize(
- use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
- )
-
-
def prepare_static_weights_for_trtllm_fp4_moe(
# args_dequant,
# args,
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index d9824c107..cd82b5432 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -4,15 +4,8 @@ from enum import Enum
import torch
-import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import (
- FusedMoEConfig,
-)
-from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
- create_flashinfer_prepare_finalize,
-)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
@@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi(
return g1_alphas, g2_alphas
-def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False
-) -> mk.FusedMoEPrepareAndFinalize:
- """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
- use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
- # Propagate block-scale flag so prepare/finalize can skip act quantization
- # and inform the kernel to consume per-block weight scales.
- return create_flashinfer_prepare_finalize(
- use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
- )
-
-
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py
index 21409de86..cd4efe1ca 100644
--- a/vllm/model_executor/warmup/deep_gemm_warmup.py
+++ b/vllm/model_executor/warmup/deep_gemm_warmup.py
@@ -14,7 +14,6 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod
-from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
@@ -169,9 +168,10 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
# modular kernels could invoke deep_gemm_moe_fp8
return True
- mk: FusedMoEModularKernel = module.quant_method.fused_experts
# Further check if the ModularKernel implementation uses the DeepGemmExperts
- return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts))
+ return isinstance(
+ module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts)
+ )
FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()