diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py index 57d303cd5..a2d267bd7 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -76,9 +76,13 @@ def _fwd_kernel_ep_scatter_1( ) tokens_per_expert = round_up_128(tokens_per_expert) cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert - tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) - cur_expert_start = tl.load(expert_start_loc + cur_expert) + # Extract this block's offset from the register vector (warp shuffle, + # no global memory round-trip) then write it once to expert_start_loc. + cur_expert_start = tl.sum( + tl.where(offset_cumsum == cur_expert, cumsum, tl.zeros_like(cumsum)) + ) + tl.store(expert_start_loc + cur_expert, cur_expert_start) cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) m_indices_start_ptr = m_indices + cur_expert_start @@ -87,7 +91,7 @@ def _fwd_kernel_ep_scatter_1( # any rows in the per-expert aligned region that do not correspond to # real tokens are left untouched here and should remain initialized to # -1 so DeepGEMM can skip them - for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E): offs = start_m + off_expert mask = offs < cur_expert_token_num tl.store( @@ -186,6 +190,7 @@ def ep_scatter( grid = num_experts assert m_indices.shape[0] % BLOCK_E == 0 + assert expert_start_loc.shape[0] == num_experts _fwd_kernel_ep_scatter_1[(grid,)]( num_recv_tokens_per_expert,