Refactor pplx init logic to make it modular (prepare for deepep) (#18200)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-05-23 23:43:43 +08:00
committed by GitHub
parent 022d8abe29
commit 6a7988c55b
16 changed files with 300 additions and 287 deletions

View File

@@ -10,7 +10,6 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@@ -461,7 +460,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.warning_once(
"DeepGemm not supported on the current platform.")
self.fused_experts = functools.partial(
self.fused_experts = functools.partial( # type: ignore
fused_experts,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
@@ -791,17 +790,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
) -> bool:
def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
if self.use_marlin or self.rocm_aiter_moe_enabled:
return False
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
@@ -809,12 +803,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
allow_deep_gemm=self.allow_deep_gemm,
)
self.fused_experts = mk.FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
return experts
def apply(
self,