[NVIDIA] [Perf] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#26714)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -37,7 +37,7 @@ if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
trtllm_fp4_block_scale_moe,
|
||||
)
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
||||
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -319,7 +319,7 @@ def tg_mxfp4_moe(
|
||||
if transpose_optimized:
|
||||
for i in range(num_experts):
|
||||
# w13 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -330,7 +330,7 @@ def tg_mxfp4_moe(
|
||||
.contiguous()
|
||||
)
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -344,7 +344,7 @@ def tg_mxfp4_moe(
|
||||
)
|
||||
)
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_bias_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
@@ -356,7 +356,7 @@ def tg_mxfp4_moe(
|
||||
.contiguous()
|
||||
)
|
||||
# w2 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -367,7 +367,7 @@ def tg_mxfp4_moe(
|
||||
.contiguous()
|
||||
)
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -381,7 +381,7 @@ def tg_mxfp4_moe(
|
||||
)
|
||||
)
|
||||
# w2 bias shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
_cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
|
||||
Reference in New Issue
Block a user