[bugfix] Fix Llama3/4 issues caused by FlashInfer 0.2.10 (#22426)

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
Po-Han Huang (NVIDIA)
2025-08-08 11:25:01 +08:00
committed by GitHub
parent 157f9c1368
commit af473f0a85
2 changed files with 17 additions and 8 deletions

View File

@@ -6,14 +6,22 @@ import torch
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released.
tile_tokens_dim = 8
# from flashinfer import next_positive_power_of_2
# # Guess tokens per expert assuming perfect expert distribution first.
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
# # And pad the number to the next power of 2.
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
# # kernel.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim