[Core/DBO][1/N] Add Dual-Batch Overlap mechanism to VLLM (#23693)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||
|
||||
logger = init_logger(__name__)
|
||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
@@ -81,12 +82,6 @@ class CommonAttentionMetadata:
|
||||
encoder_seq_lens: Optional[np.ndarray] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchSlice:
|
||||
request_slice: slice
|
||||
token_slice: slice
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
@@ -103,7 +98,7 @@ def slice_query_start_locs(
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UbatchSlice,
|
||||
ubatch_slice: UBatchSlice,
|
||||
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
@@ -133,6 +128,11 @@ def _make_metadata_with_slice(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])).item())
|
||||
|
||||
# This is to account for the case where we are in a dummy
|
||||
# run and query_start_loc_cpu is full of 0s
|
||||
if max_query_len == 0:
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
@@ -152,12 +152,12 @@ def _make_metadata_with_slice(
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UbatchSlice],
|
||||
ubatch_slices: list[UBatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UbatchSlice in ubatch_slices.
|
||||
requests for each UBatchSlice in ubatch_slices.
|
||||
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user