[MoE Refactor] Create MK for TRTLLM Kernels (#32564)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
@@ -170,7 +170,7 @@ def make_ll_modular_kernel(
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
) -> FusedMoEKernel:
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
|
||||
@@ -195,7 +195,7 @@ def make_ll_modular_kernel(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
return FusedMoEModularKernel(
|
||||
return FusedMoEKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
@@ -210,7 +210,7 @@ def make_ht_modular_kernel(
|
||||
q_dtype: torch.dtype | None,
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
) -> FusedMoEKernel:
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
|
||||
@@ -228,7 +228,7 @@ def make_ht_modular_kernel(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
return FusedMoEModularKernel(
|
||||
return FusedMoEKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
@@ -242,11 +242,11 @@ def make_modular_kernel(
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
) -> FusedMoEKernel:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
|
||||
mk: FusedMoEModularKernel
|
||||
mk: FusedMoEKernel
|
||||
# Make modular kernel
|
||||
if test_config.low_latency:
|
||||
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
|
||||
@@ -307,7 +307,7 @@ def deepep_deepgemm_moe_impl(
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
mk: FusedMoEKernel = make_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
@@ -319,7 +319,7 @@ def deepep_deepgemm_moe_impl(
|
||||
with with_dp_metadata(
|
||||
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||
):
|
||||
out = mk.forward(
|
||||
out = mk.apply(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
|
||||
Reference in New Issue
Block a user