[Perf] Use torch compile to fuse pack topk in trtllm moe (#37695)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
Wei Zhao
2026-03-27 19:30:46 -04:00
committed by GitHub
parent 88149b635e
commit b69bf2f0b1
3 changed files with 18 additions and 8 deletions

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
@@ -152,11 +153,8 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
import flashinfer
from flashinfer.fused_moe import Fp8QuantizationType
# Pack topk_ids and topk_weights into single tensor
# Format: (expert_id << 16) | (weight_bf16.view(int16))
packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)
# Pack topk ids and weights into format expected by the kernel.
packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
# trtllm_fp8_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
activation_to_flashinfer_int,
)
@@ -183,9 +184,7 @@ class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModula
assert self.quant_config.w2_scale is not None
# Pack topk ids and weights into format expected by the kernel.
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16
).view(torch.int16)
packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
# trtllm_fp4_block_scale_routed_moe does not support autotuning
# so skip this kernel during dummy run for autotuning.

View File

@@ -323,3 +323,16 @@ def normalize_batched_scales_shape(
@functools.cache
def disable_inplace() -> bool:
return is_torch_equal_or_newer("2.9")
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def trtllm_moe_pack_topk_ids_weights(
topk_ids: torch.Tensor, topk_weights: torch.Tensor
) -> torch.Tensor:
"""
Pack topk_ids and topk_weights into a single int32 tensor.
Format: (expert_id << 16) | weight_bf16.view(int16)
"""
return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view(
torch.int16
)