[MoE Refactor][2/N] Use Modular Kernels for Fp8 (#30825)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -51,7 +50,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|||||||
FlashinferMoeBackend,
|
FlashinferMoeBackend,
|
||||||
apply_flashinfer_per_tensor_scale_fp8,
|
apply_flashinfer_per_tensor_scale_fp8,
|
||||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||||
flashinfer_cutlass_moe_fp8,
|
|
||||||
get_flashinfer_moe_backend,
|
get_flashinfer_moe_backend,
|
||||||
register_moe_scaling_factors,
|
register_moe_scaling_factors,
|
||||||
rotate_flashinfer_fp8_moe_weights,
|
rotate_flashinfer_fp8_moe_weights,
|
||||||
@@ -728,18 +726,28 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||||
if self.block_quant:
|
if self.block_quant and self.weight_block_size != [128, 128]:
|
||||||
assert self.weight_block_size == [128, 128], (
|
raise NotImplementedError(
|
||||||
f"Only support weight_block_size == [128, 128], "
|
"FlashInfer CUTLASS FP8 MoE backend only supports block "
|
||||||
f"got {self.weight_block_size}"
|
"size [128, 128]."
|
||||||
|
)
|
||||||
|
if not self.block_quant:
|
||||||
|
if layer.renormalize or layer.custom_routing_function is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashInfer CUTLASS FP8 MoE backend does custom routing "
|
||||||
|
f"function or renormalization, but got {layer.renormalize} and "
|
||||||
|
f"{layer.custom_routing_function}."
|
||||||
|
)
|
||||||
|
if layer.scoring_func != "sigmoid":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashInfer CUTLASS FP8 MoE backend only supports "
|
||||||
|
f"'sigmoid' scoring function, but got {layer.scoring_func}."
|
||||||
|
)
|
||||||
|
if layer.activation != "silu":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
|
||||||
|
"activation function, but got {layer.activation}."
|
||||||
)
|
)
|
||||||
self.flashinfer_moe_fn = partial(
|
|
||||||
flashinfer_cutlass_moe_fp8,
|
|
||||||
moe=self.moe,
|
|
||||||
use_deepseek_fp8_block_scale=self.block_quant,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -928,7 +936,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||||
# it ahead of time for performance reasons.
|
# it ahead of time for performance reasons.
|
||||||
if self.allow_deep_gemm:
|
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
|
||||||
dg_w13_weight, dg_w13_weight_scale_inv = (
|
dg_w13_weight, dg_w13_weight_scale_inv = (
|
||||||
deepgemm_post_process_fp8_weight_block(
|
deepgemm_post_process_fp8_weight_block(
|
||||||
wq=layer.w13_weight.data,
|
wq=layer.w13_weight.data,
|
||||||
@@ -1039,6 +1047,61 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
|
# NOTE(rob): this is a WIP refactor. We are first migrating
|
||||||
|
# all of the kernels in the TP case to use mk. Once this is
|
||||||
|
# done, then we will initialzie the TP case and DP/EP case
|
||||||
|
# via the same code path (i.e. via maybe_init_modular_kernel).
|
||||||
|
# NOTE(rob): in progress migrating all into this format.
|
||||||
|
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
|
FlashInferExperts,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
|
FlashInferAllGatherMoEPrepareAndFinalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = self.get_fused_moe_quant_config(layer)
|
||||||
|
assert config is not None
|
||||||
|
self.moe_quant_config = config
|
||||||
|
|
||||||
|
self.kernel = mk.FusedMoEModularKernel(
|
||||||
|
FlashInferAllGatherMoEPrepareAndFinalize(
|
||||||
|
use_dp=(self.moe.dp_size > 1),
|
||||||
|
use_deepseek_fp8_block_scale=self.block_quant,
|
||||||
|
),
|
||||||
|
FlashInferExperts(
|
||||||
|
out_dtype=torch.get_default_dtype(),
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
ep_rank=self.moe.ep_rank,
|
||||||
|
ep_size=self.moe.ep_size,
|
||||||
|
tp_rank=self.moe.tp_rank,
|
||||||
|
tp_size=self.moe.tp_size,
|
||||||
|
use_dp=(self.moe.dp_size > 1),
|
||||||
|
use_deepseek_fp8_block_scale=self.block_quant,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.use_inplace = False
|
||||||
|
|
||||||
|
elif self.fp8_backend in [Fp8MoeBackend.DEEPGEMM, Fp8MoeBackend.TRITON]:
|
||||||
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
|
TritonOrDeepGemmExperts,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
|
MoEPrepareAndFinalizeNoEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = self.get_fused_moe_quant_config(layer)
|
||||||
|
assert config is not None
|
||||||
|
self.moe_quant_config = config
|
||||||
|
self.kernel = mk.FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
|
TritonOrDeepGemmExperts(
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.use_inplace = True
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||||
@@ -1091,7 +1154,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert max_num_tokens_per_rank is not None
|
assert max_num_tokens_per_rank is not None
|
||||||
|
|
||||||
experts_impl = (
|
experts_impl = (
|
||||||
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
|
BatchedDeepGemmExperts
|
||||||
|
if self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||||
|
else BatchedTritonExperts
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||||
@@ -1126,7 +1191,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
return TritonOrDeepGemmExperts(
|
return TritonOrDeepGemmExperts(
|
||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
@@ -1164,6 +1229,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
|
# TODO(rob): convert this to MK.
|
||||||
if layer.enable_eplb:
|
if layer.enable_eplb:
|
||||||
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
|
||||||
assert layer.activation == "silu", (
|
assert layer.activation == "silu", (
|
||||||
@@ -1228,6 +1294,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
rocm_aiter_fused_experts,
|
rocm_aiter_fused_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(rob): convert this to MK.
|
||||||
result = rocm_aiter_fused_experts(
|
result = rocm_aiter_fused_experts(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@@ -1240,6 +1307,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
)
|
)
|
||||||
elif self.use_marlin:
|
elif self.use_marlin:
|
||||||
|
# TODO(rob): convert this to MK.
|
||||||
assert layer.activation == "silu", (
|
assert layer.activation == "silu", (
|
||||||
f"{layer.activation} not supported for Marlin MoE."
|
f"{layer.activation} not supported for Marlin MoE."
|
||||||
)
|
)
|
||||||
@@ -1261,47 +1329,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
input_dtype=self.marlin_input_dtype,
|
input_dtype=self.marlin_input_dtype,
|
||||||
workspace=layer.workspace,
|
workspace=layer.workspace,
|
||||||
)
|
)
|
||||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
else:
|
||||||
assert layer.activation == "silu", (
|
result = self.kernel(
|
||||||
f"Expected 'silu' activation but got {layer.activation}"
|
|
||||||
)
|
|
||||||
if not self.block_quant:
|
|
||||||
assert (
|
|
||||||
not layer.renormalize and layer.custom_routing_function is not None
|
|
||||||
)
|
|
||||||
assert layer.scoring_func == "sigmoid", (
|
|
||||||
f"Expected 'sigmoid' scoring func but got {layer.scoring_func}"
|
|
||||||
)
|
|
||||||
# Delegate to CUTLASS FlashInfer path; function already bound with
|
|
||||||
# use_deepseek_fp8_block_scale for block-quant when applicable
|
|
||||||
result = self.flashinfer_moe_fn(
|
|
||||||
x,
|
x,
|
||||||
layer,
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
inplace=False,
|
inplace=self.use_inplace,
|
||||||
activation=layer.activation,
|
activation=layer.activation,
|
||||||
global_num_experts=layer.global_num_experts,
|
global_num_experts=layer.global_num_experts,
|
||||||
expert_map=layer.expert_map,
|
expert_map=layer.expert_map,
|
||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
|
|
||||||
result = fused_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
activation=layer.activation,
|
|
||||||
global_num_experts=layer.global_num_experts,
|
|
||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
|
||||||
expert_map=layer.expert_map,
|
|
||||||
quant_config=self.moe_quant_config,
|
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
if layer.zero_expert_num != 0 and layer.zero_expert_type is not None:
|
||||||
assert not isinstance(result, tuple), (
|
assert not isinstance(result, tuple), (
|
||||||
|
|||||||
Reference in New Issue
Block a user