[6/N][Attention] Move utils to more appropriate locations (#32215)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-13 08:38:52 -05:00
committed by GitHub
parent fefce49807
commit 98f60e5acb
14 changed files with 171 additions and 181 deletions

View File

@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
import numpy as np
import torch
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]],
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)

View File

@@ -8,7 +8,6 @@ from typing import (
Any,
Literal,
Protocol,
TypeVar,
get_args,
)
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return (
query_start_loc[request_slice.start : request_slice.stop + 1]
- query_start_loc[request_slice.start]
)
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(
attn_metadata.query_start_loc, request_slice
)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results
@functools.lru_cache
def get_kv_cache_layout():
# Format specified by the code.
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
return common_attn_metadata
M = TypeVar("M")
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]],
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,