[MoE Refactor] Create MK for TRTLLM Kernels (#32564)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -8,6 +8,9 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
fused_experts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
@@ -125,7 +125,9 @@ def batched_moe(
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
fused_experts = FusedMoEKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
@@ -133,12 +135,22 @@ def batched_moe(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
return fused_experts.apply(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
global_num_experts=w1.shape[0],
|
||||
activation=moe_config.activation,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
|
||||
def naive_batched_moe(
|
||||
@@ -166,8 +178,9 @@ def naive_batched_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
moe_config = make_dummy_moe_config()
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
fused_experts = FusedMoEKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
|
||||
),
|
||||
@@ -175,12 +188,22 @@ def naive_batched_moe(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
moe_config=moe_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
return fused_experts.apply(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
global_num_experts=w1.shape[0],
|
||||
activation=moe_config.activation,
|
||||
apply_router_weight_on_input=False,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
|
||||
def chunk_scales(
|
||||
@@ -581,9 +604,14 @@ def modular_triton_fused_moe(
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> FusedMoEModularKernel:
|
||||
return FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
) -> FusedMoEKernel:
|
||||
return FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
TritonExperts(moe_config, quant_config),
|
||||
shared_experts,
|
||||
inplace=False,
|
||||
|
||||
Reference in New Issue
Block a user