[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-02-05 13:07:18 -05:00
committed by GitHub
parent 1ee95841bd
commit a57c8228ff
37 changed files with 132 additions and 109 deletions

View File

@@ -74,7 +74,11 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
mk_triton = FusedMoEModularKernel(
prep_finalize,
triton_experts,
inplace=False,
)
out_triton = mk_triton(
hidden_states=a,
@@ -82,7 +86,6 @@ def test_batched_deepgemm_vs_triton(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)
@@ -93,7 +96,11 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
mk_deepgemm = FusedMoEModularKernel(
prep_finalize,
deepgemm_experts,
inplace=False,
)
out_deepgemm = mk_deepgemm(
hidden_states=a,
@@ -101,7 +108,6 @@ def test_batched_deepgemm_vs_triton(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
global_num_experts=E,
)