From b69bf2f0b170ac5b43f72f4dd4139c5388fa5de8 Mon Sep 17 00:00:00 2001 From: Wei Zhao <51183510+wzhao18@users.noreply.github.com> Date: Fri, 27 Mar 2026 19:30:46 -0400 Subject: [PATCH] [Perf] Use torch compile to fuse pack topk in trtllm moe (#37695) Signed-off-by: wzhao18 Signed-off-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com> --- .../layers/fused_moe/experts/trtllm_fp8_moe.py | 8 +++----- .../layers/fused_moe/experts/trtllm_nvfp4_moe.py | 5 ++--- vllm/model_executor/layers/fused_moe/utils.py | 13 +++++++++++++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 671435a88..9a6f67b42 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( activation_to_flashinfer_int, ) @@ -152,11 +153,8 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): import flashinfer from flashinfer.fused_moe import Fp8QuantizationType - # Pack topk_ids and topk_weights into single tensor - # Format: (expert_id << 16) | (weight_bf16.view(int16)) - packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view( - torch.int16 - ) + # Pack topk ids and weights into format expected by the kernel. + packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) # trtllm_fp8_block_scale_routed_moe does not support autotuning # so skip this kernel during dummy run for autotuning. diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 7960bdf44..84beb6abb 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( activation_to_flashinfer_int, ) @@ -183,9 +184,7 @@ class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModula assert self.quant_config.w2_scale is not None # Pack topk ids and weights into format expected by the kernel. - packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16 - ).view(torch.int16) + packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights) # trtllm_fp4_block_scale_routed_moe does not support autotuning # so skip this kernel during dummy run for autotuning. diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ba4494f6c..c576b0a25 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -323,3 +323,16 @@ def normalize_batched_scales_shape( @functools.cache def disable_inplace() -> bool: return is_torch_equal_or_newer("2.9") + + +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def trtllm_moe_pack_topk_ids_weights( + topk_ids: torch.Tensor, topk_weights: torch.Tensor +) -> torch.Tensor: + """ + Pack topk_ids and topk_weights into a single int32 tensor. + Format: (expert_id << 16) | weight_bf16.view(int16) + """ + return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view( + torch.int16 + )