[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-02-05 13:07:18 -05:00
committed by GitHub
parent 1ee95841bd
commit a57c8228ff
37 changed files with 132 additions and 109 deletions

View File

@@ -7,7 +7,11 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe import (
TritonExperts,
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -20,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.deep_gemm import per_block_cast_to_fp8
from vllm.utils.math_utils import round_up
@@ -116,6 +123,7 @@ def batched_moe(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
@@ -157,6 +165,7 @@ def naive_batched_moe(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
@@ -554,3 +563,16 @@ def make_shared_experts(
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
finally:
torch.set_default_dtype(old_dtype)
def modular_triton_fused_moe(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
return FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(moe_config, quant_config),
shared_experts,
inplace=False,
)