[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -22,12 +22,12 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id,
|
||||
dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook,
|
||||
dbo_yield,
|
||||
)
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -661,25 +661,6 @@ def _slice_scales(
|
||||
return None
|
||||
|
||||
|
||||
class SharedResizableBuffer:
|
||||
def __init__(self):
|
||||
self.buffer = None
|
||||
|
||||
def get(
|
||||
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
assert shape != ()
|
||||
shape_numel = prod(shape)
|
||||
if (
|
||||
self.buffer is None
|
||||
or self.buffer.numel() < shape_numel
|
||||
or self.buffer.device != device
|
||||
or self.buffer.dtype != dtype
|
||||
):
|
||||
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
||||
return self.buffer[:shape_numel].view(*shape)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
@@ -694,22 +675,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
objects.
|
||||
"""
|
||||
|
||||
class SharedBuffers:
|
||||
def __init__(self) -> None:
|
||||
self.fused_out = SharedResizableBuffer()
|
||||
self.workspace13 = SharedResizableBuffer()
|
||||
self.workspace2 = SharedResizableBuffer()
|
||||
|
||||
# Persistent buffers that are shared across `FusedMoEModularKernel`
|
||||
# instances (layers), to save memory and allocattions.
|
||||
#
|
||||
# We have two sets of buffers to support dual batch overlap (DBO) where each
|
||||
# microbatch (ubatch) should use its own set of buffers to avoid
|
||||
# cross-ubatch contimination.
|
||||
# NOTE that memory is lazily allocated for these buffers, meaning that if
|
||||
# DBO isn't being used, the second SharedBuffers will be empty.
|
||||
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
@@ -806,10 +771,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
assert M_full > 0 and M_chunk > 0
|
||||
|
||||
num_chunks, _ = self._chunk_info(M_full)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
@@ -832,14 +793,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
buffers.workspace13.get(
|
||||
max_workspace_13, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.workspace2.get(
|
||||
max_workspace_2, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.fused_out.get(
|
||||
max_fused_out_shape, device=device, dtype=workspace_dtype
|
||||
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(max_workspace_13, workspace_dtype),
|
||||
(max_workspace_2, workspace_dtype),
|
||||
(max_fused_out_shape, out_dtype),
|
||||
)
|
||||
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
@@ -866,22 +824,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = buffers.workspace13.get(
|
||||
workspace13_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
workspace2 = buffers.workspace2.get(
|
||||
workspace2_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
# Reuse workspace13 for the output in the non-chunked case as long
|
||||
# as it is large enough. This will not always be the case for standard
|
||||
# format experts and with experts that have empty workspaces.
|
||||
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
|
||||
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
(workspace13_shape, workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
)
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
else:
|
||||
fused_out = buffers.fused_out.get(
|
||||
fused_out_shape, device=device, dtype=out_dtype
|
||||
workspace13, workspace2, fused_out = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(workspace13_shape, workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
(fused_out_shape, out_dtype),
|
||||
)
|
||||
)
|
||||
|
||||
return workspace13, workspace2, fused_out
|
||||
|
||||
Reference in New Issue
Block a user