[Perf] Fix DBO overlap: capture DeepEP event before yield (#38451)
Signed-off-by: root <conway.zhu@cohere.com>
This commit is contained in:
@@ -107,15 +107,17 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
) -> Callable:
|
||||
has_scales = token_scales is not None
|
||||
|
||||
# Capture a DeepEP event on the compute stream before yielding.
|
||||
# This must happen before the yield so the event only covers this
|
||||
# ubatch's compute work. If captured after, the compute stream tail
|
||||
# may include the other ubatch's work, preventing overlap.
|
||||
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||
|
||||
# We yield before launching the dispatch kernel since the dispatch
|
||||
# kernel will block the CPU so we want to queue up all the compute
|
||||
# for the other ubatch before the dispatch kernel starts.
|
||||
dbo_yield_and_switch_from_compute_to_comm()
|
||||
|
||||
# capture a DeepEP event and pass it as previous_event so
|
||||
# DeepEP honors the dependency internally.
|
||||
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
@@ -357,11 +359,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||
dbo_yield_and_switch_from_compute_to_comm()
|
||||
assert fused_expert_output.dtype == torch.bfloat16, (
|
||||
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
|
||||
)
|
||||
previous_event = dbo_get_previous_event(self.buffer.capture)
|
||||
combined_x, _, event = self.buffer.combine(
|
||||
# HT combine only supports BF16
|
||||
x=fused_expert_output,
|
||||
|
||||
Reference in New Issue
Block a user