[6/N][Attention] Move utils to more appropriate locations (#32215)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-13 08:38:52 -05:00
committed by GitHub
parent fefce49807
commit 98f60e5acb
14 changed files with 171 additions and 181 deletions

View File

@@ -7,13 +7,15 @@ import torch
from tests.v1.attention.test_attention_backends import BATCH_SPECS
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (
UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata,
split_decodes_and_prefills,
)
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
_make_metadata_with_slice,
maybe_create_ubatch_slices,
slice_query_start_locs,
split_attn_metadata,
)
@pytest.fixture

View File

@@ -13,10 +13,10 @@ from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches,
subclass_attention_backend,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (

View File

@@ -15,8 +15,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
subclass_attention_backend,
)
from vllm.v1.attention.selector import get_attn_backend

View File

@@ -13,8 +13,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
subclass_attention_backend,
)
from vllm.v1.attention.selector import get_attn_backend

View File

@@ -16,8 +16,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
subclass_attention_backend,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (

View File

@@ -17,11 +17,9 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import (
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
import numpy as np
import torch
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]],
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)

View File

@@ -8,7 +8,6 @@ from typing import (
Any,
Literal,
Protocol,
TypeVar,
get_args,
)
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return (
query_start_loc[request_slice.start : request_slice.stop + 1]
- query_start_loc[request_slice.start]
)
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(
attn_metadata.query_start_loc, request_slice
)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results
@functools.lru_cache
def get_kv_cache_layout():
# Format specified by the code.
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
return common_attn_metadata
M = TypeVar("M")
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]],
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,

View File

@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata

View File

@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables

View File

@@ -6,7 +6,7 @@ import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (

View File

@@ -112,7 +112,6 @@ from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend,
get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills,
split_attn_metadata,
)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (
@@ -165,6 +164,7 @@ from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
check_ubatch_thresholds,
maybe_create_ubatch_slices,
split_attn_metadata,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.workspace import lock_workspace

View File

@@ -4,8 +4,10 @@ from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from vllm.config import ParallelConfig
from vllm.v1.attention.backend import CommonAttentionMetadata
@dataclass
@@ -110,3 +112,132 @@ def maybe_create_ubatch_slices(
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
return ubatch_slices, ubatch_slices_padded
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return (
query_start_loc[request_slice.start : request_slice.stop + 1]
- query_start_loc[request_slice.start]
)
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(
attn_metadata.query_start_loc, request_slice
)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results

View File

@@ -16,8 +16,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec