[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (#23693)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Sage Moore
2025-09-16 09:21:48 -07:00
committed by GitHub
parent 08369289af
commit 567939953b
22 changed files with 1257 additions and 172 deletions

View File

@@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
dbo_register_recv_hook, dbo_yield)
#
# This file defines a set of base classes used to make MoE kernels more modular.
@@ -226,7 +228,7 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> ReceiverType:
) -> tuple[Callable, ReceiverType]:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
@@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
return None
class SharedResizableBuffer:
def __init__(self):
self.buffer = None
def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype):
shape_numel = prod(shape)
if self.buffer is None or self.buffer.numel() < shape_numel:
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
assert self.buffer.device == device, \
f"Buffer device mismatch: {self.buffer.device} != {device}"
assert self.buffer.dtype == dtype, \
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
return self.buffer[:shape_numel].view(*shape)
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
@@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module):
layer due to any layer specific state that may be used by the component
objects.
"""
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer()
workspace2_buffer = SharedResizableBuffer()
def __init__(
self,
@@ -559,12 +581,12 @@ class FusedMoEModularKernel(torch.nn.Module):
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.empty(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
workspace13 = self.workspace13_buffer.get(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = self.workspace2_buffer.get(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
@@ -656,9 +678,9 @@ class FusedMoEModularKernel(torch.nn.Module):
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
fused_out = self.fused_out_buffer.get(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
def slice_input_tensors(
chunk_idx: int
@@ -801,8 +823,10 @@ class FusedMoEModularKernel(torch.nn.Module):
shared_output: torch.Tensor
if (not self.prepare_finalize.supports_async()
or self.shared_experts is None):
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
assert not dbo_enabled()
# Run shared experts serially with dispatch.
if self.shared_experts is not None:
@@ -822,7 +846,8 @@ class FusedMoEModularKernel(torch.nn.Module):
)
else:
# Overlap shared expert compute with all2all dispatch.
receiver = self.prepare_finalize.prepare_async(
dbo_maybe_run_recv_hook()
hook, receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
@@ -834,8 +859,16 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.quant_config,
)
assert self.shared_experts is not None
shared_output = self.shared_experts(a1)
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
# If DBO is being used, register the hook with the ubatch context
# and call it in dbo_maybe_run_recv_hook instead of passing it to
# the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
if not dbo_enabled():
hook()
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()