[Perf] Optimize deepgemm experts initialization, 3.9% TTFT improvement (#30494)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: li-jinpeng <3332126450@qq.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -84,10 +84,16 @@ def _fwd_kernel_ep_scatter_1(
|
|||||||
m_indices_start_ptr = m_indices + cur_expert_start
|
m_indices_start_ptr = m_indices + cur_expert_start
|
||||||
off_expert = tl.arange(0, BLOCK_E)
|
off_expert = tl.arange(0, BLOCK_E)
|
||||||
|
|
||||||
|
# any rows in the per-expert aligned region that do not correspond to
|
||||||
|
# real tokens are left untouched here and should remain initialized to
|
||||||
|
# -1 so DeepGEMM can skip them
|
||||||
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
|
||||||
|
offs = start_m + off_expert
|
||||||
|
mask = offs < cur_expert_token_num
|
||||||
tl.store(
|
tl.store(
|
||||||
m_indices_start_ptr + start_m + off_expert,
|
m_indices_start_ptr + offs,
|
||||||
cur_expert,
|
cur_expert,
|
||||||
|
mask=mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -366,12 +372,17 @@ def deepgemm_moe_permute(
|
|||||||
(M_sum, H // block_k), device=device, dtype=torch.float32
|
(M_sum, H // block_k), device=device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
maybe_has_empty_blocks = (expert_tokens_meta is None) or (
|
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark
|
||||||
expert_tokens_meta.expert_num_tokens_cpu is None
|
# completely invalid / padded blocks that should be skipped. We always
|
||||||
|
# initialize expert_ids to -1 so any row that is not explicitly written
|
||||||
|
# by the scatter kernel will be treated as invalid and skipped by
|
||||||
|
# DeepGEMM's scheduler.
|
||||||
|
expert_ids = torch.full(
|
||||||
|
(M_sum,),
|
||||||
|
fill_value=-1,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty
|
|
||||||
|
|
||||||
expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
|
|
||||||
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
|
inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
expert_num_tokens = None
|
expert_num_tokens = None
|
||||||
|
|||||||
Reference in New Issue
Block a user