[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-02-05 13:07:18 -05:00
committed by GitHub
parent 1ee95841bd
commit a57c8228ff
37 changed files with 132 additions and 109 deletions

View File

@@ -620,6 +620,7 @@ def make_modular_kernel(
modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize,
fused_experts=fused_experts,
inplace=False,
)
return modular_kernel

View File

@@ -74,7 +74,11 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
mk_triton = FusedMoEModularKernel(
prep_finalize,
triton_experts,
inplace=False,
)
out_triton = mk_triton(
hidden_states=a,
@@ -82,7 +86,6 @@ def test_batched_deepgemm_vs_triton(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)
@@ -93,7 +96,11 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
mk_deepgemm = FusedMoEModularKernel(
prep_finalize,
deepgemm_experts,
inplace=False,
)
out_deepgemm = mk_deepgemm(
hidden_states=a,
@@ -101,7 +108,6 @@ def test_batched_deepgemm_vs_triton(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)

View File

@@ -9,6 +9,7 @@ from tests.kernels.moe.utils import (
make_dummy_moe_config,
make_test_quant_config,
make_test_weights,
modular_triton_fused_moe,
)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
@@ -26,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
@@ -261,6 +259,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
inplace=False,
)
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):

View File

@@ -207,6 +207,7 @@ def run_with_expert_maps(
),
quant_config=new_quant_config,
),
inplace=False,
)
out_tensor = out_tensor + kernel(**kwargs)
@@ -266,6 +267,7 @@ def run_8_bit(
),
quant_config=quant_config,
),
inplace=False,
)
return kernel(**kwargs)

View File

@@ -194,8 +194,11 @@ def make_ll_modular_kernel(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
return FusedMoEModularKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
)
def make_ht_modular_kernel(
@@ -224,8 +227,11 @@ def make_ht_modular_kernel(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
return mk
return FusedMoEModularKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
)
def make_modular_kernel(
@@ -318,7 +324,6 @@ def deepep_deepgemm_moe_impl(
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),

View File

@@ -179,7 +179,11 @@ def make_modular_kernel(
quant_config=quant_config,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
mk = FusedMoEModularKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
)
return mk
@@ -256,7 +260,6 @@ def deep_ep_moe_impl(
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),

View File

@@ -115,6 +115,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
inplace=False,
)
# triton reference
@@ -135,7 +136,6 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"

View File

@@ -301,6 +301,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
moe_config=moe_config,
quant_config=quant_config,
),
inplace=False,
)
flashinfer_cutlass_output = kernel(
@@ -309,7 +310,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.w2_weight,
topk_weights,
topk_ids,
inplace=False,
activation=activation,
global_num_experts=e,
expert_map=None,

View File

@@ -108,6 +108,7 @@ def test_flashinfer_fp4_moe_no_graph(
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
inplace=False,
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]

View File

@@ -180,7 +180,11 @@ def oai_triton_moe_impl(
else:
fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config)
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
mk = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
fused_experts,
inplace=False,
)
return mk.forward(
hidden_states=x,
@@ -188,7 +192,6 @@ def oai_triton_moe_impl(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation="swigluoai",
global_num_experts=num_experts,
expert_map=None,

View File

@@ -18,7 +18,11 @@ from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config
from tests.kernels.moe.utils import (
fused_moe,
make_dummy_moe_config,
modular_triton_fused_moe,
)
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
@@ -36,9 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
batched_fused_marlin_moe,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias,
)

View File

@@ -95,6 +95,7 @@ def test_cutlass_fp4_moe_no_graph(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
inplace=False,
)
cutlass_output = kernel(

View File

@@ -172,6 +172,7 @@ def pplx_cutlass_moe(
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
inplace=False,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)

View File

@@ -592,6 +592,7 @@ def pplx_moe(
prepare_finalize,
experts,
shared_experts,
inplace=False,
)
# Note: for now use_compile will error out if the problem size is

View File

@@ -7,7 +7,11 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe import (
TritonExperts,
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -20,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.deep_gemm import per_block_cast_to_fp8
from vllm.utils.math_utils import round_up
@@ -116,6 +123,7 @@ def batched_moe(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
@@ -157,6 +165,7 @@ def naive_batched_moe(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
@@ -554,3 +563,16 @@ def make_shared_experts(
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
finally:
torch.set_default_dtype(old_dtype)
def modular_triton_fused_moe(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
return FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(moe_config, quant_config),
shared_experts,
inplace=False,
)