[6/N][Attention] Move utils to more appropriate locations (#32215)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -7,13 +7,15 @@ import torch
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
maybe_create_ubatch_slices,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -13,10 +13,10 @@ from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
make_local_attention_virtual_batches,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
|
||||
@@ -15,8 +15,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
@@ -13,8 +13,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
@@ -16,8 +16,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
|
||||
@@ -17,11 +17,9 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(
|
||||
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_backend_with_overrides(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
overrides: dict[str, Any],
|
||||
) -> type[AttentionBackend]:
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
return type(name, (attention_backend_cls,), overrides)
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
get_args,
|
||||
)
|
||||
|
||||
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||
|
||||
logger = init_logger(__name__)
|
||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
|
||||
return value in get_args(KVCacheLayoutType)
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates a new query_start_loc that corresponds to the requests in
|
||||
request_slice.
|
||||
|
||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||
This will break cudagraph compatibility.
|
||||
"""
|
||||
return (
|
||||
query_start_loc[request_slice.start : request_slice.stop + 1]
|
||||
- query_start_loc[request_slice.start]
|
||||
)
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
|
||||
) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
start_locs = attn_metadata.query_start_loc_cpu
|
||||
first_req = request_slice.start
|
||||
first_tok = token_slice.start
|
||||
last_req = request_slice.stop - 1
|
||||
last_tok = token_slice.stop - 1
|
||||
|
||||
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
|
||||
"Token slice start outside of first request"
|
||||
)
|
||||
# NOTE: last token can be outside of the last request if we have CG padding.
|
||||
|
||||
# If the request is split across ubatches, we have to adjust the metadata.
|
||||
# splits_first_request: The first request in this slice is the continuation of
|
||||
# a request that started in a previous slice.
|
||||
# splits_last_request: The last request in this slice continues into the
|
||||
# next slice.
|
||||
splits_first_request = first_tok > start_locs[first_req]
|
||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||
|
||||
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||
query_start_loc = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc, request_slice
|
||||
)
|
||||
|
||||
assert len(query_start_loc) >= 2, (
|
||||
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
|
||||
)
|
||||
|
||||
if splits_first_request:
|
||||
tokens_skipped = first_tok - start_locs[first_req]
|
||||
query_start_loc[1:] -= tokens_skipped
|
||||
query_start_loc_cpu[1:] -= tokens_skipped
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
|
||||
if splits_last_request:
|
||||
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
||||
# the tokens skipped because query_start_loc_cpu might have been modified
|
||||
# if splits_first_request is True.
|
||||
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
||||
query_start_loc[-1] -= tokens_skipped
|
||||
query_start_loc_cpu[-1] -= tokens_skipped
|
||||
|
||||
# Make sure we don't modify the seq_lens tensors
|
||||
# (not cudagraph compatible)
|
||||
seq_lens = seq_lens.clone()
|
||||
seq_lens_cpu = seq_lens_cpu.clone()
|
||||
seq_lens[-1] -= tokens_skipped
|
||||
seq_lens_cpu[-1] -= tokens_skipped
|
||||
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
|
||||
|
||||
num_requests = request_slice.stop - request_slice.start
|
||||
num_actual_tokens = token_slice.stop - token_slice.start
|
||||
max_query_len = int(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
|
||||
)
|
||||
|
||||
# This is to account for the case where we are in a dummy
|
||||
# run and query_start_loc_cpu is full of 0s
|
||||
if max_query_len == 0:
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
)
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UBatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UBatchSlice in ubatch_slices.
|
||||
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
results = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Format specified by the code.
|
||||
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
return common_attn_metadata
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(
|
||||
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_backend_with_overrides(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
overrides: dict[str, Any],
|
||||
) -> type[AttentionBackend]:
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
return type(name, (attention_backend_cls,), overrides)
|
||||
|
||||
|
||||
def split_decodes_prefills_and_extends(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
|
||||
@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
|
||||
@@ -112,7 +112,6 @@ from vllm.v1.attention.backends.utils import (
|
||||
create_fast_prefill_custom_backend,
|
||||
get_dcp_local_seq_lens,
|
||||
reorder_batch_to_split_decodes_and_prefills,
|
||||
split_attn_metadata,
|
||||
)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
@@ -165,6 +164,7 @@ from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
maybe_create_ubatch_slices,
|
||||
split_attn_metadata,
|
||||
)
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.workspace import lock_workspace
|
||||
|
||||
@@ -4,8 +4,10 @@ from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -110,3 +112,132 @@ def maybe_create_ubatch_slices(
|
||||
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
|
||||
|
||||
return ubatch_slices, ubatch_slices_padded
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates a new query_start_loc that corresponds to the requests in
|
||||
request_slice.
|
||||
|
||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||
This will break cudagraph compatibility.
|
||||
"""
|
||||
return (
|
||||
query_start_loc[request_slice.start : request_slice.stop + 1]
|
||||
- query_start_loc[request_slice.start]
|
||||
)
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
|
||||
) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
start_locs = attn_metadata.query_start_loc_cpu
|
||||
first_req = request_slice.start
|
||||
first_tok = token_slice.start
|
||||
last_req = request_slice.stop - 1
|
||||
last_tok = token_slice.stop - 1
|
||||
|
||||
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
|
||||
"Token slice start outside of first request"
|
||||
)
|
||||
# NOTE: last token can be outside of the last request if we have CG padding.
|
||||
|
||||
# If the request is split across ubatches, we have to adjust the metadata.
|
||||
# splits_first_request: The first request in this slice is the continuation of
|
||||
# a request that started in a previous slice.
|
||||
# splits_last_request: The last request in this slice continues into the
|
||||
# next slice.
|
||||
splits_first_request = first_tok > start_locs[first_req]
|
||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||
|
||||
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||
query_start_loc = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc, request_slice
|
||||
)
|
||||
|
||||
assert len(query_start_loc) >= 2, (
|
||||
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
|
||||
)
|
||||
|
||||
if splits_first_request:
|
||||
tokens_skipped = first_tok - start_locs[first_req]
|
||||
query_start_loc[1:] -= tokens_skipped
|
||||
query_start_loc_cpu[1:] -= tokens_skipped
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
|
||||
if splits_last_request:
|
||||
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
||||
# the tokens skipped because query_start_loc_cpu might have been modified
|
||||
# if splits_first_request is True.
|
||||
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
||||
query_start_loc[-1] -= tokens_skipped
|
||||
query_start_loc_cpu[-1] -= tokens_skipped
|
||||
|
||||
# Make sure we don't modify the seq_lens tensors
|
||||
# (not cudagraph compatible)
|
||||
seq_lens = seq_lens.clone()
|
||||
seq_lens_cpu = seq_lens_cpu.clone()
|
||||
seq_lens[-1] -= tokens_skipped
|
||||
seq_lens_cpu[-1] -= tokens_skipped
|
||||
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
|
||||
|
||||
num_requests = request_slice.stop - request_slice.start
|
||||
num_actual_tokens = token_slice.stop - token_slice.start
|
||||
max_query_len = int(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
|
||||
)
|
||||
|
||||
# This is to account for the case where we are in a dummy
|
||||
# run and query_start_loc_cpu is full of 0s
|
||||
if max_query_len == 0:
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
)
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UBatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UBatchSlice in ubatch_slices.
|
||||
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
results = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
|
||||
return results
|
||||
|
||||
@@ -16,8 +16,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
|
||||
Reference in New Issue
Block a user