[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-02-05 13:07:18 -05:00
committed by GitHub
parent 1ee95841bd
commit a57c8228ff
37 changed files with 132 additions and 109 deletions

View File

@@ -1083,13 +1083,16 @@ class FusedMoEConfig:
router_logits_dtype: torch.dtype | None = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
is_act_and_mul: bool = True
is_lora_enabled: bool = False
# This flag is used to disable the inplace optimization
# in MoE kernels. If this flag is True then the kernel
# should not be using inplace. If the flag is false, the
# kernel is free to use inplace or not.
disable_inplace: bool = True
def __post_init__(self):
if self.dp_size > 1:
logger.debug_once(

View File

@@ -1165,6 +1165,7 @@ def cutlass_moe_w4a8_fp8(
quant_config=quant_config,
group_size=group_size,
),
inplace=False,
)
return fn(

View File

@@ -267,6 +267,7 @@ def fused_marlin_moe(
if inplace:
assert output is None, "Conflicting request"
assert not disable_inplace()
quant_type = ScalarType.from_id(quant_type_id)
assert quant_type in [
@@ -356,10 +357,7 @@ def fused_marlin_moe(
).view(-1, topk, K)
if output is None:
if inplace and not disable_inplace():
output = hidden_states
else:
output = torch.empty_like(hidden_states)
output = hidden_states if inplace else torch.empty_like(hidden_states)
if moe_sum is None:
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)

View File

@@ -27,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
@@ -1511,7 +1508,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if inplace and not disable_inplace():
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
@@ -1534,6 +1531,8 @@ def fused_experts(
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
assert not inplace or not disable_inplace()
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
@@ -1593,7 +1592,7 @@ def fused_experts_impl(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
inplace: bool,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
@@ -1712,10 +1711,7 @@ def fused_experts_impl(
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace and not disable_inplace():
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
@@ -2291,15 +2287,3 @@ class TritonWNA16Experts(TritonExperts):
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEModularKernel:
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(moe_config, quant_config),
shared_experts,
)

View File

@@ -113,10 +113,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def supports_eplb(self) -> bool:
return False
@property
def allow_inplace(self) -> bool:
return False
@property
def method_name(self) -> str:
return self.__class__.__name__

View File

@@ -46,6 +46,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method: FusedMoEMethodBase,
prepare_finalize: FusedMoEPrepareAndFinalize,
shared_experts: torch.nn.Module | None,
inplace: bool = False,
) -> "FusedMoEModularMethod":
return FusedMoEModularMethod(
old_quant_method,
@@ -54,6 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
moe_parallel_config=moe_layer.moe_parallel_config,
inplace=inplace,
),
)
@@ -61,10 +63,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def supports_eplb(self) -> bool:
return self.old_quant_method.supports_eplb
@property
def allow_inplace(self) -> bool:
return self.old_quant_method.allow_inplace
@property
def method_name(self) -> str:
return self.old_quant_method.method_name
@@ -99,7 +97,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.allow_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,

View File

@@ -50,6 +50,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.utils import (
disable_inplace,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
@@ -560,6 +563,8 @@ class FusedMoE(CustomOp):
activation=activation,
device=vllm_config.device_config.device,
routing_method=self.routing_method_type,
# TODO: in_dtype == out_dtype?
disable_inplace=disable_inplace() or self.shared_experts is not None,
)
if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
@@ -650,7 +655,11 @@ class FusedMoE(CustomOp):
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
)
self.quant_method = FusedMoEModularMethod.make(
self, self.quant_method, prepare_finalize, self.shared_experts
self,
self.quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
)
@property

View File

@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self.inplace = inplace
# prefer an explicit FusedMoEParallelConfig when available (from
# FusedMoE layers / tests).
@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module):
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if inplace and self.shared_experts is None and not disable_inplace():
if self.inplace:
assert self.shared_experts is None
assert not disable_inplace()
output = hidden_states
else:
output = torch.zeros_like(hidden_states)

View File

@@ -472,7 +472,7 @@ def make_fp8_moe_kernel(
fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> tuple[mk.FusedMoEModularKernel, bool]:
) -> mk.FusedMoEModularKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
@@ -512,8 +512,10 @@ def make_fp8_moe_kernel(
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=(
not moe_config.disable_inplace
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
),
)
# TODO(rob): update inplace logic to be part of the kernel.
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
return kernel, inplace
return kernel

View File

@@ -437,6 +437,7 @@ def make_nvfp4_moe_kernel(
else None
),
moe_parallel_config=moe_config.moe_parallel_config,
inplace=False,
)
# TODO(rob): update inplace logic to be part of the kernel.

View File

@@ -154,11 +154,9 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
) -> tuple[mk.FusedMoEModularKernel | None, bool]:
use_inplace = True
) -> mk.FusedMoEModularKernel | None:
if backend in UNSUPPORTED_BACKEND:
return None, use_inplace
return None
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
@@ -171,8 +169,9 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
use_inplace = False
elif backend == UnquantizedMoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
@@ -184,6 +183,7 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
@@ -194,6 +194,7 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.XPU:
from vllm.model_executor.layers.fused_moe import XPUExperts
@@ -204,5 +205,6 @@ def make_unquantized_moe_kernel(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=not moe_config.disable_inplace,
)
return kernel, use_inplace
return kernel

View File

@@ -101,10 +101,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@@ -225,7 +221,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
self.kernel, self.use_inplace = make_unquantized_moe_kernel(
self.kernel = make_unquantized_moe_kernel(
backend=self.unquantized_backend,
quant_config=self.moe_quant_config,
moe_config=self.moe,
@@ -329,7 +325,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,

View File

@@ -785,4 +785,5 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
w2_zeros=layer.w2_qzeros,
workspace=layer.workspace,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
)

View File

@@ -515,7 +515,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,

View File

@@ -357,7 +357,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
@@ -669,7 +668,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
@@ -960,7 +958,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
@@ -1073,7 +1071,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
# TODO(rob): investigate the disable_expert_map introduced by:
@@ -1212,7 +1209,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
@@ -1739,6 +1736,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
workspace=layer.workspace,
input_dtype=self.marlin_input_dtype,
is_k_full=self.is_k_full,
inplace=not self.moe.disable_inplace,
)
@@ -1969,7 +1967,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_packed,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
@@ -2605,6 +2603,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
s_strides1=self.s_strides1,
s_strides2=self.s_strides2,
group_size=self.group_size,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
@property

View File

@@ -149,7 +149,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,

View File

@@ -854,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
@@ -958,10 +958,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
@property
def allow_inplace(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
@@ -1032,7 +1028,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,

View File

@@ -924,4 +924,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
)

View File

@@ -853,7 +853,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
self.moe_mk = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
@@ -967,7 +967,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
@@ -1538,7 +1537,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,

View File

@@ -378,7 +378,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,

View File

@@ -881,10 +881,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
)
@property
def allow_inplace(self) -> bool:
return True
@property
def is_monolithic(self) -> bool:
return (
@@ -923,6 +919,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=layer.activation,
expert_map=layer.expert_map,
input_dtype=self.marlin_input_dtype,
inplace=not self.moe.disable_inplace,
)
assert _can_support_mxfp4(

View File

@@ -388,6 +388,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
inplace=not self.moe.disable_inplace,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -398,7 +399,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
@@ -734,10 +735,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
block_shape=None,
)
@property
def allow_inplace(self) -> bool:
return True
def apply(
self,
layer: FusedMoE,
@@ -769,7 +766,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,