[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (#23693)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -13,6 +13,8 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (dbo_enabled, dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook, dbo_yield)
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
@@ -226,7 +228,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> ReceiverType:
|
||||
) -> tuple[Callable, ReceiverType]:
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel
|
||||
but do not wait for results from other workers.
|
||||
@@ -496,6 +498,23 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
return None
|
||||
|
||||
|
||||
class SharedResizableBuffer:
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = None
|
||||
|
||||
def get(self, shape: tuple[int, ...], device: torch.device,
|
||||
dtype: torch.dtype):
|
||||
shape_numel = prod(shape)
|
||||
if self.buffer is None or self.buffer.numel() < shape_numel:
|
||||
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
||||
assert self.buffer.device == device, \
|
||||
f"Buffer device mismatch: {self.buffer.device} != {device}"
|
||||
assert self.buffer.dtype == dtype, \
|
||||
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
|
||||
return self.buffer[:shape_numel].view(*shape)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
@@ -509,6 +528,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
layer due to any layer specific state that may be used by the component
|
||||
objects.
|
||||
"""
|
||||
fused_out_buffer = SharedResizableBuffer()
|
||||
workspace13_buffer = SharedResizableBuffer()
|
||||
workspace2_buffer = SharedResizableBuffer()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -559,12 +581,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace13 = self.workspace13_buffer.get(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = self.workspace2_buffer.get(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
@@ -656,9 +678,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
|
||||
expert_tokens_meta)
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
fused_out = self.fused_out_buffer.get(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
@@ -801,8 +823,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
shared_output: torch.Tensor
|
||||
|
||||
if (not self.prepare_finalize.supports_async()
|
||||
or self.shared_experts is None):
|
||||
if not self.prepare_finalize.supports_async():
|
||||
# We shouldn't be running an a2a kernel that doesn't
|
||||
# support async prepare/finalize
|
||||
assert not dbo_enabled()
|
||||
|
||||
# Run shared experts serially with dispatch.
|
||||
if self.shared_experts is not None:
|
||||
@@ -822,7 +846,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
# Overlap shared expert compute with all2all dispatch.
|
||||
receiver = self.prepare_finalize.prepare_async(
|
||||
dbo_maybe_run_recv_hook()
|
||||
hook, receiver = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
@@ -834,8 +859,16 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
assert self.shared_experts is not None
|
||||
shared_output = self.shared_experts(a1)
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(a1)
|
||||
|
||||
# If DBO is being used, register the hook with the ubatch context
|
||||
# and call it in dbo_maybe_run_recv_hook instead of passing it to
|
||||
# the receiver.
|
||||
dbo_register_recv_hook(hook)
|
||||
dbo_yield()
|
||||
if not dbo_enabled():
|
||||
hook()
|
||||
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = receiver()
|
||||
|
||||
Reference in New Issue
Block a user