[Core/DBO][2/N] Dual-Batch Overlap add DeepEP High Throughput support and Prefill support (#24845)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-09-23 12:02:10 -04:00
committed by GitHub
parent a903669e10
commit cc1dc7ed6d
19 changed files with 602 additions and 236 deletions

View File

@@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
#
@@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[Callable, ReceiverType]:
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
@@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `prepare`, e.g.
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
`prepare`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
e.g.
receiver = obj.prepare_async(...)
ret = obj.prepare_async(...)
if isinstance(ret, tuple):
hook, receiver = ret
hook()
if hook is not None:
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
is equivalent to:
@@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> Callable:
) -> Union[tuple[Callable, Callable], Callable]:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output but do not wait for results from other workers.
@@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC):
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `finalize`, e.g.
Returns a callback or a hook callback pair that when invoked waits for
results from other workers and has the same return signature as
`finalize`, if a hook is returned this is more lightweight check that
the recv is complete without doing extra work (used by DBO, will be
refactored in the very near future)
receiver = obj.finalize_async(output, ...)
ret = obj.finalize_async(output, ...)
... output not valid yet ...
if isinstance(ret, tuple):
hook, receiver = ret
hook()
receiver()
... output valid here ...
@@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module):
layer due to any layer specific state that may be used by the component
objects.
"""
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer()
workspace2_buffer = SharedResizableBuffer()
class SharedBuffers:
def __init__(self) -> None:
self.fused_out = SharedResizableBuffer()
self.workspace13 = SharedResizableBuffer()
self.workspace2 = SharedResizableBuffer()
# Persistent buffers that are shared across `FusedMoEModularKernel`
# instances (layers), to save memory and allocattions.
#
# We have two sets of buffers to support dual batch overlap (DBO) where each
# microbatch (ubatch) should use its own set of buffers to avoid
# cross-ubatch contimination.
# NOTE that memory is lazily allocated for these buffers, meaning that if
# DBO isn't being used, the second SharedBuffers will be empty.
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
def __init__(
self,
@@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module):
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = self.workspace13_buffer.get(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = self.workspace2_buffer.get(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
workspace13 = buffers.workspace13.get(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = buffers.workspace2.get(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
@@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module):
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = self.fused_out_buffer.get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
ubatch_idx = dbo_current_ubatch_id()
buffers = self.shared_buffers[ubatch_idx]
fused_out = buffers.fused_out.get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
def slice_input_tensors(
chunk_idx: int
@@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# TODO(lucas): enable in follow-up
assert not dbo_enabled()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
@@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module):
else:
# Overlap shared expert compute with all2all dispatch.
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
prepare_ret = self.prepare_finalize.prepare_async(
a1,
topk_weights,
topk_ids,
@@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config,
)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
if not dbo_enabled():
hook()
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = prepare_ret \
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
if hook is not None:
if dbo_enabled():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
else:
hook()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()
@@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
else:
recv_hook = self.prepare_finalize.finalize_async(
finalize_ret = self.prepare_finalize.finalize_async(
output,
fused_out,
topk_weights,
@@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
assert recv_hook is not None
dbo_register_recv_hook(recv_hook)
dbo_yield()
if not dbo_enabled():
recv_hook()
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = finalize_ret \
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
if hook is not None:
if dbo_enabled():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
else:
hook()
receiver()
if self.shared_experts is None:
return output