diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 97ee20ae9..bc9aab520 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -233,12 +233,13 @@ def flashinfer_alltoall_dispatch( max_num_token = ( max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0] ) + orig_topk_weights_dtype = topk_weights.dtype alltoall_info, topk_ids, topk_weights, _ = ( MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( topk_ids, topk_weights, None, - all2all_manager.prepare_workspace, + all2all_manager.prepare_workspace_tensor, max_num_token, ep_rank, ep_size, @@ -247,6 +248,7 @@ def flashinfer_alltoall_dispatch( top_k, ) ) + topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype) x, x_sf = moe_kernel_quantize_input( x,