Refactor pplx init logic to make it modular (prepare for deepep) (#18200)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -10,7 +10,6 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@@ -461,7 +460,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
self.fused_experts = functools.partial(
|
||||
self.fused_experts = functools.partial( # type: ignore
|
||||
fused_experts,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm)
|
||||
@@ -791,17 +790,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def set_prepare_finalize(
|
||||
self,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
) -> bool:
|
||||
def select_gemm_impl(self, prepare_finalize):
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
return False
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
@@ -809,12 +803,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
self.fused_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
)
|
||||
|
||||
return True
|
||||
return experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user