[6/N][Attention] Move utils to more appropriate locations (#32215)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype != "auto"
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(
|
||||
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_backend_with_overrides(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
overrides: dict[str, Any],
|
||||
) -> type[AttentionBackend]:
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
return type(name, (attention_backend_cls,), overrides)
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
get_args,
|
||||
)
|
||||
|
||||
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||
|
||||
logger = init_logger(__name__)
|
||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
|
||||
return value in get_args(KVCacheLayoutType)
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates a new query_start_loc that corresponds to the requests in
|
||||
request_slice.
|
||||
|
||||
Note: This function creates a new tensor to hold the new query_start_locs.
|
||||
This will break cudagraph compatibility.
|
||||
"""
|
||||
return (
|
||||
query_start_loc[request_slice.start : request_slice.stop + 1]
|
||||
- query_start_loc[request_slice.start]
|
||||
)
|
||||
|
||||
|
||||
def _make_metadata_with_slice(
|
||||
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
|
||||
) -> CommonAttentionMetadata:
|
||||
"""
|
||||
This function creates a new CommonAttentionMetadata that corresponds to
|
||||
the requests included in ubatch_slice
|
||||
"""
|
||||
|
||||
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
|
||||
|
||||
request_slice = ubatch_slice.request_slice
|
||||
token_slice = ubatch_slice.token_slice
|
||||
|
||||
start_locs = attn_metadata.query_start_loc_cpu
|
||||
first_req = request_slice.start
|
||||
first_tok = token_slice.start
|
||||
last_req = request_slice.stop - 1
|
||||
last_tok = token_slice.stop - 1
|
||||
|
||||
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
|
||||
"Token slice start outside of first request"
|
||||
)
|
||||
# NOTE: last token can be outside of the last request if we have CG padding.
|
||||
|
||||
# If the request is split across ubatches, we have to adjust the metadata.
|
||||
# splits_first_request: The first request in this slice is the continuation of
|
||||
# a request that started in a previous slice.
|
||||
# splits_last_request: The last request in this slice continues into the
|
||||
# next slice.
|
||||
splits_first_request = first_tok > start_locs[first_req]
|
||||
splits_last_request = last_tok < start_locs[last_req + 1] - 1
|
||||
|
||||
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
|
||||
query_start_loc = slice_query_start_locs(
|
||||
attn_metadata.query_start_loc, request_slice
|
||||
)
|
||||
|
||||
assert len(query_start_loc) >= 2, (
|
||||
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
|
||||
)
|
||||
|
||||
if splits_first_request:
|
||||
tokens_skipped = first_tok - start_locs[first_req]
|
||||
query_start_loc[1:] -= tokens_skipped
|
||||
query_start_loc_cpu[1:] -= tokens_skipped
|
||||
seq_lens = attn_metadata.seq_lens[request_slice]
|
||||
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
||||
|
||||
if splits_last_request:
|
||||
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
|
||||
# the tokens skipped because query_start_loc_cpu might have been modified
|
||||
# if splits_first_request is True.
|
||||
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
|
||||
query_start_loc[-1] -= tokens_skipped
|
||||
query_start_loc_cpu[-1] -= tokens_skipped
|
||||
|
||||
# Make sure we don't modify the seq_lens tensors
|
||||
# (not cudagraph compatible)
|
||||
seq_lens = seq_lens.clone()
|
||||
seq_lens_cpu = seq_lens_cpu.clone()
|
||||
seq_lens[-1] -= tokens_skipped
|
||||
seq_lens_cpu[-1] -= tokens_skipped
|
||||
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
|
||||
|
||||
num_requests = request_slice.stop - request_slice.start
|
||||
num_actual_tokens = token_slice.stop - token_slice.start
|
||||
max_query_len = int(
|
||||
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
|
||||
)
|
||||
|
||||
# This is to account for the case where we are in a dummy
|
||||
# run and query_start_loc_cpu is full of 0s
|
||||
if max_query_len == 0:
|
||||
max_query_len = attn_metadata.max_query_len
|
||||
|
||||
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
||||
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
||||
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=seq_lens,
|
||||
num_reqs=num_requests,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
)
|
||||
|
||||
|
||||
def split_attn_metadata(
|
||||
ubatch_slices: list[UBatchSlice],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[CommonAttentionMetadata]:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata instance that corresponds to the
|
||||
requests for each UBatchSlice in ubatch_slices.
|
||||
|
||||
Note: This function does not modify common_attn_metadata
|
||||
"""
|
||||
results = []
|
||||
for ubatch_slice in ubatch_slices:
|
||||
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Format specified by the code.
|
||||
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
return common_attn_metadata
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(
|
||||
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_backend_with_overrides(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
overrides: dict[str, Any],
|
||||
) -> type[AttentionBackend]:
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
return type(name, (attention_backend_cls,), overrides)
|
||||
|
||||
|
||||
def split_decodes_prefills_and_extends(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
|
||||
Reference in New Issue
Block a user