[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)
This commit is contained in:
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
@@ -35,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_kernel_for_mkm,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||
select_cutlass_fp8_gemm_impl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
@@ -79,8 +78,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
@@ -658,38 +659,36 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.weight_scale_name = (
|
||||
"weight_scale_inv" if self.block_quant else "weight_scale"
|
||||
)
|
||||
self.fp8_backend = select_fp8_moe_backend(
|
||||
block_quant=self.block_quant,
|
||||
tp_size=layer.moe_parallel_config.tp_size,
|
||||
with_lora_support=self.moe.is_lora_enabled,
|
||||
is_act_and_mul=self.moe.is_act_and_mul,
|
||||
)
|
||||
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
if self.block_quant and self.weight_block_size != [128, 128]:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports block "
|
||||
"size [128, 128]."
|
||||
)
|
||||
if layer.activation != "silu":
|
||||
raise NotImplementedError(
|
||||
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
|
||||
"activation function, but got {layer.activation}."
|
||||
)
|
||||
dynamic_per_token = (
|
||||
not self.block_quant and self.quant_config.activation_scheme != "static"
|
||||
)
|
||||
if dynamic_per_token and self.fp8_backend in [
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
Fp8MoeBackend.FLASHINFER_CUTLASS,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
"FlashInfer FP8 MoE backend does not support dynamic per token "
|
||||
"activation quantization."
|
||||
# Set weight key and activation key for kernel compatibility
|
||||
if self.block_quant:
|
||||
weight_key = kFp8Static128BlockSym
|
||||
activation_key = kFp8Dynamic128Sym
|
||||
else:
|
||||
weight_key = kFp8StaticTensorSym
|
||||
activation_key = (
|
||||
kFp8StaticTensorSym
|
||||
if self.quant_config.activation_scheme == "static"
|
||||
else kFp8DynamicTensorSym
|
||||
)
|
||||
|
||||
# Select Fp8 MoE backend
|
||||
self.fp8_backend, self.experts_cls = select_fp8_moe_backend(
|
||||
config=self.moe,
|
||||
weight_key=weight_key,
|
||||
activation_key=activation_key,
|
||||
allow_vllm_cutlass=False,
|
||||
)
|
||||
|
||||
# Delay creation of the kernel until after process-weights.
|
||||
self.kernel: mk.FusedMoEModularKernel | None = None
|
||||
|
||||
@property
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
if self.kernel is not None:
|
||||
return self.kernel.prepare_finalize.topk_indices_dtype()
|
||||
return None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
@@ -842,14 +841,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
|
||||
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
|
||||
|
||||
# Setup modular kernel for TP case.
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
if self.moe_quant_config and (
|
||||
(not self.moe.moe_parallel_config.use_all2all_kernels)
|
||||
or self.moe.moe_parallel_config.use_naive_all2all_kernels
|
||||
):
|
||||
assert self.experts_cls is not None
|
||||
self.kernel, self.use_inplace = make_fp8_moe_kernel(
|
||||
layer=layer,
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
@@ -904,13 +910,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.fp8_backend in [
|
||||
Fp8MoeBackend.AITER,
|
||||
Fp8MoeBackend.MARLIN,
|
||||
Fp8MoeBackend.FLASHINFER_TRTLLM,
|
||||
]:
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
return None
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# For no-EP case, don't use the MKM framework.
|
||||
if not self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
@@ -924,73 +930,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
|
||||
if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
|
||||
raise NotImplementedError(
|
||||
"Marlin and ROCm AITER are not supported with all2all yet."
|
||||
)
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
experts_impl = (
|
||||
BatchedDeepGemmExperts
|
||||
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
else BatchedTritonExperts
|
||||
)
|
||||
logger.debug(
|
||||
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
experts_impl.__name__,
|
||||
self.__class__.__name__,
|
||||
max_num_tokens_per_rank,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return experts_impl(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
elif self.moe.is_lora_enabled:
|
||||
return TritonExperts(quant_config=self.moe_quant_config)
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
# Select GEMM experts with block-scale when weights are block-quantized
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||
logger.debug(
|
||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return TritonOrDeepGemmExperts(self.moe_quant_config)
|
||||
else:
|
||||
assert self.fp8_backend == Fp8MoeBackend.TRITON
|
||||
logger.debug(
|
||||
"TritonExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__,
|
||||
self.weight_block_size,
|
||||
False,
|
||||
)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
assert self.experts_cls is not None
|
||||
return make_fp8_moe_kernel_for_mkm(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
experts_cls=self.experts_cls,
|
||||
prepare_finalize=prepare_finalize,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
@@ -1067,7 +1014,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling=layer.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
result = apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
Reference in New Issue
Block a user