[NVIDIA] [Perf] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#26714)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
jiahanc
2025-10-16 16:20:25 -07:00
committed by GitHub
parent fb5e10d3fb
commit 41d3071918
7 changed files with 25 additions and 96 deletions

View File

@@ -37,7 +37,7 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
trtllm_fp4_block_scale_moe,
)
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
@dataclass
@@ -319,7 +319,7 @@ def tg_mxfp4_moe(
if transpose_optimized:
for i in range(num_experts):
# w13 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
@@ -330,7 +330,7 @@ def tg_mxfp4_moe(
.contiguous()
)
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
permute_sf_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
@@ -344,7 +344,7 @@ def tg_mxfp4_moe(
)
)
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
permute_bias_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
@@ -356,7 +356,7 @@ def tg_mxfp4_moe(
.contiguous()
)
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
@@ -367,7 +367,7 @@ def tg_mxfp4_moe(
.contiguous()
)
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
permute_sf_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
@@ -381,7 +381,7 @@ def tg_mxfp4_moe(
)
)
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
permute_indices = get_w2_permute_indices_with_cache(
_cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,