diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 327cd44f6..893968b5c 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -620,6 +620,7 @@ def make_modular_kernel( modular_kernel = mk.FusedMoEModularKernel( prepare_finalize=prepare_finalize, fused_experts=fused_experts, + inplace=False, ) return modular_kernel diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index 081a5fd0b..2c6c45a5f 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -74,7 +74,11 @@ def test_batched_deepgemm_vs_triton( quant_config=quant_config, moe_config=make_dummy_moe_config(), ) - mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) + mk_triton = FusedMoEModularKernel( + prep_finalize, + triton_experts, + inplace=False, + ) out_triton = mk_triton( hidden_states=a, @@ -82,7 +86,6 @@ def test_batched_deepgemm_vs_triton( w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=False, global_num_experts=E, ) @@ -93,7 +96,11 @@ def test_batched_deepgemm_vs_triton( quant_config=quant_config, moe_config=make_dummy_moe_config(), ) - mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) + mk_deepgemm = FusedMoEModularKernel( + prep_finalize, + deepgemm_experts, + inplace=False, + ) out_deepgemm = mk_deepgemm( hidden_states=a, @@ -101,7 +108,6 @@ def test_batched_deepgemm_vs_triton( w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=False, global_num_experts=E, ) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 508df9e32..66508568e 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -9,6 +9,7 @@ from tests.kernels.moe.utils import ( make_dummy_moe_config, make_test_quant_config, make_test_weights, + modular_triton_fused_moe, ) from tests.kernels.quant_utils import ( native_per_token_group_quant_fp8, @@ -26,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm_shape, ) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - modular_triton_fused_moe, -) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) @@ -261,6 +259,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch) moe_config=make_dummy_moe_config(), quant_config=quant_config, ), + inplace=False, ) def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 3a5a66a38..d232d00fc 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -207,6 +207,7 @@ def run_with_expert_maps( ), quant_config=new_quant_config, ), + inplace=False, ) out_tensor = out_tensor + kernel(**kwargs) @@ -266,6 +267,7 @@ def run_8_bit( ), quant_config=quant_config, ), + inplace=False, ) return kernel(**kwargs) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 1bf5ced2e..11f535715 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -194,8 +194,11 @@ def make_ll_modular_kernel( quant_config=quant_config, moe_config=make_dummy_moe_config(), ) - mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) - return mk + return FusedMoEModularKernel( + prepare_finalize=a2a, + fused_experts=fused_experts, + inplace=False, + ) def make_ht_modular_kernel( @@ -224,8 +227,11 @@ def make_ht_modular_kernel( moe_config=make_dummy_moe_config(), quant_config=quant_config, ) - mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) - return mk + return FusedMoEModularKernel( + prepare_finalize=a2a, + fused_experts=fused_experts, + inplace=False, + ) def make_modular_kernel( @@ -318,7 +324,6 @@ def deepep_deepgemm_moe_impl( w2=w2, topk_weights=test_tensors.topk_weights, topk_ids=test_tensors.topk, - inplace=False, activation="silu", global_num_experts=num_experts, expert_map=build_expert_map(), diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index f740f5bf9..8d3ca1650 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -179,7 +179,11 @@ def make_modular_kernel( quant_config=quant_config, ) - mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) + mk = FusedMoEModularKernel( + prepare_finalize=a2a, + fused_experts=fused_experts, + inplace=False, + ) return mk @@ -256,7 +260,6 @@ def deep_ep_moe_impl( w2=w2, topk_weights=topk_weights_chunk, topk_ids=topk_chunk, - inplace=False, activation="silu", global_num_experts=num_experts, expert_map=build_expert_map(), diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 729b54753..7f9bccb73 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -115,6 +115,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): moe_config=make_dummy_moe_config(), quant_config=quant_config, ), + inplace=False, ) # triton reference @@ -135,7 +136,6 @@ def run_single_case(m, n, k, topk, num_experts, block_size): w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=False, ) diff = calc_diff(out_deepgemm, out_triton) assert diff < 0.001, f"Diff exceeded 1%: {diff}" diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 1c512b5b1..e62cf7941 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -301,6 +301,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( moe_config=moe_config, quant_config=quant_config, ), + inplace=False, ) flashinfer_cutlass_output = kernel( @@ -309,7 +310,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( td.layer.w2_weight, topk_weights, topk_ids, - inplace=False, activation=activation, global_num_experts=e, expert_map=None, diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 9bb61ddfa..113649afe 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -108,6 +108,7 @@ def test_flashinfer_fp4_moe_no_graph( flashinfer_experts = FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), FlashInferExperts(moe_config=moe_config, quant_config=quant_config), + inplace=False, ) fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation] diff --git a/tests/kernels/moe/test_modular_oai_triton_moe.py b/tests/kernels/moe/test_modular_oai_triton_moe.py index 38022e0e6..bebf18ef0 100644 --- a/tests/kernels/moe/test_modular_oai_triton_moe.py +++ b/tests/kernels/moe/test_modular_oai_triton_moe.py @@ -180,7 +180,11 @@ def oai_triton_moe_impl( else: fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config) - mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts) + mk = FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + fused_experts, + inplace=False, + ) return mk.forward( hidden_states=x, @@ -188,7 +192,6 @@ def oai_triton_moe_impl( w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, activation="swigluoai", global_num_experts=num_experts, expert_map=None, diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index a304e70fc..53fb43e3c 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -18,7 +18,11 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config +from tests.kernels.moe.utils import ( + fused_moe, + make_dummy_moe_config, + modular_triton_fused_moe, +) from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, set_current_vllm_config @@ -36,9 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( batched_fused_marlin_moe, fused_marlin_moe, ) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - modular_triton_fused_moe, -) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_permute_bias, ) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index a22b2088b..10678e376 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -95,6 +95,7 @@ def test_cutlass_fp4_moe_no_graph( moe_config=make_dummy_moe_config(), quant_config=quant_config, ), + inplace=False, ) cutlass_output = kernel( diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index ef37c1c74..213d28cda 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -172,6 +172,7 @@ def pplx_cutlass_moe( fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, experts, + inplace=False, ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 08519087e..deb3b9eb4 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -592,6 +592,7 @@ def pplx_moe( prepare_finalize, experts, shared_experts, + inplace=False, ) # Note: for now use_compile will error out if the problem size is diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 4883085cb..897bfddce 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -7,7 +7,11 @@ import vllm._custom_ops as ops from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe import ( + TritonExperts, + fused_experts, + fused_topk, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, @@ -20,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( NaiveBatchedExperts, ) from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.math_utils import round_up @@ -116,6 +123,7 @@ def batched_moe( quant_config=quant_config, moe_config=make_dummy_moe_config(), ), + inplace=False, ) return fused_experts(a, w1, w2, topk_weight, topk_ids) @@ -157,6 +165,7 @@ def naive_batched_moe( quant_config=quant_config, moe_config=make_dummy_moe_config(), ), + inplace=False, ) return fused_experts(a, w1, w2, topk_weight, topk_ids) @@ -554,3 +563,16 @@ def make_shared_experts( return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s) finally: torch.set_default_dtype(old_dtype) + + +def modular_triton_fused_moe( + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + shared_experts: torch.nn.Module | None = None, +) -> FusedMoEModularKernel: + return FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonExperts(moe_config, quant_config), + shared_experts, + inplace=False, + ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 6650367da..3a8c13b3a 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1083,13 +1083,16 @@ class FusedMoEConfig: router_logits_dtype: torch.dtype | None = None max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE - has_bias: bool = False - is_act_and_mul: bool = True - is_lora_enabled: bool = False + # This flag is used to disable the inplace optimization + # in MoE kernels. If this flag is True then the kernel + # should not be using inplace. If the flag is false, the + # kernel is free to use inplace or not. + disable_inplace: bool = True + def __post_init__(self): if self.dp_size > 1: logger.debug_once( diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 74f05a2c0..ac5a86067 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1165,6 +1165,7 @@ def cutlass_moe_w4a8_fp8( quant_config=quant_config, group_size=group_size, ), + inplace=False, ) return fn( diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 2e5167bdf..8d95665f7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -267,6 +267,7 @@ def fused_marlin_moe( if inplace: assert output is None, "Conflicting request" + assert not disable_inplace() quant_type = ScalarType.from_id(quant_type_id) assert quant_type in [ @@ -356,10 +357,7 @@ def fused_marlin_moe( ).view(-1, topk, K) if output is None: - if inplace and not disable_inplace(): - output = hidden_states - else: - output = torch.empty_like(hidden_states) + output = hidden_states if inplace else torch.empty_like(hidden_states) if moe_sum is None: return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 120b3c2d1..e0907368b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -27,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) @@ -1511,7 +1508,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: - if inplace and not disable_inplace(): + if inplace: return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts @@ -1534,6 +1531,8 @@ def fused_experts( if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + assert not inplace or not disable_inplace() + return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, w1=w1, @@ -1593,7 +1592,7 @@ def fused_experts_impl( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - inplace: bool = False, + inplace: bool, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, @@ -1712,10 +1711,7 @@ def fused_experts_impl( else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") - if inplace and not disable_inplace(): - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) + out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) if ocp_mx_scheme is not None: # TODO: On platforms for which `current_platform.supports_mx()` is True @@ -2291,15 +2287,3 @@ class TritonWNA16Experts(TritonExperts): # separate function is required for MoE + LoRA self.moe_sum(intermediate_cache3, output) - - -def modular_triton_fused_moe( - moe_config: FusedMoEConfig, - quant_config: FusedMoEQuantConfig, - shared_experts: torch.nn.Module | None = None, -) -> mk.FusedMoEModularKernel: - return mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonExperts(moe_config, quant_config), - shared_experts, - ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 3ad56cc4c..93db1c545 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -113,10 +113,6 @@ class FusedMoEMethodBase(QuantizeMethodBase): def supports_eplb(self) -> bool: return False - @property - def allow_inplace(self) -> bool: - return False - @property def method_name(self) -> str: return self.__class__.__name__ diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 7a2244a9b..c30eeb6dc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -46,6 +46,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): old_quant_method: FusedMoEMethodBase, prepare_finalize: FusedMoEPrepareAndFinalize, shared_experts: torch.nn.Module | None, + inplace: bool = False, ) -> "FusedMoEModularMethod": return FusedMoEModularMethod( old_quant_method, @@ -54,6 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, moe_parallel_config=moe_layer.moe_parallel_config, + inplace=inplace, ), ) @@ -61,10 +63,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): def supports_eplb(self) -> bool: return self.old_quant_method.supports_eplb - @property - def allow_inplace(self) -> bool: - return self.old_quant_method.allow_inplace - @property def method_name(self) -> str: return self.old_quant_method.method_name @@ -99,7 +97,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=self.allow_inplace, activation=layer.activation, global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b092cf6cf..3935fe374 100755 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -50,6 +50,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import ( from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) +from vllm.model_executor.layers.fused_moe.utils import ( + disable_inplace, +) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) @@ -560,6 +563,8 @@ class FusedMoE(CustomOp): activation=activation, device=vllm_config.device_config.device, routing_method=self.routing_method_type, + # TODO: in_dtype == out_dtype? + disable_inplace=disable_inplace() or self.shared_experts is not None, ) if self.use_mori_kernels: assert self.rocm_aiter_fmoe_enabled, ( @@ -650,7 +655,11 @@ class FusedMoE(CustomOp): "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) ) self.quant_method = FusedMoEModularMethod.make( - self, self.quant_method, prepare_finalize, self.shared_experts + self, + self.quant_method, + prepare_finalize, + self.shared_experts, + inplace=not self.moe_config.disable_inplace, ) @property diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 940a2c55f..598374af2 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module): fused_experts: FusedMoEPermuteExpertsUnpermute, shared_experts: torch.nn.Module | None = None, moe_parallel_config: FusedMoEParallelConfig | None = None, + inplace: bool = False, ): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts self.shared_experts = shared_experts + self.inplace = inplace # prefer an explicit FusedMoEParallelConfig when available (from # FusedMoE layers / tests). @@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module): w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - inplace: bool = False, activation: str = "silu", global_num_experts: int = -1, expert_map: torch.Tensor | None = None, @@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module): - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. - topk_ids (torch.Tensor): A map of row to expert id. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - activation (str): The activation function to apply after the first MoE layer. - global_num_experts (int): The total number of experts in the global @@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module): - torch.Tensor: The output tensor after applying the MoE layer. """ - if inplace and self.shared_experts is None and not disable_inplace(): + if self.inplace: + assert self.shared_experts is None + assert not disable_inplace() output = hidden_states else: output = torch.zeros_like(hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 70c251674..bc0fc9a88 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -472,7 +472,7 @@ def make_fp8_moe_kernel( fp8_backend: Fp8MoeBackend, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, shared_experts: torch.nn.Module | None = None, -) -> tuple[mk.FusedMoEModularKernel, bool]: +) -> mk.FusedMoEModularKernel: # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, @@ -512,8 +512,10 @@ def make_fp8_moe_kernel( else None ), moe_parallel_config=moe_config.moe_parallel_config, + inplace=( + not moe_config.disable_inplace + and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS + ), ) - # TODO(rob): update inplace logic to be part of the kernel. - inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS - return kernel, inplace + return kernel diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 276d231eb..dc3ac61ad 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -437,6 +437,7 @@ def make_nvfp4_moe_kernel( else None ), moe_parallel_config=moe_config.moe_parallel_config, + inplace=False, ) # TODO(rob): update inplace logic to be part of the kernel. diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index a8754d6d6..c4a19ecb6 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -154,11 +154,9 @@ def make_unquantized_moe_kernel( backend: UnquantizedMoeBackend, quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, -) -> tuple[mk.FusedMoEModularKernel | None, bool]: - use_inplace = True - +) -> mk.FusedMoEModularKernel | None: if backend in UNSUPPORTED_BACKEND: - return None, use_inplace + return None if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( @@ -171,8 +169,9 @@ def make_unquantized_moe_kernel( moe_config=moe_config, quant_config=quant_config, ), + inplace=False, ) - use_inplace = False + elif backend == UnquantizedMoeBackend.AITER: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( AiterExperts, @@ -184,6 +183,7 @@ def make_unquantized_moe_kernel( moe_config=moe_config, quant_config=quant_config, ), + inplace=not moe_config.disable_inplace, ) elif backend == UnquantizedMoeBackend.TRITON: from vllm.model_executor.layers.fused_moe import TritonExperts @@ -194,6 +194,7 @@ def make_unquantized_moe_kernel( moe_config=moe_config, quant_config=quant_config, ), + inplace=not moe_config.disable_inplace, ) elif backend == UnquantizedMoeBackend.XPU: from vllm.model_executor.layers.fused_moe import XPUExperts @@ -204,5 +205,6 @@ def make_unquantized_moe_kernel( moe_config=moe_config, quant_config=quant_config, ), + inplace=not moe_config.disable_inplace, ) - return kernel, use_inplace + return kernel diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 6fdd8ecf7..8a35be78b 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -101,10 +101,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def supports_eplb(self) -> bool: return True - @property - def allow_inplace(self) -> bool: - return True - def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, @@ -225,7 +221,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.moe_quant_config is not None - self.kernel, self.use_inplace = make_unquantized_moe_kernel( + self.kernel = make_unquantized_moe_kernel( backend=self.unquantized_backend, quant_config=self.moe_quant_config, moe_config=self.moe, @@ -329,7 +325,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=self.use_inplace, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 163ee78a7..642088a45 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -785,4 +785,5 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): w2_zeros=layer.w2_qzeros, workspace=layer.workspace, input_dtype=self.input_dtype, + inplace=not self.moe.disable_inplace, ) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 8b6b1e445..2fd567d7f 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -515,7 +515,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase): w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=not self.moe.disable_inplace, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 5152c5ccc..e25a415a5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -357,7 +357,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight, topk_weights, topk_ids, - inplace=False, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -669,7 +668,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight, topk_weights, topk_ids, - inplace=False, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -960,7 +958,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.moe_mk, self.use_inplace = make_fp8_moe_kernel( + self.moe_mk = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -1073,7 +1071,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight, topk_weights, topk_ids, - inplace=self.use_inplace, activation=layer.activation, global_num_experts=layer.global_num_experts, # TODO(rob): investigate the disable_expert_map introduced by: @@ -1212,7 +1209,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=not self.moe.disable_inplace, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, @@ -1739,6 +1736,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): workspace=layer.workspace, input_dtype=self.marlin_input_dtype, is_k_full=self.is_k_full, + inplace=not self.moe.disable_inplace, ) @@ -1969,7 +1967,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer.w2_weight_packed, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=not self.moe.disable_inplace, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, @@ -2605,6 +2603,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): s_strides1=self.s_strides1, s_strides2=self.s_strides2, group_size=self.group_size, + apply_router_weight_on_input=layer.apply_router_weight_on_input, ) @property diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 5a0bb5d30..176bfe040 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -149,7 +149,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase): layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=not self.moe.disable_inplace, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8b9fe0f3e..a61239706 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -854,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.moe_mk, self.use_inplace = make_fp8_moe_kernel( + self.moe_mk = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -958,10 +958,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): def supports_eplb(self) -> bool: return True - @property - def allow_inplace(self) -> bool: - return True - @property def is_monolithic(self) -> bool: return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM @@ -1032,7 +1028,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight, topk_weights, topk_ids, - inplace=self.use_inplace, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 698855c09..d18c7207d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -924,4 +924,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): workspace=layer.workspace, is_k_full=self.is_k_full, input_dtype=self.input_dtype, + inplace=not self.moe.disable_inplace, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e76c109ec..4474e630b 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -853,7 +853,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.moe_mk, self.use_inplace = make_fp8_moe_kernel( + self.moe_mk = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -967,7 +967,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w2_weight, topk_weights, topk_ids, - inplace=self.use_inplace, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -1538,7 +1537,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): layer.w2_weight, topk_weights, topk_ids, - inplace=False, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 34628591f..bca2516d4 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -378,7 +378,7 @@ class MoeWNA16Method(FusedMoEMethodBase): layer.w2_qweight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=not self.moe.disable_inplace, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index a50fa4bee..50009445d 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -881,10 +881,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" ) - @property - def allow_inplace(self) -> bool: - return True - @property def is_monolithic(self) -> bool: return ( @@ -923,6 +919,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): activation=layer.activation, expert_map=layer.expert_map, input_dtype=self.marlin_input_dtype, + inplace=not self.moe.disable_inplace, ) assert _can_support_mxfp4( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d2f0213e8..fc836c56b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -388,6 +388,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, + inplace=not self.moe.disable_inplace, ) else: from vllm.model_executor.layers.fused_moe import fused_experts @@ -398,7 +399,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=not self.moe.disable_inplace, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, @@ -734,10 +735,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): block_shape=None, ) - @property - def allow_inplace(self) -> bool: - return True - def apply( self, layer: FusedMoE, @@ -769,7 +766,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=not self.moe.disable_inplace, activation=layer.activation, global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input,