[flashinfer] fix FI all2all with FI cutlass moe (#28166)

Signed-off-by: Xiaozhu <mxz297@gmail.com>
This commit is contained in:
Xiaozhu Meng
2025-11-05 21:52:16 -08:00
committed by GitHub
parent bde5039325
commit e31946f86e

View File

@@ -233,12 +233,13 @@ def flashinfer_alltoall_dispatch(
max_num_token = (
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
)
orig_topk_weights_dtype = topk_weights.dtype
alltoall_info, topk_ids, topk_weights, _ = (
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
topk_ids,
topk_weights,
None,
all2all_manager.prepare_workspace,
all2all_manager.prepare_workspace_tensor,
max_num_token,
ep_rank,
ep_size,
@@ -247,6 +248,7 @@ def flashinfer_alltoall_dispatch(
top_k,
)
)
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
x, x_sf = moe_kernel_quantize_input(
x,