[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -620,6 +620,7 @@ def make_modular_kernel(
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
@@ -74,7 +74,11 @@ def test_batched_deepgemm_vs_triton(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||
mk_triton = FusedMoEModularKernel(
|
||||
prep_finalize,
|
||||
triton_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
out_triton = mk_triton(
|
||||
hidden_states=a,
|
||||
@@ -82,7 +86,6 @@ def test_batched_deepgemm_vs_triton(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
@@ -93,7 +96,11 @@ def test_batched_deepgemm_vs_triton(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||
mk_deepgemm = FusedMoEModularKernel(
|
||||
prep_finalize,
|
||||
deepgemm_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
out_deepgemm = mk_deepgemm(
|
||||
hidden_states=a,
|
||||
@@ -101,7 +108,6 @@ def test_batched_deepgemm_vs_triton(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from tests.kernels.moe.utils import (
|
||||
make_dummy_moe_config,
|
||||
make_test_quant_config,
|
||||
make_test_weights,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
@@ -26,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
@@ -261,6 +259,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
|
||||
|
||||
@@ -207,6 +207,7 @@ def run_with_expert_maps(
|
||||
),
|
||||
quant_config=new_quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
out_tensor = out_tensor + kernel(**kwargs)
|
||||
|
||||
@@ -266,6 +267,7 @@ def run_8_bit(
|
||||
),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
return kernel(**kwargs)
|
||||
|
||||
|
||||
@@ -194,8 +194,11 @@ def make_ll_modular_kernel(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
return FusedMoEModularKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
|
||||
def make_ht_modular_kernel(
|
||||
@@ -224,8 +227,11 @@ def make_ht_modular_kernel(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
return FusedMoEModularKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
@@ -318,7 +324,6 @@ def deepep_deepgemm_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
|
||||
@@ -179,7 +179,11 @@ def make_modular_kernel(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
mk = FusedMoEModularKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
return mk
|
||||
|
||||
|
||||
@@ -256,7 +260,6 @@ def deep_ep_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
|
||||
@@ -115,6 +115,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# triton reference
|
||||
@@ -135,7 +136,6 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
@@ -301,6 +301,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
flashinfer_cutlass_output = kernel(
|
||||
@@ -309,7 +310,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
td.layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
|
||||
@@ -108,6 +108,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||
|
||||
@@ -180,7 +180,11 @@ def oai_triton_moe_impl(
|
||||
else:
|
||||
fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
|
||||
mk = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return mk.forward(
|
||||
hidden_states=x,
|
||||
@@ -188,7 +192,6 @@ def oai_triton_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation="swigluoai",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
|
||||
@@ -18,7 +18,11 @@ from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config
|
||||
from tests.kernels.moe.utils import (
|
||||
fused_moe,
|
||||
make_dummy_moe_config,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -36,9 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
batched_fused_marlin_moe,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_permute_bias,
|
||||
)
|
||||
|
||||
@@ -95,6 +95,7 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
cutlass_output = kernel(
|
||||
|
||||
@@ -172,6 +172,7 @@ def pplx_cutlass_moe(
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
|
||||
@@ -592,6 +592,7 @@ def pplx_moe(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
|
||||
@@ -7,7 +7,11 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonExperts,
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -20,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import round_up
|
||||
@@ -116,6 +123,7 @@ def batched_moe(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
@@ -157,6 +165,7 @@ def naive_batched_moe(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
@@ -554,3 +563,16 @@ def make_shared_experts(
|
||||
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> FusedMoEModularKernel:
|
||||
return FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(moe_config, quant_config),
|
||||
shared_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user