[Perf] Use torch compile to fuse pack topk in trtllm moe (#37695)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
@@ -152,11 +153,8 @@ class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular):
|
||||
import flashinfer
|
||||
from flashinfer.fused_moe import Fp8QuantizationType
|
||||
|
||||
# Pack topk_ids and topk_weights into single tensor
|
||||
# Format: (expert_id << 16) | (weight_bf16.view(int16))
|
||||
packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view(
|
||||
torch.int16
|
||||
)
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_topk_ids = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
|
||||
|
||||
# trtllm_fp8_block_scale_routed_moe does not support autotuning
|
||||
# so skip this kernel during dummy run for autotuning.
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import trtllm_moe_pack_topk_ids_weights
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
activation_to_flashinfer_int,
|
||||
)
|
||||
@@ -183,9 +184,7 @@ class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModula
|
||||
assert self.quant_config.w2_scale is not None
|
||||
|
||||
# Pack topk ids and weights into format expected by the kernel.
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
packed_tensor = trtllm_moe_pack_topk_ids_weights(topk_ids, topk_weights)
|
||||
|
||||
# trtllm_fp4_block_scale_routed_moe does not support autotuning
|
||||
# so skip this kernel during dummy run for autotuning.
|
||||
|
||||
@@ -323,3 +323,16 @@ def normalize_batched_scales_shape(
|
||||
@functools.cache
|
||||
def disable_inplace() -> bool:
|
||||
return is_torch_equal_or_newer("2.9")
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def trtllm_moe_pack_topk_ids_weights(
|
||||
topk_ids: torch.Tensor, topk_weights: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pack topk_ids and topk_weights into a single int32 tensor.
|
||||
Format: (expert_id << 16) | weight_bf16.view(int16)
|
||||
"""
|
||||
return (topk_ids.to(torch.int32) << 16) | topk_weights.to(torch.bfloat16).view(
|
||||
torch.int16
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user