[Bugfix] ep_scatter kernel store-load race condition (#34991)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
@@ -76,9 +76,13 @@ def _fwd_kernel_ep_scatter_1(
|
||||
)
|
||||
tokens_per_expert = round_up_128(tokens_per_expert)
|
||||
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
|
||||
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
|
||||
|
||||
cur_expert_start = tl.load(expert_start_loc + cur_expert)
|
||||
# Extract this block's offset from the register vector (warp shuffle,
|
||||
# no global memory round-trip) then write it once to expert_start_loc.
|
||||
cur_expert_start = tl.sum(
|
||||
tl.where(offset_cumsum == cur_expert, cumsum, tl.zeros_like(cumsum))
|
||||
)
|
||||
tl.store(expert_start_loc + cur_expert, cur_expert_start)
|
||||
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
|
||||
|
||||
m_indices_start_ptr = m_indices + cur_expert_start
|
||||
@@ -87,7 +91,7 @@ def _fwd_kernel_ep_scatter_1(
|
||||
# any rows in the per-expert aligned region that do not correspond to
|
||||
# real tokens are left untouched here and should remain initialized to
|
||||
# -1 so DeepGEMM can skip them
|
||||
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
||||
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E):
|
||||
offs = start_m + off_expert
|
||||
mask = offs < cur_expert_token_num
|
||||
tl.store(
|
||||
@@ -186,6 +190,7 @@ def ep_scatter(
|
||||
grid = num_experts
|
||||
|
||||
assert m_indices.shape[0] % BLOCK_E == 0
|
||||
assert expert_start_loc.shape[0] == num_experts
|
||||
|
||||
_fwd_kernel_ep_scatter_1[(grid,)](
|
||||
num_recv_tokens_per_expert,
|
||||
|
||||
Reference in New Issue
Block a user