[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,11 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonExperts,
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -20,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import round_up
|
||||
@@ -116,6 +123,7 @@ def batched_moe(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
@@ -157,6 +165,7 @@ def naive_batched_moe(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
@@ -554,3 +563,16 @@ def make_shared_experts(
|
||||
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> FusedMoEModularKernel:
|
||||
return FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(moe_config, quant_config),
|
||||
shared_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user