[6/N][Attention] Move utils to more appropriate locations (#32215)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-13 08:38:52 -05:00
committed by GitHub
parent fefce49807
commit 98f60e5acb
14 changed files with 171 additions and 181 deletions

View File

@@ -7,13 +7,15 @@ import torch
from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.test_attention_backends import BATCH_SPECS
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
UBatchSlice,
_make_metadata_with_slice,
slice_query_start_locs,
split_attn_metadata,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices from vllm.v1.worker.ubatch_utils import (
UBatchSlice,
_make_metadata_with_slice,
maybe_create_ubatch_slices,
slice_query_start_locs,
split_attn_metadata,
)
@pytest.fixture @pytest.fixture

View File

@@ -13,10 +13,10 @@ from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches, make_local_attention_virtual_batches,
subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (

View File

@@ -15,8 +15,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend

View File

@@ -13,8 +13,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend

View File

@@ -16,8 +16,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (

View File

@@ -17,11 +17,9 @@ from vllm.v1.attention.backend import (
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import (
subclass_attention_backend_with_overrides, subclass_attention_backend_with_overrides,
) )
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
import numpy as np import numpy as np
import torch import torch
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto" return kv_cache_dtype != "auto"
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]],
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)

View File

@@ -8,7 +8,6 @@ from typing import (
Any, Any,
Literal, Literal,
Protocol, Protocol,
TypeVar,
get_args, get_args,
) )
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend,
) )
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__) logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"] KVCacheLayoutType = Literal["NHD", "HND"]
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType) return value in get_args(KVCacheLayoutType)
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return (
query_start_loc[request_slice.start : request_slice.stop + 1]
- query_start_loc[request_slice.start]
)
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(
attn_metadata.query_start_loc, request_slice
)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results
@functools.lru_cache @functools.lru_cache
def get_kv_cache_layout(): def get_kv_cache_layout():
# Format specified by the code. # Format specified by the code.
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
return common_attn_metadata return common_attn_metadata
M = TypeVar("M")
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]],
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
)
def subclass_attention_backend_with_overrides(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
overrides: dict[str, Any],
) -> type[AttentionBackend]:
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls,), overrides)
def split_decodes_prefills_and_extends( def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1, decode_threshold: int = 1,

View File

@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata

View File

@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables

View File

@@ -6,7 +6,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import ( from vllm.v1.worker.gpu.cudagraph_utils import (

View File

@@ -112,7 +112,6 @@ from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend, create_fast_prefill_custom_backend,
get_dcp_local_seq_lens, get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills, reorder_batch_to_split_decodes_and_prefills,
split_attn_metadata,
) )
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
@@ -165,6 +164,7 @@ from vllm.v1.worker.ubatch_utils import (
UBatchSlices, UBatchSlices,
check_ubatch_thresholds, check_ubatch_thresholds,
maybe_create_ubatch_slices, maybe_create_ubatch_slices,
split_attn_metadata,
) )
from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.workspace import lock_workspace from vllm.v1.worker.workspace import lock_workspace

View File

@@ -4,8 +4,10 @@ from dataclasses import dataclass
from typing import TypeAlias from typing import TypeAlias
import numpy as np import numpy as np
import torch
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.v1.attention.backend import CommonAttentionMetadata
@dataclass @dataclass
@@ -110,3 +112,132 @@ def maybe_create_ubatch_slices(
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
return ubatch_slices, ubatch_slices_padded return ubatch_slices, ubatch_slices_padded
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return (
query_start_loc[request_slice.start : request_slice.stop + 1]
- query_start_loc[request_slice.start]
)
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata
) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty"
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
start_locs = attn_metadata.query_start_loc_cpu
first_req = request_slice.start
first_tok = token_slice.start
last_req = request_slice.stop - 1
last_tok = token_slice.stop - 1
assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], (
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request = first_tok > start_locs[first_req]
splits_last_request = last_tok < start_locs[last_req + 1] - 1
query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice)
query_start_loc = slice_query_start_locs(
attn_metadata.query_start_loc, request_slice
)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}"
)
if splits_first_request:
tokens_skipped = first_tok - start_locs[first_req]
query_start_loc[1:] -= tokens_skipped
query_start_loc_cpu[1:] -= tokens_skipped
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
if splits_last_request:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped = start_locs[last_req + 1] - token_slice.stop
query_start_loc[-1] -= tokens_skipped
query_start_loc_cpu[-1] -= tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens = seq_lens.clone()
seq_lens_cpu = seq_lens_cpu.clone()
seq_lens[-1] -= tokens_skipped
seq_lens_cpu[-1] -= tokens_skipped
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results

View File

@@ -16,8 +16,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec