[Perf] Add TRTLLM FP8 MoE Modular Kernel (#36307)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Wei Zhao
2026-03-12 10:32:31 -04:00
committed by GitHub
parent 7f1f36bf91
commit 2e693f48e7
3 changed files with 236 additions and 114 deletions

View File

@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
TrtLlmFp8Experts,
TrtLlmFp8ExpertsMonolithic,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
@@ -247,7 +247,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
allow_new_interface=True,
use_monolithic=True,
),
TrtLlmFp8Experts(
TrtLlmFp8ExpertsMonolithic(
moe_config=td.layer.moe,
quant_config=quant_config,
),

View File

@@ -4,6 +4,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -11,6 +12,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
@@ -22,10 +26,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
class TrtLlmFp8ExpertsBase:
"""
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
Fp8 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic
interfaces.
"""
def __init__(
@@ -33,8 +40,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
self.routing_method_type = moe_config.routing_method
self.topk = moe_config.experts_per_token
self.intermediate_size_per_partition = (
@@ -44,24 +49,7 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
# Make additional scales for per-tensor interface.
if self.quant_config.is_per_tensor:
w1_scale = self.quant_config.w1_scale
assert w1_scale is not None
a1_scale = self.quant_config.a1_scale
assert a1_scale is not None
w2_scale = self.quant_config.w2_scale
assert w2_scale is not None
a2_scale = self.quant_config.a2_scale
assert a2_scale is not None
self._g1_alphas = (w1_scale * a1_scale).squeeze()
self._g2_alphas = (w2_scale * a2_scale).squeeze()
self._g1_scale_c = (
self._g1_alphas / self.quant_config.a2_scale
if moe_config.is_act_and_mul
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
)
self.quant_config = quant_config
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
@@ -79,50 +67,11 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
"""Supports only SiLU and RELU^2 non-gated activation."""
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""Monolithic kernel so only use with naive DP/EP and TP."""
@@ -153,6 +102,178 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
def supports_expert_map(self) -> bool:
return False
class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports modular interface.
"""
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
return (workspace1, workspace2, output)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
import flashinfer
# Pack topk_ids and topk_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)
# trtllm_fp8_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.
import vllm.utils.flashinfer as fi_utils
if fi_utils._is_fi_autotuning:
return
assert a1q_scale is not None
# `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the
# output tensor in-place so we need to manually copy the result to the
# output tensor
# https://github.com/flashinfer-ai/flashinfer/issues/2703
result = flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe(
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr]
gemm1_weights=w1,
gemm1_weights_scale=self.quant_config.w1_scale,
gemm2_weights=w2,
gemm2_weights_scale=self.quant_config.w2_scale,
num_experts=global_num_experts,
top_k=self.topk,
n_group=None,
topk_group=None,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.ep_rank * self.local_num_experts,
local_num_experts=self.local_num_experts,
routed_scaling_factor=None,
routing_method_type=1,
use_shuffled_weight=False,
weight_layout=0,
# output=output,
)
output.copy_(result)
class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolithic):
"""
Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
# Make additional scales for per-tensor interface.
if self.quant_config.is_per_tensor:
w1_scale = self.quant_config.w1_scale
assert w1_scale is not None
a1_scale = self.quant_config.a1_scale
assert a1_scale is not None
w2_scale = self.quant_config.w2_scale
assert w2_scale is not None
a2_scale = self.quant_config.a2_scale
assert a2_scale is not None
self._g1_alphas = (w1_scale * a1_scale).squeeze()
self._g2_alphas = (w2_scale * a2_scale).squeeze()
self._g1_scale_c = (
self._g1_alphas / self.quant_config.a2_scale
if moe_config.is_act_and_mul
else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
)
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Supports Fp8 per-tensor and Fp8 block."""
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_routing_method(
routing_method: RoutingMethodType,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
"""Monolithic kernels need to express router support."""
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
# NOTE(dbari): Default is not implemented and should not be enabled until it is
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
]
else:
raise ValueError("Unsupported quantization scheme.")
def _apply_per_block(
self,
hidden_states: torch.Tensor,

View File

@@ -104,83 +104,84 @@ def _get_priority_backends(
def backend_to_kernel_cls(
backend: Fp8MoeBackend,
) -> type[mk.FusedMoEExperts]:
) -> list[type[mk.FusedMoEExperts]]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
TrtLlmFp8Experts,
TrtLlmFp8ExpertsModular,
TrtLlmFp8ExpertsMonolithic,
)
return TrtLlmFp8Experts
return [TrtLlmFp8ExpertsMonolithic, TrtLlmFp8ExpertsModular]
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
return FlashInferExperts
return [FlashInferExperts]
elif backend == Fp8MoeBackend.DEEPGEMM:
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
return TritonOrDeepGemmExperts
return [TritonOrDeepGemmExperts]
elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
return BatchedDeepGemmExperts
return [BatchedDeepGemmExperts]
elif backend == Fp8MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return MarlinExperts
return [MarlinExperts]
elif backend == Fp8MoeBackend.TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
return TritonExperts
return [TritonExperts]
elif backend == Fp8MoeBackend.BATCHED_TRITON:
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
return BatchedTritonExperts
return [BatchedTritonExperts]
elif backend == Fp8MoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
return AiterExperts
return [AiterExperts]
elif backend == Fp8MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
TritonOrCutlassExperts,
)
return TritonOrCutlassExperts
return [TritonOrCutlassExperts]
elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8,
)
return CutlassBatchedExpertsFp8
return [CutlassBatchedExpertsFp8]
elif backend == Fp8MoeBackend.XPU:
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExpertsFp8,
)
return XPUExpertsFp8
return [XPUExpertsFp8]
else:
raise ValueError(f"Unknown FP8 MoE backend: {backend.value}")
@@ -215,8 +216,9 @@ def select_fp8_moe_backend(
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)[0]
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key)
@@ -256,13 +258,13 @@ def select_fp8_moe_backend(
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
@@ -312,7 +314,7 @@ def select_fp8_moe_backend(
raise ValueError(
f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
)
k_cls = backend_to_kernel_cls(backend)
k_cls = backend_to_kernel_cls(backend)[0]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
@@ -322,23 +324,23 @@ def select_fp8_moe_backend(
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no "
"FlashInfer FP8 MoE backend supports the configuration."
@@ -382,20 +384,19 @@ def select_fp8_moe_backend(
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
# TODO(rob): per discussion with TPU team, we need a way to register
# MoE backends by OOT plugins, rather than having an explicit list