[6/N][Attention] Move utils to more appropriate locations (#32215)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -7,13 +7,15 @@ import torch
|
|||||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
UBatchSlice,
|
|
||||||
_make_metadata_with_slice,
|
|
||||||
slice_query_start_locs,
|
|
||||||
split_attn_metadata,
|
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
|
from vllm.v1.worker.ubatch_utils import (
|
||||||
|
UBatchSlice,
|
||||||
|
_make_metadata_with_slice,
|
||||||
|
maybe_create_ubatch_slices,
|
||||||
|
slice_query_start_locs,
|
||||||
|
split_attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
make_local_attention_virtual_batches,
|
make_local_attention_virtual_batches,
|
||||||
subclass_attention_backend,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.selector import get_attn_backend
|
from vllm.v1.attention.selector import get_attn_backend
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
|||||||
@@ -15,8 +15,6 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
)
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
|
||||||
subclass_attention_backend,
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.selector import get_attn_backend
|
from vllm.v1.attention.selector import get_attn_backend
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
)
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
|
||||||
subclass_attention_backend,
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.selector import get_attn_backend
|
from vllm.v1.attention.selector import get_attn_backend
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
)
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
|
||||||
subclass_attention_backend,
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
|
|||||||
@@ -17,11 +17,9 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
)
|
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
|
||||||
from vllm.v1.attention.backends.utils import (
|
|
||||||
subclass_attention_backend_with_overrides,
|
subclass_attention_backend_with_overrides,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
from vllm.v1.attention.selector import get_attn_backend
|
from vllm.v1.attention.selector import get_attn_backend
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||||
return kv_cache_dtype != "auto"
|
return kv_cache_dtype != "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def subclass_attention_backend(
|
||||||
|
name_prefix: str,
|
||||||
|
attention_backend_cls: type[AttentionBackend],
|
||||||
|
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
"""
|
||||||
|
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||||
|
"""
|
||||||
|
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||||
|
|
||||||
|
return type(
|
||||||
|
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def subclass_attention_backend_with_overrides(
|
||||||
|
name_prefix: str,
|
||||||
|
attention_backend_cls: type[AttentionBackend],
|
||||||
|
overrides: dict[str, Any],
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||||
|
return type(name, (attention_backend_cls,), overrides)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Literal,
|
Literal,
|
||||||
Protocol,
|
Protocol,
|
||||||
TypeVar,
|
|
||||||
get_args,
|
get_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||||
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
|
|||||||
return value in get_args(KVCacheLayoutType)
|
return value in get_args(KVCacheLayoutType)
|
||||||
|
|
||||||
|
|
||||||
def slice_query_start_locs(
|
|
||||||
query_start_loc: torch.Tensor,
|
|
||||||
request_slice: slice,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Creates a new query_start_loc that corresponds to the requests in
|
|
||||||
request_slice.
|
|
||||||
|
|
||||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
|
||||||
This will break cudagraph compatibility.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
query_start_loc[request_slice.start : request_slice.stop + 1]
|
|
||||||
- query_start_loc[request_slice.start]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_metadata_with_slice(
|
|
||||||
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
|
|
||||||
) -> CommonAttentionMetadata:
|
|
||||||
"""
|
|
||||||
This function creates a new CommonAttentionMetadata that corresponds to
|
|
||||||
the requests included in ubatch_slice
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
|
|
||||||
|
|
||||||
request_slice = ubatch_slice.request_slice
|
|
||||||
token_slice = ubatch_slice.token_slice
|
|
||||||
|
|
||||||
start_locs = attn_metadata.query_start_loc_cpu
|
|
||||||
first_req = request_slice.start
|
|
||||||
first_tok = token_slice.start
|
|
||||||
last_req = request_slice.stop - 1
|
|
||||||
last_tok = token_slice.stop - 1
|
|
||||||
|
|
||||||
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
|
|
||||||
"Token slice start outside of first request"
|
|
||||||
)
|
|
||||||
# NOTE: last token can be outside of the last request if we have CG padding.
|
|
||||||
|
|
||||||
# If the request is split across ubatches, we have to adjust the metadata.
|
|
||||||
# splits_first_request: The first request in this slice is the continuation of
|
|
||||||
# a request that started in a previous slice.
|
|
||||||
# splits_last_request: The last request in this slice continues into the
|
|
||||||
# next slice.
|
|
||||||
splits_first_request = first_tok > start_locs[first_req]
|
|
||||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
|
||||||
|
|
||||||
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
|
||||||
query_start_loc = slice_query_start_locs(
|
|
||||||
attn_metadata.query_start_loc, request_slice
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(query_start_loc) >= 2, (
|
|
||||||
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if splits_first_request:
|
|
||||||
tokens_skipped = first_tok - start_locs[first_req]
|
|
||||||
query_start_loc[1:] -= tokens_skipped
|
|
||||||
query_start_loc_cpu[1:] -= tokens_skipped
|
|
||||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
|
||||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
|
||||||
|
|
||||||
if splits_last_request:
|
|
||||||
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
|
||||||
# the tokens skipped because query_start_loc_cpu might have been modified
|
|
||||||
# if splits_first_request is True.
|
|
||||||
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
|
||||||
query_start_loc[-1] -= tokens_skipped
|
|
||||||
query_start_loc_cpu[-1] -= tokens_skipped
|
|
||||||
|
|
||||||
# Make sure we don't modify the seq_lens tensors
|
|
||||||
# (not cudagraph compatible)
|
|
||||||
seq_lens = seq_lens.clone()
|
|
||||||
seq_lens_cpu = seq_lens_cpu.clone()
|
|
||||||
seq_lens[-1] -= tokens_skipped
|
|
||||||
seq_lens_cpu[-1] -= tokens_skipped
|
|
||||||
|
|
||||||
max_seq_len = int(seq_lens_cpu.max())
|
|
||||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
|
|
||||||
|
|
||||||
num_requests = request_slice.stop - request_slice.start
|
|
||||||
num_actual_tokens = token_slice.stop - token_slice.start
|
|
||||||
max_query_len = int(
|
|
||||||
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is to account for the case where we are in a dummy
|
|
||||||
# run and query_start_loc_cpu is full of 0s
|
|
||||||
if max_query_len == 0:
|
|
||||||
max_query_len = attn_metadata.max_query_len
|
|
||||||
|
|
||||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
|
||||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
|
||||||
|
|
||||||
return CommonAttentionMetadata(
|
|
||||||
query_start_loc=query_start_loc,
|
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
num_reqs=num_requests,
|
|
||||||
num_actual_tokens=num_actual_tokens,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
block_table_tensor=block_table_tensor,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
_seq_lens_cpu=seq_lens_cpu,
|
|
||||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def split_attn_metadata(
|
|
||||||
ubatch_slices: list[UBatchSlice],
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
) -> list[CommonAttentionMetadata]:
|
|
||||||
"""
|
|
||||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
|
||||||
requests for each UBatchSlice in ubatch_slices.
|
|
||||||
|
|
||||||
Note: This function does not modify common_attn_metadata
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
for ubatch_slice in ubatch_slices:
|
|
||||||
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_kv_cache_layout():
|
def get_kv_cache_layout():
|
||||||
# Format specified by the code.
|
# Format specified by the code.
|
||||||
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
|||||||
return common_attn_metadata
|
return common_attn_metadata
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar("M")
|
|
||||||
|
|
||||||
|
|
||||||
def subclass_attention_backend(
|
|
||||||
name_prefix: str,
|
|
||||||
attention_backend_cls: type[AttentionBackend],
|
|
||||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
|
||||||
) -> type[AttentionBackend]:
|
|
||||||
"""
|
|
||||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
|
||||||
"""
|
|
||||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
|
||||||
|
|
||||||
return type(
|
|
||||||
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def subclass_attention_backend_with_overrides(
|
|
||||||
name_prefix: str,
|
|
||||||
attention_backend_cls: type[AttentionBackend],
|
|
||||||
overrides: dict[str, Any],
|
|
||||||
) -> type[AttentionBackend]:
|
|
||||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
|
||||||
return type(name, (attention_backend_cls,), overrides)
|
|
||||||
|
|
||||||
|
|
||||||
def split_decodes_prefills_and_extends(
|
def split_decodes_prefills_and_extends(
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
decode_threshold: int = 1,
|
decode_threshold: int = 1,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.config.compilation import CUDAGraphMode
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.compilation import CUDAGraphMode
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
create_fast_prefill_custom_backend,
|
create_fast_prefill_custom_backend,
|
||||||
get_dcp_local_seq_lens,
|
get_dcp_local_seq_lens,
|
||||||
reorder_batch_to_split_decodes_and_prefills,
|
reorder_batch_to_split_decodes_and_prefills,
|
||||||
split_attn_metadata,
|
|
||||||
)
|
)
|
||||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
@@ -165,6 +164,7 @@ from vllm.v1.worker.ubatch_utils import (
|
|||||||
UBatchSlices,
|
UBatchSlices,
|
||||||
check_ubatch_thresholds,
|
check_ubatch_thresholds,
|
||||||
maybe_create_ubatch_slices,
|
maybe_create_ubatch_slices,
|
||||||
|
split_attn_metadata,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||||
from vllm.v1.worker.workspace import lock_workspace
|
from vllm.v1.worker.workspace import lock_workspace
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from dataclasses import dataclass
|
|||||||
from typing import TypeAlias
|
from typing import TypeAlias
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
|
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -110,3 +112,132 @@ def maybe_create_ubatch_slices(
|
|||||||
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
|
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
|
||||||
|
|
||||||
return ubatch_slices, ubatch_slices_padded
|
return ubatch_slices, ubatch_slices_padded
|
||||||
|
|
||||||
|
|
||||||
|
def slice_query_start_locs(
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
request_slice: slice,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Creates a new query_start_loc that corresponds to the requests in
|
||||||
|
request_slice.
|
||||||
|
|
||||||
|
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||||
|
This will break cudagraph compatibility.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
query_start_loc[request_slice.start : request_slice.stop + 1]
|
||||||
|
- query_start_loc[request_slice.start]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_metadata_with_slice(
|
||||||
|
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
|
||||||
|
) -> CommonAttentionMetadata:
|
||||||
|
"""
|
||||||
|
This function creates a new CommonAttentionMetadata that corresponds to
|
||||||
|
the requests included in ubatch_slice
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
|
||||||
|
|
||||||
|
request_slice = ubatch_slice.request_slice
|
||||||
|
token_slice = ubatch_slice.token_slice
|
||||||
|
|
||||||
|
start_locs = attn_metadata.query_start_loc_cpu
|
||||||
|
first_req = request_slice.start
|
||||||
|
first_tok = token_slice.start
|
||||||
|
last_req = request_slice.stop - 1
|
||||||
|
last_tok = token_slice.stop - 1
|
||||||
|
|
||||||
|
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
|
||||||
|
"Token slice start outside of first request"
|
||||||
|
)
|
||||||
|
# NOTE: last token can be outside of the last request if we have CG padding.
|
||||||
|
|
||||||
|
# If the request is split across ubatches, we have to adjust the metadata.
|
||||||
|
# splits_first_request: The first request in this slice is the continuation of
|
||||||
|
# a request that started in a previous slice.
|
||||||
|
# splits_last_request: The last request in this slice continues into the
|
||||||
|
# next slice.
|
||||||
|
splits_first_request = first_tok > start_locs[first_req]
|
||||||
|
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||||
|
|
||||||
|
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||||
|
query_start_loc = slice_query_start_locs(
|
||||||
|
attn_metadata.query_start_loc, request_slice
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(query_start_loc) >= 2, (
|
||||||
|
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if splits_first_request:
|
||||||
|
tokens_skipped = first_tok - start_locs[first_req]
|
||||||
|
query_start_loc[1:] -= tokens_skipped
|
||||||
|
query_start_loc_cpu[1:] -= tokens_skipped
|
||||||
|
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||||
|
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||||
|
|
||||||
|
if splits_last_request:
|
||||||
|
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
||||||
|
# the tokens skipped because query_start_loc_cpu might have been modified
|
||||||
|
# if splits_first_request is True.
|
||||||
|
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
||||||
|
query_start_loc[-1] -= tokens_skipped
|
||||||
|
query_start_loc_cpu[-1] -= tokens_skipped
|
||||||
|
|
||||||
|
# Make sure we don't modify the seq_lens tensors
|
||||||
|
# (not cudagraph compatible)
|
||||||
|
seq_lens = seq_lens.clone()
|
||||||
|
seq_lens_cpu = seq_lens_cpu.clone()
|
||||||
|
seq_lens[-1] -= tokens_skipped
|
||||||
|
seq_lens_cpu[-1] -= tokens_skipped
|
||||||
|
|
||||||
|
max_seq_len = int(seq_lens_cpu.max())
|
||||||
|
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
|
||||||
|
|
||||||
|
num_requests = request_slice.stop - request_slice.start
|
||||||
|
num_actual_tokens = token_slice.stop - token_slice.start
|
||||||
|
max_query_len = int(
|
||||||
|
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is to account for the case where we are in a dummy
|
||||||
|
# run and query_start_loc_cpu is full of 0s
|
||||||
|
if max_query_len == 0:
|
||||||
|
max_query_len = attn_metadata.max_query_len
|
||||||
|
|
||||||
|
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||||
|
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||||
|
|
||||||
|
return CommonAttentionMetadata(
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
num_reqs=num_requests,
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
block_table_tensor=block_table_tensor,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
_seq_lens_cpu=seq_lens_cpu,
|
||||||
|
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def split_attn_metadata(
|
||||||
|
ubatch_slices: list[UBatchSlice],
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
) -> list[CommonAttentionMetadata]:
|
||||||
|
"""
|
||||||
|
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||||
|
requests for each UBatchSlice in ubatch_slices.
|
||||||
|
|
||||||
|
Note: This function does not modify common_attn_metadata
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for ubatch_slice in ubatch_slices:
|
||||||
|
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config
|
|||||||
from vllm.multimodal.registry import MultiModalRegistry
|
from vllm.multimodal.registry import MultiModalRegistry
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||||
from vllm.v1.attention.backend import AttentionBackend
|
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user