[NVIDIA] [Perf] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#26714)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -50,7 +50,6 @@ from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import (
|
||||
has_triton_kernels,
|
||||
is_torch_equal_or_newer,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
@@ -97,12 +96,6 @@ def get_mxfp4_backend():
|
||||
and has_flashinfer()
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
):
|
||||
logger.info_once(
|
||||
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
|
||||
"for high concurrency throughput workloads consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
|
||||
"performance"
|
||||
)
|
||||
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
|
||||
elif current_platform.is_device_capability(100) and has_flashinfer():
|
||||
logger.info_once(
|
||||
@@ -357,7 +350,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
|
||||
):
|
||||
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
|
||||
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
|
||||
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
|
||||
|
||||
layer.gemm1_alpha = Parameter(
|
||||
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
@@ -449,7 +442,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
# w13 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -460,7 +453,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
.contiguous()
|
||||
)
|
||||
# w13 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -476,7 +469,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
)
|
||||
# w13 bias shuffling
|
||||
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_bias_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
@@ -488,7 +481,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
.contiguous()
|
||||
)
|
||||
# w2 weight shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -499,7 +492,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
.contiguous()
|
||||
)
|
||||
# w2 scale shuffling
|
||||
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_sf_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m,
|
||||
@@ -515,7 +508,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
)
|
||||
# w2 bias shuffling
|
||||
permute_indices = _maybe_get_cached_w2_permute_indices(
|
||||
permute_indices = get_w2_permute_indices_with_cache(
|
||||
self._cache_permute_indices,
|
||||
w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m,
|
||||
@@ -735,30 +728,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
@@ -1037,7 +1006,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.ep_rank * layer.local_num_experts, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
None,
|
||||
1 if renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
tune_max_num_tokens=self.max_capture_size,
|
||||
|
||||
Reference in New Issue
Block a user