[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: amirkl94 <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
@@ -33,7 +33,6 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_kernel_for_mkm,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
@@ -53,7 +52,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
@@ -679,15 +677,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
allow_vllm_cutlass=False,
|
||||
)
|
||||
|
||||
# Delay creation of the kernel until after process-weights.
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.kernel is not None:
|
||||
return self.kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
@@ -813,7 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def _setup_kernel(
|
||||
self,
|
||||
layer: Module,
|
||||
layer: FusedMoE,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
@@ -845,16 +834,15 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config and (
|
||||
(not self.moe.moe_parallel_config.use_all2all_kernels)
|
||||
or self.moe.moe_parallel_config.use_naive_all2all_kernels
|
||||
):
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
@@ -909,33 +897,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# For no-EP case, don't use the MKM framework.
|
||||
if not self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.experts_cls is not None
|
||||
return make_fp8_moe_kernel_for_mkm(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
experts_cls=self.experts_cls,
|
||||
prepare_finalize=prepare_finalize,
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@@ -1037,9 +1011,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.kernel is not None
|
||||
assert self.moe_mk is not None
|
||||
assert not self.is_monolithic
|
||||
return self.kernel(
|
||||
return self.moe_mk(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
||||
Reference in New Issue
Block a user