[Core/DBO][2/N] Dual-Batch Overlap add DeepEP High Throughput support and Prefill support (#24845)
Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
|
||||
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
|
||||
#
|
||||
@@ -223,7 +224,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[Callable, ReceiverType]:
|
||||
) -> Union[tuple[Callable, ReceiverType], ReceiverType]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
@@ -239,10 +240,21 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `prepare`, e.g.
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`prepare`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
e.g.
|
||||
|
||||
receiver = obj.prepare_async(...)
|
||||
ret = obj.prepare_async(...)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
hook, receiver = ret
|
||||
hook()
|
||||
|
||||
if hook is not None:
|
||||
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
|
||||
|
||||
is equivalent to:
|
||||
@@ -284,7 +296,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> Callable:
|
||||
) -> Union[tuple[Callable, Callable], Callable]:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output but do not wait for results from other workers.
|
||||
@@ -298,11 +310,17 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- weight_and_reduce_impl: An optional TopKWeightAndReduce
|
||||
implementation.
|
||||
|
||||
Returns a callback that when invoked waits for results from other
|
||||
workers and has the same return signature as `finalize`, e.g.
|
||||
Returns a callback or a hook callback pair that when invoked waits for
|
||||
results from other workers and has the same return signature as
|
||||
`finalize`, if a hook is returned this is more lightweight check that
|
||||
the recv is complete without doing extra work (used by DBO, will be
|
||||
refactored in the very near future)
|
||||
|
||||
receiver = obj.finalize_async(output, ...)
|
||||
ret = obj.finalize_async(output, ...)
|
||||
... output not valid yet ...
|
||||
if isinstance(ret, tuple):
|
||||
hook, receiver = ret
|
||||
hook()
|
||||
receiver()
|
||||
... output valid here ...
|
||||
|
||||
@@ -600,9 +618,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
fused_out_buffer = SharedResizableBuffer()
|
||||
workspace13_buffer = SharedResizableBuffer()
|
||||
workspace2_buffer = SharedResizableBuffer()
|
||||
|
||||
class SharedBuffers:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.fused_out = SharedResizableBuffer()
|
||||
self.workspace13 = SharedResizableBuffer()
|
||||
self.workspace2 = SharedResizableBuffer()
|
||||
|
||||
# Persistent buffers that are shared across `FusedMoEModularKernel`
|
||||
# instances (layers), to save memory and allocattions.
|
||||
#
|
||||
# We have two sets of buffers to support dual batch overlap (DBO) where each
|
||||
# microbatch (ubatch) should use its own set of buffers to avoid
|
||||
# cross-ubatch contimination.
|
||||
# NOTE that memory is lazily allocated for these buffers, meaning that if
|
||||
# DBO isn't being used, the second SharedBuffers will be empty.
|
||||
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -647,14 +679,18 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = self.workspace13_buffer.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = self.workspace2_buffer.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace13 = buffers.workspace13.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = buffers.workspace2.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
@@ -733,9 +769,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
fused_out = self.fused_out_buffer.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
fused_out = buffers.fused_out.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
@@ -868,6 +906,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
# TODO(lucas): enable in follow-up
|
||||
assert not dbo_enabled()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
@@ -883,7 +922,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
dbo_maybe_run_recv_hook()
|
||||
hook, receiver = self.prepare_finalize.prepare_async(
|
||||
prepare_ret = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@@ -893,13 +932,21 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
# If DBO is being used, register the hook with the ubatch context
|
||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||
# the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
hook()
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = prepare_ret \
|
||||
if isinstance(prepare_ret, tuple) else (None, prepare_ret)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
@@ -952,7 +999,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
else:
|
||||
recv_hook = self.prepare_finalize.finalize_async(
|
||||
finalize_ret = self.prepare_finalize.finalize_async(
|
||||
output,
|
||||
fused_out,
|
||||
topk_weights,
|
||||
@@ -964,11 +1011,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
assert recv_hook is not None
|
||||
dbo_register_recv_hook(recv_hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
recv_hook()
|
||||
# TODO(lucas): refactor this in the alternative schedules followup
|
||||
# currently unpack if we have hook + receiver pair or just
|
||||
# receiver (see finalize_async docstring)
|
||||
hook, receiver = finalize_ret \
|
||||
if isinstance(finalize_ret, tuple) else (None, finalize_ret)
|
||||
|
||||
if hook is not None:
|
||||
if dbo_enabled():
|
||||
# If DBO is being used, register the hook with the ubatch
|
||||
# context and call it in dbo_maybe_run_recv_hook instead of
|
||||
# passing it to the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
else:
|
||||
hook()
|
||||
|
||||
receiver()
|
||||
|
||||
if self.shared_experts is None:
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user