[NVIDIA] [Perf] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#26714)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
jiahanc
2025-10-16 16:20:25 -07:00
committed by GitHub
parent fb5e10d3fb
commit 41d3071918
7 changed files with 25 additions and 96 deletions

View File

@@ -11,7 +11,6 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.utils import next_power_of_2
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -65,30 +64,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (M, K)
return (workspace1, workspace2, output)
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# 1.0 means perfect expert distribution.
# > 1.0 means some experts have more tokens than the perfect
# distribution.
# < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert assuming perfect
# distribution.
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
# kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def apply(
self,
output: torch.Tensor,
@@ -148,9 +123,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
"local_expert_offset": local_expert_offset,
"local_num_experts": local_num_experts,
"routed_scaling_factor": None,
"tile_tokens_dim": self._get_tile_tokens_dim(
x_quant, topk, local_num_experts
),
"tile_tokens_dim": None,
"routing_method_type": 1,
"do_finalize": True,
"output": output,

View File

@@ -72,7 +72,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from vllm.scalar_type import scalar_types
from vllm.utils import next_power_of_2
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
@@ -1125,16 +1124,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
return out.view(*output_shape)
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization.
@@ -1332,8 +1321,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
):
from flashinfer import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import (
_maybe_get_cached_w2_permute_indices,
_maybe_get_cached_w3_w1_permute_indices,
get_w2_permute_indices_with_cache,
)
"""Prepare quantized weights for kernel (done offline with weights)."""
@@ -1394,7 +1383,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
)
)
permute_indices = _maybe_get_cached_w2_permute_indices(
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
gemm2_weights_fp4[i].view(torch.uint8),
epilogue_tile_m,
@@ -1405,7 +1394,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
.contiguous()
)
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
gemm2_scales_linear_fp4[i].view(torch.uint8),
epilogue_tile_m,
@@ -1664,9 +1653,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
tile_tokens_dim=_get_tile_tokens_dim(
x.shape[0], top_k, layer.local_num_experts
),
tile_tokens_dim=None,
routing_method_type=routing_method_type,
do_finalize=True,
)[0]

View File

@@ -50,7 +50,6 @@ from vllm.scalar_type import scalar_types
from vllm.utils import (
has_triton_kernels,
is_torch_equal_or_newer,
next_power_of_2,
round_up,
)
from vllm.utils.flashinfer import has_flashinfer
@@ -97,12 +96,6 @@ def get_mxfp4_backend():
and has_flashinfer()
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
logger.info_once(
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
"for high concurrency throughput workloads consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
"performance"
)
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
elif current_platform.is_device_capability(100) and has_flashinfer():
logger.info_once(
@@ -357,7 +350,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
):
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
@@ -449,7 +442,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
# w13 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_weight[i].view(torch.uint8),
epilogue_tile_m,
@@ -460,7 +453,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
.contiguous()
)
# w13 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
@@ -476,7 +469,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
)
# w13 bias shuffling
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
permute_bias_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
@@ -488,7 +481,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
.contiguous()
)
# w2 weight shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_weight[i].view(torch.uint8),
epilogue_tile_m,
@@ -499,7 +492,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
.contiguous()
)
# w2 scale shuffling
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
permute_sf_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m,
@@ -515,7 +508,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
)
# w2 bias shuffling
permute_indices = _maybe_get_cached_w2_permute_indices(
permute_indices = get_w2_permute_indices_with_cache(
self._cache_permute_indices,
w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m,
@@ -735,30 +728,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
@@ -1037,7 +1006,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.ep_rank * layer.local_num_experts, # local_expert_offset
self.num_experts, # local num experts
None,
self._get_tile_tokens_dim(x, top_k),
None,
1 if renormalize else 0, # routing_method_type, renormalize
True, # do finalize
tune_max_num_tokens=self.max_capture_size,