[Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc
2025-11-10 09:34:57 -08:00
committed by GitHub
parent b039bfda8f
commit 34553b9d27
7 changed files with 78 additions and 30 deletions

View File

@@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released.
tile_tokens_dim = 8
# from flashinfer import next_positive_power_of_2
# # Guess tokens per expert assuming perfect expert distribution first.
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
# # And pad the number to the next power of 2.
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
# # kernel.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
# A factor considering tokens are not perfectly balanced among experts.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim