[Perf] Fix DBO overlap: capture DeepEP event before yield (#38451)

Signed-off-by: root <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2026-03-31 13:38:59 -07:00
committed by GitHub
parent d9b90a07ac
commit 517b769b58

View File

@@ -107,15 +107,17 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
) -> Callable:
has_scales = token_scales is not None
# Capture a DeepEP event on the compute stream before yielding.
# This must happen before the yield so the event only covers this
# ubatch's compute work. If captured after, the compute stream tail
# may include the other ubatch's work, preventing overlap.
previous_event = dbo_get_previous_event(self.buffer.capture)
# We yield before launching the dispatch kernel since the dispatch
# kernel will block the CPU so we want to queue up all the compute
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm()
# capture a DeepEP event and pass it as previous_event so
# DeepEP honors the dependency internally.
previous_event = dbo_get_previous_event(self.buffer.capture)
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
@@ -357,11 +359,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
previous_event = dbo_get_previous_event(self.buffer.capture)
dbo_yield_and_switch_from_compute_to_comm()
assert fused_expert_output.dtype == torch.bfloat16, (
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
)
previous_event = dbo_get_previous_event(self.buffer.capture)
combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output,