diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 0e7ab7255..c239cb5d0 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1131,7 +1131,7 @@ steps: - csrc/quantization/cutlass_w8a8/moe/ - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/mla/cutlass_mla.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c2587dd81..d1a536a07 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1017,7 +1017,7 @@ steps: - csrc/quantization/cutlass_w8a8/moe/ - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/mla/cutlass_mla.py diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index 2772d69b4..8641a18b4 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -85,7 +85,7 @@ steps: - csrc/quantization/cutlass_w8a8/moe/ - vllm/model_executor/layers/fused_moe/cutlass_moe.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py - - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py + - vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/mla/cutlass_mla.py diff --git a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py index 4894d37c4..c1f4f0aa9 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py @@ -197,7 +197,7 @@ def bench_run( ) kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( make_dummy_moe_config(), quant_config=quant_config, @@ -242,7 +242,7 @@ def bench_run( ) kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( make_dummy_moe_config(), quant_config=quant_config, diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 022c4f2e8..75ebee6ec 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -36,8 +36,7 @@ th { | pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] | | deepep_high_throughput | standard | fp8 | G(128),A,T2 | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] | | deepep_low_latency | batched | fp8 | G(128),A,T3 | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] | -| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] | -| flashinfer4 | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] | +| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] | | MoEPrepareAndFinalizeNoEP5 | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] | | BatchedPrepareAndFinalize5 | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] | diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 07964031f..4ee18e342 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -22,6 +22,9 @@ from vllm.distributed import ( ) from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -40,7 +43,6 @@ from .mk_objects import ( TestMoEQuantConfig, expert_info, make_fused_experts, - make_prepare_finalize, prepare_finalize_info, ) from .parallel_utils import ProcessGroupInfo @@ -603,10 +605,12 @@ def make_modular_kernel( routing_method=RoutingMethodType.DeepSeekV3, ) - # make modular kernel - prepare_finalize = make_prepare_finalize( - config.prepare_finalize_type, config.all2all_backend(), moe, quant_config + prepare_finalize = maybe_make_prepare_finalize( + moe=moe, + quant_config=quant_config, + allow_new_interface=True, ) + assert prepare_finalize is not None fused_experts = make_fused_experts( config.fused_experts_type, diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index f62e11641..d215d2ab6 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -7,9 +7,6 @@ import torch # Fused experts and PrepareFinalize imports import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe import TritonExperts -from vllm.model_executor.layers.fused_moe.all2all_utils import ( - maybe_make_prepare_finalize, -) from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) @@ -255,13 +252,12 @@ if has_pplx(): ) if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): + from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize, + ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize, - create_flashinfer_prepare_finalize, - ) register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, @@ -429,24 +425,6 @@ if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): ] -def make_prepare_finalize( - prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, - backend: str | None, - moe: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, -) -> mk.FusedMoEPrepareAndFinalize: - if backend != "naive" and backend is not None: - prepare_finalize = maybe_make_prepare_finalize(moe, quant_config) - assert prepare_finalize is not None - return prepare_finalize - elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: - return create_flashinfer_prepare_finalize( - use_dp=moe.moe_parallel_config.dp_size > 1 - ) - else: - return MoEPrepareAndFinalizeNoEP() - - def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor: s = rank * num_local_experts e = s + num_local_experts diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index e5752e404..a484a3f21 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -294,12 +294,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( ) kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - defer_input_quant=FlashInferExperts.expects_unquantized_inputs( - moe_config=moe_config, - quant_config=quant_config, - ) - ), + MoEPrepareAndFinalizeNoEP(), FlashInferExperts( moe_config=moe_config, quant_config=quant_config, diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 66ceb0c64..9bb61ddfa 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -106,12 +106,7 @@ def test_flashinfer_fp4_moe_no_graph( ) flashinfer_experts = FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP( - defer_input_quant=FlashInferExperts.expects_unquantized_inputs( - moe_config=moe_config, - quant_config=quant_config, - ) - ), + MoEPrepareAndFinalizeNoEP(), FlashInferExperts(moe_config=moe_config, quant_config=quant_config), ) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 0011149e3..a22b2088b 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -90,7 +90,7 @@ def test_cutlass_fp4_moe_no_graph( ) kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + MoEPrepareAndFinalizeNoEP(), CutlassExpertsFp4( moe_config=make_dummy_moe_config(), quant_config=quant_config, diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 97a9a2815..678cd4580 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase): return buffer - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase): return hidden_states, router_logits + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if extra_tensors is not None: + raise NotImplementedError( + "extra_tensors is not supported for NaiveAll2AllManager" + ) + sp_size = self.tp_group.world_size if is_sequence_parallel else 1 + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) + + hidden_states = self.naive_multicast( + hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel + ) + topk_weights = self.naive_multicast( + topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel + ) + topk_ids = self.naive_multicast( + topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel + ) + return hidden_states, topk_weights, topk_ids + def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: @@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase): return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) return gathered_tensors[0], gathered_tensors[1] + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + """ + Gather hidden_states and router_logits from all dp ranks. + """ + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + assert sizes is not None + dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() + assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] + + tensors_to_gather = [hidden_states, topk_weights, topk_ids] + if extra_tensors is not None: + tensors_to_gather.extend(extra_tensors) + + gathered_tensors = dist_group.all_gatherv( + tensors_to_gather, + dim=0, + sizes=sizes, + ) + + hidden_states = gathered_tensors[0] + topk_weights = gathered_tensors[1] + topk_ids = gathered_tensors[2] + + if extra_tensors is None: + return hidden_states, topk_weights, topk_ids + + return hidden_states, topk_weights, topk_ids, gathered_tensors[3:] + def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: @@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase): pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, ) - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase): ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + raise NotImplementedError + def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: @@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): def get_handle(self, kwargs): raise NotImplementedError - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + raise NotImplementedError + def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index b09d5f44d..572bac80f 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from typing import Any from weakref import WeakValueDictionary import torch @@ -64,13 +63,32 @@ class All2AllManagerBase: # and reuse it for the same config. raise NotImplementedError - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> Any: + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + # Subclasses should either: + # - implement handling for extra_tensors, or + # - raise a clear error if extra_tensors is not supported. + raise NotImplementedError + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): # Subclasses should either: # - implement handling for extra_tensors, or # - raise a clear error if extra_tensors is not supported. @@ -280,7 +298,7 @@ class DeviceCommunicatorBase: for module in moe_modules: module.maybe_init_modular_kernel() - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -294,8 +312,29 @@ class DeviceCommunicatorBase: Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class. """ + if extra_tensors is not None: + return hidden_states, router_logits, extra_tensors return hidden_states, router_logits + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + """ + Dispatch the hidden states and topk weights/ids to the appropriate device. + This is a no-op in the base class. + """ + if extra_tensors is not None: + return hidden_states, topk_weights, topk_ids, extra_tensors + return hidden_states, topk_weights, topk_ids + def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 25db6a2eb..b5fbdfcc3 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase): ) -> dict[str, torch.Tensor | Any]: return self.dist_module.recv_tensor_dict(src) - def dispatch( # type: ignore[override] + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + """ + Dispatch the hidden states and router logits to the appropriate device. + This is a no-op in the base class. + """ + assert self.all2all_manager is not None - return self.all2all_manager.dispatch( + return self.all2all_manager.dispatch_router_logits( hidden_states, router_logits, is_sequence_parallel, - extra_tensors, # type: ignore[call-arg] + extra_tensors, + ) + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + """ + Dispatch the hidden states and topk weights/ids to the appropriate device. + This is a no-op in the base class. + """ + assert self.all2all_manager is not None + return self.all2all_manager.dispatch( + hidden_states, + topk_weights, + topk_ids, + is_sequence_parallel, + extra_tensors=extra_tensors, ) def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: + """ + Combine the hidden states and router logits from the appropriate device. + This is a no-op in the base class. + """ assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine( - hidden_states, is_sequence_parallel + return self.all2all_manager.combine( + hidden_states, + is_sequence_parallel, ) - return hidden_states class _CPUSHMDistributed: diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index f67ddfd31..4c78871e1 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase): return output_list - def dispatch( # type: ignore[override] + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase): tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] ): + """ + Dispatch the hidden states and router logits to the appropriate device. + This is a no-op in the base class. + """ + assert self.all2all_manager is not None - return self.all2all_manager.dispatch( + return self.all2all_manager.dispatch_router_logits( hidden_states, router_logits, is_sequence_parallel, - extra_tensors, # type: ignore[call-arg] + extra_tensors, + ) + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + """ + Dispatch the hidden states and topk weights/ids to the appropriate device. + This is a no-op in the base class. + """ + assert self.all2all_manager is not None + return self.all2all_manager.dispatch( + hidden_states, + topk_weights, + topk_ids, + is_sequence_parallel, + extra_tensors=extra_tensors, ) def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: + """ + Combine the hidden states and router logits from the appropriate device. + This is a no-op in the base class. + """ assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine( - hidden_states, is_sequence_parallel + return self.all2all_manager.combine( + hidden_states, + is_sequence_parallel, ) - return hidden_states diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py index 61aee2db4..81f4ae207 100644 --- a/vllm/distributed/device_communicators/mnnvl_compat.py +++ b/vllm/distributed/device_communicators/mnnvl_compat.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import torch.distributed as dist from flashinfer.comm.mnnvl import CommBackend as CommBackend @@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend): dist.all_gather_object(gathered, data, group=self._group) return gathered + # NOTE(rob): CommBackend is an abstract class, and bcast/barrier + # are unimplemented on vLLM side. If we need to utilize these + # methods in the future, can create a concrete implementation. + def bcast(self, data: Any, root: int) -> Any: + raise NotImplementedError + + def barrier(self) -> None: + raise NotImplementedError + def Split(self, color: int, key: int) -> "CustomCommunicator": return self diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 6bc26b6f3..85c7f18e3 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase): def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: dist.broadcast(input_, src=src, group=self.device_group) - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + """ + Dispatch the hidden states and router logits to the appropriate device. + This is a no-op in the base class. + """ + assert self.all2all_manager is not None - return self.all2all_manager.dispatch( + return self.all2all_manager.dispatch_router_logits( hidden_states, router_logits, is_sequence_parallel, - extra_tensors, # type: ignore[call-arg] + extra_tensors, + ) + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + """ + Dispatch the hidden states and topk weights/ids to the appropriate device. + This is a no-op in the base class. + """ + assert self.all2all_manager is not None + return self.all2all_manager.dispatch( + hidden_states, + topk_weights, + topk_ids, + is_sequence_parallel, + extra_tensors=extra_tensors, ) def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False ) -> torch.Tensor: + """ + Combine the hidden states and router logits from the appropriate device. + This is a no-op in the base class. + """ assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine( - hidden_states, is_sequence_parallel + return self.all2all_manager.combine( + hidden_states, + is_sequence_parallel, ) - return hidden_states diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d8c6ceba3..8f9dc0354 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1000,7 +1000,7 @@ class GroupCoordinator: if self.device_communicator is not None: self.device_communicator.prepare_communication_buffer_for_model(model) - def dispatch( + def dispatch_router_logits( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1011,7 +1011,7 @@ class GroupCoordinator: | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] ): if self.device_communicator is not None: - return self.device_communicator.dispatch( # type: ignore[call-arg] + return self.device_communicator.dispatch_router_logits( hidden_states, router_logits, is_sequence_parallel, @@ -1020,6 +1020,28 @@ class GroupCoordinator: else: return hidden_states, router_logits + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ): + if self.device_communicator is not None: + return self.device_communicator.dispatch( + hidden_states, + topk_weights, + topk_ids, + is_sequence_parallel, + extra_tensors, + ) + else: + return hidden_states, topk_weights, topk_ids + def combine( self, hidden_states, is_sequence_parallel: bool = False ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 638602a92..bf8ec2dc6 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -7,17 +7,27 @@ import torch from vllm.distributed import ( get_ep_group, ) +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( + FlashInferA2APrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNaiveEP, + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx +logger = init_logger(__name__) + if current_platform.is_cuda_alike(): if has_pplx(): from .pplx_prepare_finalize import ( @@ -70,20 +80,46 @@ def maybe_make_prepare_finalize( moe: FusedMoEConfig, quant_config: FusedMoEQuantConfig | None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + allow_new_interface: bool = False, ) -> FusedMoEPrepareAndFinalize | None: + # NOTE(rob): we are migrating each quant_method to hold the MK + # in all cases. The allow_new_interface=False flag allow us to fall + # back to the old method for methods that have not yet been migrated. + # + # In old method: + # * maybe_init_modular_kernel() calls this function. If we are + # using no Dp/Ep or naive all2all, we return None this function + # returns None and no ModularKernelMethod is created. If non-naive + # all2all is used, this returns a PrepareAndFinalize object and + # a ModularKernelMethod is created. + # In new method: + # * maybe_make_prepare_finalize() is called from the oracle. We + # always return a PrepareAndFinalize object and the quant method + # holds the ModularKernel. if not moe.moe_parallel_config.use_all2all_kernels: - return None + if not allow_new_interface: + return None + + # For DP/TP case, fall back to naive P/F. + if moe.moe_parallel_config.dp_size > 1: + logger.info_once( + "Detected DP deployment with no --enable-expert-parallel. " + "Falling back to AllGather+ReduceScatter dispatch/combine." + ) + return MoEPrepareAndFinalizeNaiveEP( + is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, + num_dispatchers=( + get_ep_group().device_communicator.all2all_manager.world_size + ), + ) + else: + return MoEPrepareAndFinalizeNoEP() all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: FusedMoEPrepareAndFinalize | None = None - # TODO(rob): update this as part of the MoE refactor. - assert not moe.use_flashinfer_cutlass_kernels, ( - "Must be created in modelopt.py or fp8.py" - ) - if moe.use_pplx_kernels: assert quant_config is not None @@ -203,4 +239,16 @@ def maybe_make_prepare_finalize( use_fp8_dispatch=use_fp8_dispatch, ) + elif moe.use_fi_all2allv_kernels: + assert quant_config is not None + prepare_finalize = FlashInferA2APrepareAndFinalize( + num_dispatchers=all2all_manager.world_size, + ) + + elif moe.use_naive_all2all_kernels and allow_new_interface: + prepare_finalize = MoEPrepareAndFinalizeNaiveEP( + is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel), + num_dispatchers=all2all_manager.world_size, + ) + return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 9a28c3193..b275bf414 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import cdiv @@ -862,6 +861,7 @@ class FusedMoEParallelConfig: use_ep: bool # whether to use EP or not all2all_backend: str # all2all backend for MoE communication + is_sequence_parallel: bool # whether sequence parallelism is used enable_eplb: bool # whether to enable expert load balancing @property @@ -883,6 +883,12 @@ class FusedMoEParallelConfig: def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + @property + def use_fi_all2allv_kernels(self): + return ( + self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv" + ) + @property def use_batched_activation_format(self): return self.use_deepep_ll_kernels or self.use_pplx_kernels @@ -1014,6 +1020,7 @@ class FusedMoEParallelConfig: ep_rank=0, use_ep=False, all2all_backend=vllm_parallel_config.all2all_backend, + is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe, enable_eplb=vllm_parallel_config.enable_eplb, ) # DP + EP / TP + EP / DP + TP + EP @@ -1033,6 +1040,7 @@ class FusedMoEParallelConfig: ep_rank=ep_rank, use_ep=True, all2all_backend=vllm_parallel_config.all2all_backend, + is_sequence_parallel=vllm_parallel_config.use_sequence_parallel_moe, enable_eplb=vllm_parallel_config.enable_eplb, ) @@ -1051,6 +1059,7 @@ class FusedMoEParallelConfig: use_ep=False, all2all_backend="naive", enable_eplb=False, + is_sequence_parallel=False, ) @@ -1145,12 +1154,9 @@ class FusedMoEConfig: return self.moe_parallel_config.use_mori_kernels @property - def use_flashinfer_cutlass_kernels(self): - """ - Whether to use FlashInfer cutlass kernels for NVFP4 MoE. - """ - return ( - envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput" - ) + def use_fi_all2allv_kernels(self): + return self.moe_parallel_config.use_fi_all2allv_kernels + + @property + def use_naive_all2all_kernels(self): + return self.moe_parallel_config.use_naive_all2all_kernels diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0d8690638..86edbe303 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -103,7 +103,14 @@ def run_cutlass_moe_fp8( or a2_scale.size(0) == a1q.shape[0] ), "Intermediate scale shape mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - if expert_map is not None: + + # NOTE(rob): the expert_map is used for the STANDARD case and + # the batched format is used by the BATCHED case. + # TODO(rob): update the MK interface to only pass the expert_map + # during the STANDARD case to make this clearer across all kernels. + if use_batched_format: + assert expert_num_tokens is not None + else: assert expert_num_tokens is None # We have two modes: batched experts and non-batched experts. @@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): # needed for STANDARD activation format kernels in DP/EP mode. # Note that the BATCHED activation format does not use # the expert map for identifying experts. - return not moe_parallel_config.use_all2all_kernels + return not ( + moe_parallel_config.use_fi_all2allv_kernels + or moe_parallel_config.use_deepep_ht_kernels + ) def supports_chunking(self) -> bool: return True @@ -641,10 +651,8 @@ def run_cutlass_moe_fp4( class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): - @staticmethod - def expects_unquantized_inputs( - moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig - ) -> bool: + @property + def expects_unquantized_inputs(self) -> bool: return True @staticmethod diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 1d5d039b6..fafcf6de6 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True + # NOTE(rob): discovered an IMA with this combination. Needs investigation. + return not moe_parallel_config.use_fi_all2allv_kernels def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 929cff799..514aa205a 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -103,6 +103,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts: int, a1_scale: torch.Tensor | None, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool, ) -> Callable: has_scales = token_scales is not None @@ -174,6 +175,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_topk_weights, a1_scale, quant_config, + defer_input_quant=defer_input_quant, ) def _receiver( @@ -187,6 +189,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_topk_weights: torch.Tensor | None, a1_scale: torch.Tensor | None, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool, ) -> mk.PrepareResultType: if event.event is not None: event.current_stream_wait() @@ -221,14 +224,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_num_tokens_per_expert_list, device=expert_x.device ) - # Dispatch and Quant - # DeepEP kernels only support dispatching block-quantized - # activation scales. - # Dispatch in bfloat16 and quantize afterwards - if not quant_config.is_block_quantized: + # * For non-block quant, dispatch in b16 and quantize now as + # DeepEP kernels only support dispatching block scales. + # * For expert kernels that require unquantized inputs, + # defer quantization to FusedMoEExpertsPermuteUnpermute. + if not quant_config.is_block_quantized and not defer_input_quant: # Quantize after dispatch. expert_x_scale = None if expert_x.numel() != 0: + # TODO: support per_act_token_quant, expert_x, expert_x_scale = moe_kernel_quantize_input( expert_x, a1_scale, @@ -257,6 +261,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.ReceiverType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -266,8 +271,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ) a1 = a1 * topk_weights.to(a1.dtype) - if quant_config.is_block_quantized: - # Quant and Dispatch + # * DeepEP only supports fp8 block scales so quantize + # before the dispatch for these models. + # * For all other quantization, dispatch after. + # * For expert kernels that require unquantized inputs, + # defer quantization to FusedMoEExpertsPermuteUnpermute. + if quant_config.is_block_quantized and not defer_input_quant: a1q, a1q_scale = moe_kernel_quantize_input( a1, quant_config.a1_scale, @@ -281,7 +290,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): else: a1q = a1 a1q_scale = None - a1_post_scale = quant_config.a1_scale + a1_post_scale = ( + quant_config.a1_gscale + if quant_config.quant_dtype == "nvfp4" + else quant_config.a1_scale + ) return self._do_dispatch( tokens=a1q, @@ -291,6 +304,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts=num_experts, a1_scale=a1_post_scale, quant_config=quant_config, + defer_input_quant=defer_input_quant, ) def prepare( @@ -302,6 +316,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.PrepareResultType: receiver = self.prepare_async( a1, @@ -311,6 +326,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map, apply_router_weight_on_input, quant_config, + defer_input_quant, ) return receiver() diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index c0e0b0b22..f5a3da438 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -242,7 +242,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> tuple[Callable, mk.ReceiverType]: + if defer_input_quant: + raise NotImplementedError( + f"{self.__class__.__name__} does not support defer_input_quant=True. " + "Please select an MoE kernel that accepts quantized inputs." + ) + hidden_size = a1.size(1) assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, ( f"Hidden Size {hidden_size} not in supported list of hidden sizes" @@ -344,7 +351,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.PrepareResultType: + if defer_input_quant: + raise NotImplementedError( + f"{self.__class__.__name__} does not support defer_input_quant=True. " + "Please select an MoE kernel that accepts quantized inputs." + ) hook, receiver = self.prepare_async( a1, topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py new file mode 100644 index 000000000..39b373861 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group +from vllm.distributed.device_communicators.base_device_communicator import ( + All2AllManagerBase, +) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave + + +def get_local_sizes(): + return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + + +class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """Base class for FlashInfer MoE prepare and finalize operations.""" + + def __init__( + self, + num_dispatchers: int = 1, + ): + super().__init__() + self.num_dispatchers_ = num_dispatchers + self.all2all_manager = get_ep_group().device_communicator.all2all_manager + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def output_is_reduced(self) -> bool: + return False + + def _apply_router_weight_on_input( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + """Apply router weight on input if needed.""" + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1.mul_(topk_weights.to(a1.dtype)) + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + self._apply_router_weight_on_input( + a1, topk_weights, topk_ids, apply_router_weight_on_input + ) + global_num_tokens_cpu = get_local_sizes() + top_k = topk_ids.size(1) + + (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = ( + flashinfer_alltoall_dispatch( + self.all2all_manager, + global_num_tokens_cpu, + a1, + quant_config.a1_gscale, + topk_ids, + topk_weights, + top_k, + num_experts, + quant_config, + defer_input_quant=defer_input_quant, + ) + ) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + top_k = topk_ids.size(1) + token_count = output.shape[0] + fused_expert_output = flashinfer_alltoall_combine( + self.all2all_manager, + fused_expert_output, + top_k=top_k, + token_count=token_count, + alltoall_info=self.alltoall_info, + ) + output.copy_(fused_expert_output) + + +def flashinfer_alltoall_dispatch( + all2all_manager: All2AllManagerBase, + global_num_tokens_cpu: list[int], + x: torch.Tensor, + gs: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + top_k: int, + num_experts: int, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, +): + from flashinfer.comm.trtllm_alltoall import MnnvlMoe + + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + "FlashInfer AllToAll workspace not available" + ) + + ep_rank = all2all_manager.rank + ep_size = all2all_manager.world_size + max_num_token = ( + max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0] + ) + orig_topk_weights_dtype = topk_weights.dtype + alltoall_info, topk_ids, topk_weights, _ = ( + MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + topk_ids, + topk_weights, + None, + all2all_manager.prepare_workspace_tensor, + max_num_token, + ep_rank, + ep_size, + num_experts, + num_experts, + top_k, + ) + ) + topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype) + + if not defer_input_quant: + x, x_sf = moe_kernel_quantize_input( + x, + gs, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + # NOTE: swizzling pads the scales to multiple of 128 + # which makes the scales tensor different shape than + # the hidden states, breaking the A2A kernel. So, we + # delay the swizzling until after the A2A. + is_fp4_scale_swizzled=False, + ) + + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + + x_sf = MnnvlMoe.mnnvl_moe_alltoallv( + x_sf, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + + # Swizzle after the A2A if nvfp4. + if quant_config.quant_dtype == "nvfp4": + if x_sf.element_size() == 1: + x_sf = x_sf.view(torch.uint8) + x_sf = nvfp4_block_scale_interleave(x_sf) + else: + # Block-scale path: pass activations through without quantization + x_sf = None + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + return alltoall_info, topk_ids, topk_weights, x, x_sf + + +def flashinfer_alltoall_combine( + all2all_manager: All2AllManagerBase, + output: torch.Tensor, + top_k: int, + token_count: int, + alltoall_info, +): + from flashinfer.comm.trtllm_alltoall import MnnvlMoe + + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + "FlashInfer AllToAll workspace not available" + ) + return MnnvlMoe.mnnvl_moe_alltoallv_combine( + output, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank=all2all_manager.rank, + ep_size=all2all_manager.world_size, + top_k=top_k, + token_count=token_count, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 8d5985875..faa654ea3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -78,16 +78,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): # - skip input activation quantization (kernel applies scaling) self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized - @staticmethod - def expects_unquantized_inputs( - moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig - ) -> bool: - # NVFP4 TP kernels and FP8 block-quantized kernels apply - # input quantization inside FusedMoEPermuteExpertsUnpermute. - return ( - quant_config.use_nvfp4_w4a4 - and not moe_config.moe_parallel_config.use_all2all_kernels - ) or (quant_config.use_fp8_w8a8 and quant_config.is_block_quantized) + @property + def expects_unquantized_inputs(self) -> bool: + return self.quant_config.use_fp8_w8a8 and self.quant_config.is_block_quantized @staticmethod def _supports_current_device() -> bool: @@ -144,10 +137,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): # FLASHINFER_CUTLASS currently uses its down P/F, which does not # work with SP. This will be removed in follow up after we get # rid of the FlashInfer specific P/F function. - return ( - moe_parallel_config.dp_size == 1 - or moe_parallel_config.dp_size == moe_parallel_config.ep_size - ) + # TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As. + return not moe_parallel_config.is_sequence_parallel @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -194,8 +185,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): """ workspace1 = (M, K) workspace2 = (0,) - # For TP, the quantization is fused with fused_moe call. - output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K) + # For NVFP4, the output is stored in a packed int8 format, + # so the actual hidden dim is 2x the size of K here. + output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K) # The workspace is determined by `aq`, since it comes after any # potential communication op and is involved in the expert computation. return (workspace1, workspace2, output_shape) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py deleted file mode 100644 index a56d68566..000000000 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ /dev/null @@ -1,373 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import get_dp_group, get_ep_group -from vllm.distributed.device_communicators.base_device_communicator import ( - All2AllManagerBase, -) -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP, -) -from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input -from vllm.utils.flashinfer import nvfp4_block_scale_interleave - - -def get_local_sizes(): - return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() - - -class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): - """Base class for FlashInfer MoE prepare and finalize operations.""" - - def __init__( - self, - use_dp: bool, - num_dispatchers: int = 1, - use_deepseek_fp8_block_scale: bool = False, - ): - super().__init__() - self.num_dispatchers_ = num_dispatchers - self.use_dp = use_dp - self.local_tokens = None - # Toggle for DeepSeek-style FP8 block-scale path where activations are - # not quantized here and weight block scales are consumed by the kernel. - self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale - - @property - def activation_format(self) -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - def max_num_tokens_per_rank(self) -> int | None: - return None - - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - def num_dispatchers(self) -> int: - return self.num_dispatchers_ - - def output_is_reduced(self) -> bool: - return False - - def _apply_router_weight_on_input( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - ) -> None: - """Apply router weight on input if needed.""" - if apply_router_weight_on_input: - topk = topk_ids.size(1) - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - a1.mul_(topk_weights.to(a1.dtype)) - - -class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): - """FlashInfer implementation using AllToAll communication.""" - - def __init__( - self, - use_dp: bool, - num_dispatchers: int = 1, - use_deepseek_fp8_block_scale: bool = False, - ): - super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) - self.alltoall_info = None - - # Initialize all2all_manager only for DP case - self.all2all_manager = None - if self.use_dp: - self.all2all_manager = get_ep_group().device_communicator.all2all_manager - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - self._apply_router_weight_on_input( - a1, topk_weights, topk_ids, apply_router_weight_on_input - ) - - if not self.use_dp: - # Non-DP case: quantize activations unless using block-scale path - if not self.use_deepseek_fp8_block_scale: - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) - else: - a1q = a1 - a1q_scale = None - else: - # DP case: use FlashInfer AllToAll - global_num_tokens_cpu = get_local_sizes() - top_k = topk_ids.size(1) - - (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = ( - flashinfer_alltoall_dispatch( - self.all2all_manager, - global_num_tokens_cpu, - a1, - quant_config.a1_gscale, - topk_ids, - topk_weights, - top_k, - num_experts, - quant_config, - use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, - ) - ) - - return a1q, a1q_scale, None, topk_ids, topk_weights - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - if self.use_dp: - top_k = topk_ids.size(1) - token_count = output.shape[0] - fused_expert_output = flashinfer_alltoall_combine( - self.all2all_manager, - fused_expert_output, - top_k=top_k, - token_count=token_count, - alltoall_info=self.alltoall_info, - ) - output.copy_(fused_expert_output) - - -class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): - def __init__( - self, - use_dp: bool, - num_dispatchers: int = 1, - use_deepseek_fp8_block_scale: bool = False, - ): - super().__init__(use_dp, num_dispatchers, use_deepseek_fp8_block_scale) - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - ) -> mk.PrepareResultType: - self._apply_router_weight_on_input( - a1, topk_weights, topk_ids, apply_router_weight_on_input - ) - is_nvfp4 = quant_config.quant_dtype == "nvfp4" - if not self.use_dp and is_nvfp4: - return a1, None, None, topk_ids, topk_weights - - if not self.use_deepseek_fp8_block_scale: - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale if is_nvfp4 else quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=not self.use_dp, - ) - else: - # Block-scale path: pass activations through, omit per-token scales - a1q = a1 - a1q_scale = None - - if self.use_dp: - # Build gather list conditionally - omit a1q_scale if None - # (block-scale path) - gather_list = [topk_weights, topk_ids, a1q] - if a1q_scale is not None: - gather_list.append(a1q_scale) - gathered = get_dp_group().all_gatherv( - gather_list, - dim=0, - sizes=get_local_sizes(), - ) - topk_weights, topk_ids, a1q, a1q_scale = gathered - else: - gathered = get_dp_group().all_gatherv( - gather_list, - dim=0, - sizes=get_local_sizes(), - ) - topk_weights, topk_ids, a1q = gathered - a1q_scale = None - - if is_nvfp4 and a1q_scale is not None: - if a1q_scale.element_size() == 1: - a1q_scale = a1q_scale.view(torch.uint8) - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) - - return a1q, a1q_scale, None, topk_ids, topk_weights - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP) - - if self.use_dp: - fused_expert_output = get_dp_group().reduce_scatterv( - fused_expert_output, dim=0, sizes=get_local_sizes() - ) - output.copy_(fused_expert_output) - - -def flashinfer_alltoall_dispatch( - all2all_manager: All2AllManagerBase, - global_num_tokens_cpu: list[int], - x: torch.Tensor, - gs: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - top_k: int, - num_experts: int, - quant_config: FusedMoEQuantConfig, - use_deepseek_fp8_block_scale: bool = False, -): - from flashinfer.comm.trtllm_alltoall import MnnvlMoe - - assert all2all_manager.ensure_alltoall_workspace_initialized(), ( - "FlashInfer AllToAll workspace not available" - ) - - ep_rank = all2all_manager.rank - ep_size = all2all_manager.world_size - max_num_token = ( - max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0] - ) - orig_topk_weights_dtype = topk_weights.dtype - alltoall_info, topk_ids, topk_weights, _ = ( - MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( - topk_ids, - topk_weights, - None, - all2all_manager.prepare_workspace_tensor, - max_num_token, - ep_rank, - ep_size, - num_experts, - num_experts, - top_k, - ) - ) - topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype) - - if not use_deepseek_fp8_block_scale: - x, x_sf = moe_kernel_quantize_input( - x, - gs, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - is_fp4_scale_swizzled=False, # delay swizzle to after comm - ) - x = MnnvlMoe.mnnvl_moe_alltoallv( - x, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank, - ep_size, - ) - - x_sf = MnnvlMoe.mnnvl_moe_alltoallv( - x_sf, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank, - ep_size, - ) - if quant_config.quant_dtype == "nvfp4": - x_sf = nvfp4_block_scale_interleave(x_sf) - else: - # Block-scale path: pass activations through without quantization - x_sf = None - x = MnnvlMoe.mnnvl_moe_alltoallv( - x, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank, - ep_size, - ) - return alltoall_info, topk_ids, topk_weights, x, x_sf - - -def flashinfer_alltoall_combine( - all2all_manager: All2AllManagerBase, - output: torch.Tensor, - top_k: int, - token_count: int, - alltoall_info, -): - from flashinfer.comm.trtllm_alltoall import MnnvlMoe - - assert all2all_manager.ensure_alltoall_workspace_initialized(), ( - "FlashInfer AllToAll workspace not available" - ) - return MnnvlMoe.mnnvl_moe_alltoallv_combine( - output, - alltoall_info, - all2all_manager.workspace_tensor, - ep_rank=all2all_manager.rank, - ep_size=all2all_manager.world_size, - top_k=top_k, - token_count=token_count, - ) - - -def create_flashinfer_prepare_finalize( - use_dp: bool, - use_nvfp4: bool = False, - enable_alltoallv: bool = False, - use_deepseek_fp8_block_scale: bool = False, -) -> FlashInferCutlassMoEPrepareAndFinalize | MoEPrepareAndFinalizeNoEP: - """Factory function to create the appropriate FlashInfer implementation.""" - - if use_dp: - if enable_alltoallv: - assert use_nvfp4 - return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) - return FlashInferAllGatherMoEPrepareAndFinalize( - use_dp=True, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ) - else: - # CUTLASS FP8 BLOCK and CUTLASS NVFP4 apply input quantization - # in a single call with the MoE experts kernel. - defer_input_quant = use_deepseek_fp8_block_scale or use_nvfp4 - return MoEPrepareAndFinalizeNoEP(defer_input_quant=defer_input_quant) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 509bacfbc..c681e083a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.PrepareResultType: + if defer_input_quant: + raise NotImplementedError( + f"{self.__class__.__name__} does not support defer_input_quant=True. " + "Please select an MoE kernel that accepts quantized inputs." + ) assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index ce77eafa6..2e5167bdf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -593,7 +593,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True + return not moe_parallel_config.use_fi_all2allv_kernels @property def quant_type_id(self) -> int: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 669a6e74b..0206e19de 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1951,7 +1951,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True + return not moe_parallel_config.use_fi_all2allv_kernels def supports_chunking(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 36562142c..3ad56cc4c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -5,6 +5,7 @@ from abc import abstractmethod import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -26,6 +27,19 @@ class FusedMoEMethodBase(QuantizeMethodBase): super().__init__() self.moe: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None + self.moe_mk: mk.FusedMoEModularKernel | None = None + + @property + def supports_internal_mk(self) -> bool: + # NOTE(rob): temporary attribute to indicate support for + # completed migration to the new internal MK interface. + return self.moe_mk is not None + + @property + def mk_owns_shared_expert(self) -> bool: + # NOTE(rob): temporary attribute to indicate support for + # completed migration to the new internal MK interface. + return self.moe_mk is not None and self.moe_mk.shared_experts is not None @abstractmethod def create_weights( @@ -91,6 +105,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): @property def topk_indices_dtype(self) -> torch.dtype | None: + if self.moe_mk is not None: + return self.moe_mk.prepare_finalize.topk_indices_dtype() return None @property diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 53df0dc85..7a2244a9b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ): super().__init__(old_quant_method.moe) self.moe_quant_config = old_quant_method.moe_quant_config - self.fused_experts = experts + self.moe_mk = experts self.disable_expert_map = getattr( old_quant_method, "disable_expert_map", - not self.fused_experts.supports_expert_map(), + not self.moe_mk.supports_expert_map(), ) self.old_quant_method = old_quant_method assert not self.old_quant_method.is_monolithic @@ -57,10 +57,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ), ) - @property - def topk_indices_dtype(self) -> torch.dtype | None: - return self.fused_experts.prepare_finalize.topk_indices_dtype() - @property def supports_eplb(self) -> bool: return self.old_quant_method.supports_eplb @@ -96,7 +92,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - return self.fused_experts( + assert self.moe_mk is not None + return self.moe_mk( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index daf2e0e6b..e38275004 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -571,9 +571,6 @@ class FusedMoE(CustomOp): device=vllm_config.device_config.device, routing_method=self.routing_method_type, ) - self.moe_config_use_flashinfer_cutlass_kernels = ( - self.moe_config.use_flashinfer_cutlass_kernels - ) if self.use_mori_kernels: assert self.rocm_aiter_fmoe_enabled, ( "Mori needs to be used with aiter fused_moe for now." @@ -646,6 +643,11 @@ class FusedMoE(CustomOp): # This is called after all weight loading and post-processing, so it # should be safe to swap out the quant_method. def maybe_init_modular_kernel(self) -> None: + # NOTE(rob): WIP refactor. For quant methods that own the MK + # we create the MK during process_weights_after_loading. + if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic: + return None + self.ensure_moe_quant_config_init() # routing_tables only needed for round-robin expert placement with # DeepEP all2all backend. @@ -728,14 +730,6 @@ class FusedMoE(CustomOp): def use_mori_kernels(self): return self.moe_parallel_config.use_mori_kernels - @property - def use_flashinfer_cutlass_kernels(self): - return ( - self.moe_quant_config is not None - and self.moe_quant_config.quant_dtype == "nvfp4" - and self.moe_config_use_flashinfer_cutlass_kernels - ) - @property def use_marlin_kernels(self): return getattr(self.quant_method, "use_marlin", False) @@ -746,7 +740,7 @@ class FusedMoE(CustomOp): self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels or self.moe_parallel_config.use_mori_kernels - or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) + or self.moe_parallel_config.use_fi_all2allv_kernels ) and envs.VLLM_ENABLE_MOE_DP_CHUNK @property @@ -1532,7 +1526,7 @@ class FusedMoE(CustomOp): assert self.quant_method is not None return ( isinstance(self.quant_method, FusedMoEModularMethod) - and self.quant_method.fused_experts.output_is_reduced() + and self.quant_method.moe_mk.output_is_reduced() # type: ignore[union-attr] ) def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): @@ -1765,7 +1759,7 @@ class FusedMoE(CustomOp): self.ensure_dp_chunking_init() has_separate_shared_experts = ( - not isinstance(self.quant_method, FusedMoEModularMethod) + not self.quant_method.mk_owns_shared_expert and self.shared_experts is not None ) @@ -1789,8 +1783,10 @@ class FusedMoE(CustomOp): hidden_states, router_logits, has_separate_shared_experts ) - do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( - self.quant_method, FusedMoEModularMethod + # NOTE(rob): once we finish migrating all the quant methods to use + # MKs, we can remove the naive dispatch/combine path from here. + do_naive_dispatch_combine = ( + self.dp_size > 1 and not self.quant_method.supports_internal_mk ) ctx = get_forward_context() @@ -1818,7 +1814,7 @@ class FusedMoE(CustomOp): else: hidden_states_to_dispatch = hidden_states - dispatch_res = get_ep_group().dispatch( + dispatch_res = get_ep_group().dispatch_router_logits( hidden_states_to_dispatch, router_logits, self.is_sequence_parallel, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6aff6401e..940a2c55f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -180,6 +180,7 @@ class FusedMoEPrepareAndFinalize(ABC): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool, ) -> PrepareResultType: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -192,6 +193,9 @@ class FusedMoEPrepareAndFinalize(ABC): - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - quant_config: Quantization info provided by the fused experts. + - defer_input_quant: Runtime parameter indicating whether or not to + defer input quantization to the FusedMoEPermuteExpertsUnpermute + in cases where the compute kernel expects unquantized inputs Returns a tuple of: - quantized + dispatched a. @@ -220,6 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool, ) -> tuple[Callable, ReceiverType] | ReceiverType: """ Perform any quantization (and/or) dispatching needed for this kernel @@ -235,6 +240,9 @@ class FusedMoEPrepareAndFinalize(ABC): space to the local expert space of the expert parallel shard. - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. + - defer_input_quant: Runtime parameter indicating whether or not to + defer input quantization to the FusedMoEPermuteExpertsUnpermute + in cases where the compute kernel expects unquantized inputs Returns a callback or a hook callback pair that when invoked waits for results from other workers and has the same return signature as @@ -407,10 +415,8 @@ class FusedMoEPermuteExpertsUnpermute(ABC): self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers - @staticmethod - def expects_unquantized_inputs( - moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig - ) -> bool: + @property + def expects_unquantized_inputs(self) -> bool: """ Whether or not the PrepareFinalize should defer input quantization in the prepare step. If True, then the Experts kernel will @@ -1069,6 +1075,7 @@ class FusedMoEModularKernel(torch.nn.Module): expert_map, apply_router_weight_on_input, self.fused_experts.quant_config, + defer_input_quant=self.fused_experts.expects_unquantized_inputs, ) else: # Overlap shared expert compute with all2all dispatch. @@ -1081,6 +1088,7 @@ class FusedMoEModularKernel(torch.nn.Module): expert_map, apply_router_weight_on_input, self.fused_experts.quant_config, + defer_input_quant=self.fused_experts.expects_unquantized_inputs, ) # TODO(lucas): refactor this in the alternative schedules followup diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index 930e7ae3f..dc0f32dc1 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -58,6 +58,7 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.PrepareResultType: """ Returns a tuple of: @@ -69,6 +70,11 @@ class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): - Optional dispatched expert topk IDs - Optional dispatched expert topk weight """ + if defer_input_quant: + raise NotImplementedError( + f"{self.__class__.__name__} does not support defer_input_quant=True. " + "Please select an MoE kernel that accepts quantized inputs." + ) assert not apply_router_weight_on_input, ( "mori does not support apply_router_weight_on_input=True now." ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index dd0ad4523..15fc6e237 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -8,6 +8,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -17,9 +20,6 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( is_supported_config_trtllm, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, @@ -465,68 +465,52 @@ def make_fp8_moe_quant_config( ) -def make_fp8_moe_kernel_for_mkm( +def make_fp8_moe_kernel( + moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], - prepare_finalize: mk.FusedMoEPrepareAndFinalize, -) -> mk.FusedMoEPermuteExpertsUnpermute: + fp8_backend: Fp8MoeBackend, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + shared_experts: torch.nn.Module | None = None, +) -> tuple[mk.FusedMoEModularKernel, bool]: + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=routing_tables, + allow_new_interface=True, + ) + assert prepare_finalize is not None + + logger.info_once("Using %s", prepare_finalize.__class__.__name__) + + # Create Experts. if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None + max_num_tokens = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens is not None experts = experts_cls( moe_config=moe_config, - quant_config=quant_config, - max_num_tokens=max_num_tokens_per_rank, + quant_config=moe_quant_config, + max_num_tokens=max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), ) else: experts = experts_cls( moe_config=moe_config, - quant_config=quant_config, + quant_config=moe_quant_config, ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts - - -def make_fp8_moe_kernel( - moe_quant_config: FusedMoEQuantConfig, - moe_config: FusedMoEConfig, - fp8_backend: Fp8MoeBackend, - experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], -) -> tuple[mk.FusedMoEModularKernel, bool]: - # TODO(rob): unify after we merge tp and dp/ep. - if ( - moe_config.moe_parallel_config.use_all2all_kernels - and moe_config.moe_parallel_config.all2all_backend - not in ["allgather_reducescatter", "naive"] - ): - raise ValueError( - "Fp8 Oracle should not create non-naive A2A P/F. " - "This should happen via the ModularKernelMethod." - ) - - # Create Prepare/Finalize. - prepare_finalize = MoEPrepareAndFinalizeNoEP( - defer_input_quant=experts_cls.expects_unquantized_inputs( - moe_config, moe_quant_config - ), - ) - - # Create Experts. - experts = experts_cls( - moe_config=moe_config, - quant_config=moe_quant_config, - ) - # NOTE(rob): we only want the mk to control the shared_expert # if using all2all (for SBO). bnell is making this explict in # the new MoE runner class. kernel = mk.FusedMoEModularKernel( prepare_finalize, experts, - shared_experts=None, + shared_experts=( + shared_experts + if moe_config.moe_parallel_config.use_all2all_kernels + else None + ), moe_parallel_config=moe_config.moe_parallel_config, ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 897631e23..276d231eb 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -7,6 +7,9 @@ import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -14,9 +17,6 @@ from vllm.model_executor.layers.fused_moe.config import ( nvfp4_moe_quant_config, nvfp4_w4a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( is_supported_config_trtllm, prepare_nvfp4_moe_layer_for_fi_or_cutlass, @@ -391,67 +391,51 @@ def make_nvfp4_moe_quant_config( ) -def make_nvfp4_moe_kernel_for_mkm( +def make_nvfp4_moe_kernel( + moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], - prepare_finalize: mk.FusedMoEPrepareAndFinalize, -) -> mk.FusedMoEPermuteExpertsUnpermute: + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + shared_experts: torch.nn.Module | None = None, +) -> mk.FusedMoEModularKernel: + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=routing_tables, + allow_new_interface=True, + ) + assert prepare_finalize is not None + + logger.info_once("Using %s", prepare_finalize.__class__.__name__) + + # Create Experts. if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() - assert max_num_tokens_per_rank is not None + max_num_tokens = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens is not None experts = experts_cls( moe_config=moe_config, - quant_config=quant_config, - max_num_tokens=max_num_tokens_per_rank, + quant_config=moe_quant_config, + max_num_tokens=max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), ) else: experts = experts_cls( moe_config=moe_config, - quant_config=quant_config, + quant_config=moe_quant_config, ) - logger.debug_once("Using %s", experts.__class__.__name__) - return experts - - -def make_nvfp4_moe_kernel( - moe_quant_config: FusedMoEQuantConfig, - moe_config: FusedMoEConfig, - experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], -) -> mk.FusedMoEModularKernel: - # TODO(rob): unify after we merge tp and dp/ep. - if ( - moe_config.moe_parallel_config.use_all2all_kernels - and moe_config.moe_parallel_config.all2all_backend - not in ["allgather_reducescatter", "naive"] - ): - raise ValueError( - "NvFP4 Oracle should not create non-naive A2A P/F. " - "This should happen via the ModularKernelMethod." - ) - - # Create Prepare/Finalize. - prepare_finalize = MoEPrepareAndFinalizeNoEP( - defer_input_quant=experts_cls.expects_unquantized_inputs( - moe_config, moe_quant_config - ), - ) - - # Create Experts. - experts = experts_cls( - moe_config=moe_config, - quant_config=moe_quant_config, - ) - # NOTE(rob): we only want the mk to control the shared_expert # if using all2all (for SBO). bnell is making this explict in # the new MoE runner class. kernel = mk.FusedMoEModularKernel( prepare_finalize, experts, - shared_experts=None, + shared_experts=( + shared_experts + if moe_config.moe_parallel_config.use_all2all_kernels + else None + ), moe_parallel_config=moe_config.moe_parallel_config, ) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 8c4da9711..78b941498 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -106,7 +106,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> tuple[Callable, mk.ReceiverType]: + if defer_input_quant: + raise NotImplementedError( + f"{self.__class__.__name__} does not support defer_input_quant=True. " + "Please select an MoE kernel that accepts quantized inputs." + ) + num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -274,6 +281,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.PrepareResultType: hook, receiver = self.prepare_async( a1, @@ -283,6 +291,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_map, apply_router_weight_on_input, quant_config, + defer_input_quant=defer_input_quant, ) hook() return receiver() diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 5d806fa84..d10476702 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -4,19 +4,133 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate, ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave + + +class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): + def __init__( + self, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, + ) -> None: + super().__init__() + self.is_sequence_parallel = is_sequence_parallel + self._num_dispatchers = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return self._num_dispatchers + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) + + # Defer input quantization to the MoE kernel. + use_nvfp4 = quant_config.use_nvfp4_w4a4 + if defer_input_quant: + a1q = a1 + a1q_scale = None + else: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + # NOTE: swizzling pads the scales to multiple of 128 + # which makes the scales tensor different shape than + # the hidden states, breaking the A2A kernel. So, we + # delay the swizzling until after the A2A. + is_fp4_scale_swizzled=False, + ) + + # Skip gathering scales if we have static quantization + # (the scale is a scalar, replicated on all ranks) or + # if quantization is deferred. + skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 + scales = None if skip_gather_scales else [a1q_scale] + + res = get_ep_group().dispatch( + a1q, + topk_weights, + topk_ids, + is_sequence_parallel=self.is_sequence_parallel, + extra_tensors=scales, + ) + if skip_gather_scales: + a1q, topk_weights, topk_ids = res + else: + a1q, topk_weights, topk_ids, scales = res + assert scales is not None and len(scales) == 1 + a1q_scale = scales[0] + if quant_config.quant_dtype == "nvfp4": + assert a1q_scale is not None + if a1q_scale.element_size() == 1: + a1q_scale = a1q_scale.view(torch.uint8) + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + + out = weight_and_reduce_impl.apply( + output=None, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + output.copy_( + get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel) + ) class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): - def __init__(self, defer_input_quant: bool = False) -> None: - super().__init__() - self.defer_input_quant = defer_input_quant - @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -42,6 +156,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, ) -> mk.PrepareResultType: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -54,12 +169,17 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): # Defer input quant to moe kernel for backends (e.g. AITER, FI) # which use a single kernel call for quant + experts. - if self.defer_input_quant: + if defer_input_quant: return a1, None, None, None, None + input_sf = ( + quant_config.a1_gscale + if quant_config.use_nvfp4_w4a4 + else quant_config.a1_scale + ) a1q, a1q_scale = moe_kernel_quantize_input( a1, - quant_config.a1_scale, + input_sf, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 81ed0c504..33150da6f 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -287,17 +287,14 @@ def rocm_aiter_fused_experts( class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): + @property + def expects_unquantized_inputs(self) -> bool: + return True + @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard - @staticmethod - def expects_unquantized_inputs( - fused_moe_config: mk.FusedMoEConfig, quant_config: FusedMoEQuantConfig - ) -> bool: - # AITER fused MoE kernels handle input quantization internally. - return True - @staticmethod def _supports_current_device() -> bool: return rocm_aiter_ops.is_fused_moe_enabled() @@ -329,7 +326,7 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - return True + return not moe_parallel_config.use_fi_all2allv_kernels def supports_expert_map(self): return True diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index a143347b1..63ddb24b8 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -37,7 +37,7 @@ class SharedFusedMoE(FusedMoE): use_overlapped and not ( (self.enable_eplb and backend != "allgather_reducescatter") - or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) + or self.moe_parallel_config.use_fi_all2allv_kernels ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ec120cab4..f26ddfb87 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, - make_fp8_moe_kernel_for_mkm, make_fp8_moe_quant_config, select_fp8_moe_backend, ) @@ -53,7 +52,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( is_global_sf_supported_for_nvfp4_backend, make_mxfp4_moe_quant_config, make_nvfp4_moe_kernel, - make_nvfp4_moe_kernel_for_mkm, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, ) @@ -67,7 +65,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_input_tensor_strategy_moe, @@ -243,7 +240,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): self.group_size = 32 self.mxfp4_backend = NvFp4MoeBackend.MARLIN self.experts_cls = MarlinExperts - self.kernel: mk.FusedMoEModularKernel | None = None def create_weights( self, @@ -320,7 +316,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale ) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + def process_weights_after_loading(self, layer: FusedMoE) -> None: layer.w13_weight = torch.nn.Parameter( layer.w13_weight_packed.data, requires_grad=False ) @@ -335,10 +331,12 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config is not None: - self.kernel = make_nvfp4_moe_kernel( + self.moe_mk = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), ) def apply( @@ -348,8 +346,8 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.kernel is not None - return self.kernel( + assert self.moe_mk is not None + return self.moe_mk( x, layer.w13_weight, layer.w2_weight, @@ -380,19 +378,10 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): activation_key=None if use_a16 else kNvfp4Dynamic, ) - # Delay creation of the kernel until after process-weights. - self.kernel: mk.FusedMoEModularKernel | None = None - self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - @property - def topk_indices_dtype(self) -> torch.dtype | None: - if self.kernel is not None: - return self.kernel.prepare_finalize.topk_indices_dtype() - return None - def create_weights( self, layer: torch.nn.Module, @@ -506,7 +495,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ) set_weight_attrs(w2_input_scale, extra_weight_attrs) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + def process_weights_after_loading(self, layer: FusedMoE) -> None: """ Convert NVFP4 MoE weights into kernel format and setup the kernel. """ @@ -572,48 +561,33 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config and ( - (not self.moe.moe_parallel_config.use_all2all_kernels) - or self.moe.moe_parallel_config.use_naive_all2all_kernels - ): + if self.moe_quant_config: assert self.experts_cls is not None - self.kernel = make_nvfp4_moe_kernel( + self.moe_mk = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - return None - elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # For no-EP case, don't use the MKM framework. - if not self.moe.moe_parallel_config.use_all2all_kernels: - return None - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=False, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - assert self.experts_cls is not None - return make_nvfp4_moe_kernel_for_mkm( - moe_config=self.moe, - quant_config=self.moe_quant_config, - experts_cls=self.experts_cls, - prepare_finalize=prepare_finalize, + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." ) def get_fused_moe_quant_config( @@ -684,8 +658,8 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): global_num_experts=layer.global_num_experts, ) else: - assert self.kernel is not None - return self.kernel( + assert self.moe_mk is not None + return self.moe_mk( x, layer.w13_weight, layer.w2_weight, @@ -759,15 +733,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): allow_vllm_cutlass=True, ) - # Delay creation of the kernel until after process-weights. - self.kernel: mk.FusedMoEModularKernel | None = None - - @property - def topk_indices_dtype(self) -> torch.dtype | None: - if self.kernel is not None: - return self.kernel.prepare_finalize.topk_indices_dtype() - return None - def create_weights( self, layer: torch.nn.Module, @@ -927,7 +892,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w13_input_scale = None layer.w2_input_scale = None - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + def process_weights_after_loading(self, layer: FusedMoE) -> None: # Allow for accessing weights and scales in standard way. w13 = layer.w13_weight w2 = layer.w2_weight @@ -989,49 +954,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config and ( - (not self.moe.moe_parallel_config.use_all2all_kernels) - or self.moe.moe_parallel_config.use_naive_all2all_kernels - ): + if self.moe_quant_config: assert self.experts_cls is not None - self.kernel, self.use_inplace = make_fp8_moe_kernel( + self.moe_mk, self.use_inplace = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # For no-EP case, don't use the MKM framework. - if not self.moe.moe_parallel_config.use_all2all_kernels: - return None - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - assert self.experts_cls is not None - return make_fp8_moe_kernel_for_mkm( - moe_config=self.moe, - quant_config=self.moe_quant_config, - experts_cls=self.experts_cls, - prepare_finalize=prepare_finalize, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." ) def get_fused_moe_quant_config( @@ -1120,8 +1070,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - assert self.kernel is not None - return self.kernel( + assert self.moe_mk is not None + return self.moe_mk( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index fe59022cb..60600e1e3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, - make_fp8_moe_kernel_for_mkm, make_fp8_moe_quant_config, select_fp8_moe_backend, ) @@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): allow_vllm_cutlass=False, ) - # Delay creation of the kernel until after process-weights. - self.kernel: mk.FusedMoEModularKernel | None = None - - @property - def topk_indices_dtype(self) -> torch.dtype | None: - if self.kernel is not None: - return self.kernel.prepare_finalize.topk_indices_dtype() - return None - def create_weights( self, layer: Module, @@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def _setup_kernel( self, - layer: Module, + layer: FusedMoE, w13: torch.Tensor, w2: torch.Tensor, w13_scale: torch.Tensor, @@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase): # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config and ( - (not self.moe.moe_parallel_config.use_all2all_kernels) - or self.moe.moe_parallel_config.use_naive_all2all_kernels - ): + if self.moe_quant_config: assert self.experts_cls is not None - self.kernel, self.use_inplace = make_fp8_moe_kernel( + self.moe_mk, self.use_inplace = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, ) def process_weights_after_loading(self, layer: Module) -> None: @@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # For no-EP case, don't use the MKM framework. - if not self.moe.moe_parallel_config.use_all2all_kernels: - return None - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - assert self.experts_cls is not None - return make_fp8_moe_kernel_for_mkm( - moe_config=self.moe, - quant_config=self.moe_quant_config, - experts_cls=self.experts_cls, - prepare_finalize=prepare_finalize, + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." ) def get_fused_moe_quant_config( @@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.kernel is not None + assert self.moe_mk is not None assert not self.is_monolithic - return self.kernel( + return self.moe_mk( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 476ad618e..e10144ed1 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, - make_fp8_moe_kernel_for_mkm, make_fp8_moe_quant_config, select_fp8_moe_backend, ) @@ -35,7 +34,6 @@ from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, make_nvfp4_moe_kernel, - make_nvfp4_moe_kernel_for_mkm, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, ) @@ -54,13 +52,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, @@ -739,47 +735,23 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): activation_key=kFp8StaticTensorSym, ) - # Delay creation of the kernel until after process-weights. - self.kernel: mk.FusedMoEModularKernel | None = None - - @property - def topk_indices_dtype(self) -> torch.dtype | None: - if self.kernel is not None: - return self.kernel.prepare_finalize.topk_indices_dtype() - return None - def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - # TRT LLM not supported with all2all yet. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - # For no-EP case, don't use the MKM framework. - if not self.moe.moe_parallel_config.use_all2all_kernels: - return None - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=False, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - assert self.experts_cls is not None - return make_fp8_moe_kernel_for_mkm( - moe_config=self.moe, - quant_config=self.moe_quant_config, - experts_cls=self.experts_cls, - prepare_finalize=prepare_finalize, + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." ) def create_weights( @@ -863,7 +835,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def _setup_kernel( self, - layer: torch.nn.Module, + layer: FusedMoE, w13: torch.Tensor, w2: torch.Tensor, w13_scale: torch.Tensor, @@ -893,11 +865,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.kernel, self.use_inplace = make_fp8_moe_kernel( + self.moe_mk, self.use_inplace = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -998,8 +972,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): f"but got {layer.activation}" ) - assert self.kernel is not None - return self.kernel( + assert self.moe_mk is not None + return self.moe_mk( x, layer.w13_weight, layer.w2_weight, @@ -1379,50 +1353,27 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): activation_key=kNvfp4Dynamic, ) - # Delay creation of the kernel until after process-weights. - self.kernel: mk.FusedMoEModularKernel | None = None - self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) - @property - def topk_indices_dtype(self) -> torch.dtype | None: - if self.kernel is not None: - return self.kernel.prepare_finalize.topk_indices_dtype() - return None - def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - return None - elif self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: - # For no-EP case, don't use the MKM framework. - if not self.moe.moe_parallel_config.use_all2all_kernels: - return None - # For now, fp4 moe only works with the flashinfer dispatcher. - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - self.moe - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize - else: - return super().maybe_make_prepare_finalize(routing_tables) + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - assert self.moe_quant_config is not None - assert self.experts_cls is not None - return make_nvfp4_moe_kernel_for_mkm( - moe_config=self.moe, - quant_config=self.moe_quant_config, - experts_cls=self.experts_cls, - prepare_finalize=prepare_finalize, + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." ) def uses_weight_scale_2_pattern(self) -> bool: @@ -1547,7 +1498,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): ) layer.register_parameter("w2_input_scale", w2_input_scale) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + def process_weights_after_loading(self, layer: FusedMoE) -> None: """ Convert NVFP4 MoE weights into kernel format and setup the kernel. """ @@ -1599,15 +1550,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config and ( - (not self.moe.moe_parallel_config.use_all2all_kernels) - or self.moe.moe_parallel_config.use_naive_all2all_kernels - ): + if self.moe_quant_config: assert self.experts_cls is not None - self.kernel = make_nvfp4_moe_kernel( + self.moe_mk = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), ) @property @@ -1708,8 +1658,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): global_num_experts=layer.global_num_experts, ) else: - assert self.kernel is not None - return self.kernel( + assert self.moe_mk is not None + return self.moe_mk( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 52d0a5c47..ae5a934fb 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -15,9 +15,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEParallelConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kNvfp4Dynamic, @@ -42,7 +39,6 @@ __all__ = [ "is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutedsl_moe_available", "reorder_w1w3_to_w3w1", - "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] # @@ -163,17 +159,6 @@ def reorder_w1w3_to_w3w1( ) -def build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig, -) -> mk.FusedMoEPrepareAndFinalize: - """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" - use_dp = moe.moe_parallel_config.dp_size > 1 - enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv" - return create_flashinfer_prepare_finalize( - use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv - ) - - def prepare_static_weights_for_trtllm_fp4_moe( # args_dequant, # args, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index d9824c107..cd82b5432 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -4,15 +4,8 @@ from enum import Enum import torch -import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, -) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize, -) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up @@ -163,18 +156,6 @@ def make_fp8_moe_alpha_scales_for_fi( return g1_alphas, g2_alphas -def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig | None, use_deepseek_fp8_block_scale: bool = False -) -> mk.FusedMoEPrepareAndFinalize: - """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" - use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - # Propagate block-scale flag so prepare/finalize can skip act quantization - # and inform the kernel to consume per-block weight scales. - return create_flashinfer_prepare_finalize( - use_dp, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale - ) - - def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 21409de86..cd4efe1ca 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -14,7 +14,6 @@ from vllm.distributed.parallel_state import get_dp_group, is_global_first_rank from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) @@ -169,9 +168,10 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: # modular kernels could invoke deep_gemm_moe_fp8 return True - mk: FusedMoEModularKernel = module.quant_method.fused_experts # Further check if the ModularKernel implementation uses the DeepGemmExperts - return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts)) + return isinstance( + module.quant_method.moe_mk, (DeepGemmExperts, TritonOrDeepGemmExperts) + ) FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set()