diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 734819fcd..66edaf0a7 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -7,13 +7,15 @@ import torch from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm.v1.attention.backends.utils import ( - UBatchSlice, - _make_metadata_with_slice, - slice_query_start_locs, - split_attn_metadata, split_decodes_and_prefills, ) -from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices +from vllm.v1.worker.ubatch_utils import ( + UBatchSlice, + _make_metadata_with_slice, + maybe_create_ubatch_slices, + slice_query_start_locs, + split_attn_metadata, +) @pytest.fixture diff --git a/vllm/model_executor/layers/attention/chunked_local_attention.py b/vllm/model_executor/layers/attention/chunked_local_attention.py index 8916ff0c4..0fae51443 100644 --- a/vllm/model_executor/layers/attention/chunked_local_attention.py +++ b/vllm/model_executor/layers/attention/chunked_local_attention.py @@ -13,10 +13,10 @@ from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + subclass_attention_backend, ) from vllm.v1.attention.backends.utils import ( make_local_attention_virtual_batches, - subclass_attention_backend, ) from vllm.v1.attention.selector import get_attn_backend from vllm.v1.kv_cache_interface import ( diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index a16981a83..a3f1f1072 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -15,8 +15,6 @@ from vllm.v1.attention.backend import ( AttentionMetadata, AttentionType, CommonAttentionMetadata, -) -from vllm.v1.attention.backends.utils import ( subclass_attention_backend, ) from vllm.v1.attention.selector import get_attn_backend diff --git a/vllm/model_executor/layers/attention/encoder_only_attention.py b/vllm/model_executor/layers/attention/encoder_only_attention.py index 8df9e05c8..89a92ca1b 100644 --- a/vllm/model_executor/layers/attention/encoder_only_attention.py +++ b/vllm/model_executor/layers/attention/encoder_only_attention.py @@ -13,8 +13,6 @@ from vllm.v1.attention.backend import ( AttentionMetadata, AttentionType, CommonAttentionMetadata, -) -from vllm.v1.attention.backends.utils import ( subclass_attention_backend, ) from vllm.v1.attention.selector import get_attn_backend diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index f7ec382b3..a869226ea 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -16,8 +16,6 @@ from vllm.v1.attention.backend import ( AttentionMetadata, AttentionType, CommonAttentionMetadata, -) -from vllm.v1.attention.backends.utils import ( subclass_attention_backend, ) from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( diff --git a/vllm/model_executor/models/whisper_utils.py b/vllm/model_executor/models/whisper_utils.py index d41ccde0a..4d9f7ccf0 100644 --- a/vllm/model_executor/models/whisper_utils.py +++ b/vllm/model_executor/models/whisper_utils.py @@ -17,11 +17,9 @@ from vllm.v1.attention.backend import ( AttentionMetadata, AttentionType, CommonAttentionMetadata, -) -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.attention.backends.utils import ( subclass_attention_backend_with_overrides, ) +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.selector import get_attn_backend from vllm.v1.kv_cache_interface import AttentionSpec diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 0fd3d6eb3..6c6bb808b 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args import numpy as np import torch @@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: return kv_cache_dtype != "auto" + + +def subclass_attention_backend( + name_prefix: str, + attention_backend_cls: type[AttentionBackend], + builder_cls: type[AttentionMetadataBuilder[M]], +) -> type[AttentionBackend]: + """ + Return a new subclass where `get_builder_cls` returns `builder_cls`. + """ + name: str = name_prefix + attention_backend_cls.__name__ # type: ignore + + return type( + name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls} + ) + + +def subclass_attention_backend_with_overrides( + name_prefix: str, + attention_backend_cls: type[AttentionBackend], + overrides: dict[str, Any], +) -> type[AttentionBackend]: + name: str = name_prefix + attention_backend_cls.__name__ # type: ignore + return type(name, (attention_backend_cls,), overrides) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index c549bf7b5..1c254b836 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -8,7 +8,6 @@ from typing import ( Any, Literal, Protocol, - TypeVar, get_args, ) @@ -33,10 +32,9 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionMetadata, - AttentionMetadataBuilder, CommonAttentionMetadata, + subclass_attention_backend, ) -from vllm.v1.worker.ubatch_utils import UBatchSlice logger = init_logger(__name__) KVCacheLayoutType = Literal["NHD", "HND"] @@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) -def slice_query_start_locs( - query_start_loc: torch.Tensor, - request_slice: slice, -) -> torch.Tensor: - """ - Creates a new query_start_loc that corresponds to the requests in - request_slice. - - Note: This function creates a new tensor to hold the new query_start_locs. - This will break cudagraph compatibility. - """ - return ( - query_start_loc[request_slice.start : request_slice.stop + 1] - - query_start_loc[request_slice.start] - ) - - -def _make_metadata_with_slice( - ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata -) -> CommonAttentionMetadata: - """ - This function creates a new CommonAttentionMetadata that corresponds to - the requests included in ubatch_slice - """ - - assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty" - - request_slice = ubatch_slice.request_slice - token_slice = ubatch_slice.token_slice - - start_locs = attn_metadata.query_start_loc_cpu - first_req = request_slice.start - first_tok = token_slice.start - last_req = request_slice.stop - 1 - last_tok = token_slice.stop - 1 - - assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], ( - "Token slice start outside of first request" - ) - # NOTE: last token can be outside of the last request if we have CG padding. - - # If the request is split across ubatches, we have to adjust the metadata. - # splits_first_request: The first request in this slice is the continuation of - # a request that started in a previous slice. - # splits_last_request: The last request in this slice continues into the - # next slice. - splits_first_request = first_tok > start_locs[first_req] - splits_last_request = last_tok < start_locs[last_req + 1] - 1 - - query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) - query_start_loc = slice_query_start_locs( - attn_metadata.query_start_loc, request_slice - ) - - assert len(query_start_loc) >= 2, ( - f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}" - ) - - if splits_first_request: - tokens_skipped = first_tok - start_locs[first_req] - query_start_loc[1:] -= tokens_skipped - query_start_loc_cpu[1:] -= tokens_skipped - seq_lens = attn_metadata.seq_lens[request_slice] - seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] - - if splits_last_request: - # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate - # the tokens skipped because query_start_loc_cpu might have been modified - # if splits_first_request is True. - tokens_skipped = start_locs[last_req + 1] - token_slice.stop - query_start_loc[-1] -= tokens_skipped - query_start_loc_cpu[-1] -= tokens_skipped - - # Make sure we don't modify the seq_lens tensors - # (not cudagraph compatible) - seq_lens = seq_lens.clone() - seq_lens_cpu = seq_lens_cpu.clone() - seq_lens[-1] -= tokens_skipped - seq_lens_cpu[-1] -= tokens_skipped - - max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] - - num_requests = request_slice.stop - request_slice.start - num_actual_tokens = token_slice.stop - token_slice.start - max_query_len = int( - torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item() - ) - - # This is to account for the case where we are in a dummy - # run and query_start_loc_cpu is full of 0s - if max_query_len == 0: - max_query_len = attn_metadata.max_query_len - - block_table_tensor = attn_metadata.block_table_tensor[request_slice] - slot_mapping = attn_metadata.slot_mapping[token_slice] - - return CommonAttentionMetadata( - query_start_loc=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - seq_lens=seq_lens, - num_reqs=num_requests, - num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - block_table_tensor=block_table_tensor, - slot_mapping=slot_mapping, - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, - ) - - -def split_attn_metadata( - ubatch_slices: list[UBatchSlice], - common_attn_metadata: CommonAttentionMetadata, -) -> list[CommonAttentionMetadata]: - """ - Creates a new CommonAttentionMetadata instance that corresponds to the - requests for each UBatchSlice in ubatch_slices. - - Note: This function does not modify common_attn_metadata - """ - results = [] - for ubatch_slice in ubatch_slices: - results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) - - return results - - @functools.lru_cache def get_kv_cache_layout(): # Format specified by the code. @@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( return common_attn_metadata -M = TypeVar("M") - - -def subclass_attention_backend( - name_prefix: str, - attention_backend_cls: type[AttentionBackend], - builder_cls: type[AttentionMetadataBuilder[M]], -) -> type[AttentionBackend]: - """ - Return a new subclass where `get_builder_cls` returns `builder_cls`. - """ - name: str = name_prefix + attention_backend_cls.__name__ # type: ignore - - return type( - name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls} - ) - - -def subclass_attention_backend_with_overrides( - name_prefix: str, - attention_backend_cls: type[AttentionBackend], - overrides: dict[str, Any], -) -> type[AttentionBackend]: - name: str = name_prefix + attention_backend_cls.__name__ # type: ignore - return type(name, (attention_backend_cls,), overrides) - - def split_decodes_prefills_and_extends( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index d5095af18..51784cdc6 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -12,7 +12,7 @@ from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import set_forward_context -from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index ed9260120..30a833379 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.triton_utils import tl, triton -from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.block_table import BlockTables diff --git a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py index dcdeedda6..c4a511778 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py @@ -6,7 +6,7 @@ import torch from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode -from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.cudagraph_utils import ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 525ad5db4..de2a1e371 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -112,7 +112,6 @@ from vllm.v1.attention.backends.utils import ( create_fast_prefill_custom_backend, get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, - split_attn_metadata, ) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( @@ -165,6 +164,7 @@ from vllm.v1.worker.ubatch_utils import ( UBatchSlices, check_ubatch_thresholds, maybe_create_ubatch_slices, + split_attn_metadata, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.workspace import lock_workspace diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index f68891735..7c4172647 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -4,8 +4,10 @@ from dataclasses import dataclass from typing import TypeAlias import numpy as np +import torch from vllm.config import ParallelConfig +from vllm.v1.attention.backend import CommonAttentionMetadata @dataclass @@ -110,3 +112,132 @@ def maybe_create_ubatch_slices( assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded return ubatch_slices, ubatch_slices_padded + + +def slice_query_start_locs( + query_start_loc: torch.Tensor, + request_slice: slice, +) -> torch.Tensor: + """ + Creates a new query_start_loc that corresponds to the requests in + request_slice. + + Note: This function creates a new tensor to hold the new query_start_locs. + This will break cudagraph compatibility. + """ + return ( + query_start_loc[request_slice.start : request_slice.stop + 1] + - query_start_loc[request_slice.start] + ) + + +def _make_metadata_with_slice( + ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata +) -> CommonAttentionMetadata: + """ + This function creates a new CommonAttentionMetadata that corresponds to + the requests included in ubatch_slice + """ + + assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty" + + request_slice = ubatch_slice.request_slice + token_slice = ubatch_slice.token_slice + + start_locs = attn_metadata.query_start_loc_cpu + first_req = request_slice.start + first_tok = token_slice.start + last_req = request_slice.stop - 1 + last_tok = token_slice.stop - 1 + + assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], ( + "Token slice start outside of first request" + ) + # NOTE: last token can be outside of the last request if we have CG padding. + + # If the request is split across ubatches, we have to adjust the metadata. + # splits_first_request: The first request in this slice is the continuation of + # a request that started in a previous slice. + # splits_last_request: The last request in this slice continues into the + # next slice. + splits_first_request = first_tok > start_locs[first_req] + splits_last_request = last_tok < start_locs[last_req + 1] - 1 + + query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) + query_start_loc = slice_query_start_locs( + attn_metadata.query_start_loc, request_slice + ) + + assert len(query_start_loc) >= 2, ( + f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}" + ) + + if splits_first_request: + tokens_skipped = first_tok - start_locs[first_req] + query_start_loc[1:] -= tokens_skipped + query_start_loc_cpu[1:] -= tokens_skipped + seq_lens = attn_metadata.seq_lens[request_slice] + seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + + if splits_last_request: + # NOTE: We use start_locs (the original query_start_loc_cpu) to calculate + # the tokens skipped because query_start_loc_cpu might have been modified + # if splits_first_request is True. + tokens_skipped = start_locs[last_req + 1] - token_slice.stop + query_start_loc[-1] -= tokens_skipped + query_start_loc_cpu[-1] -= tokens_skipped + + # Make sure we don't modify the seq_lens tensors + # (not cudagraph compatible) + seq_lens = seq_lens.clone() + seq_lens_cpu = seq_lens_cpu.clone() + seq_lens[-1] -= tokens_skipped + seq_lens_cpu[-1] -= tokens_skipped + + max_seq_len = int(seq_lens_cpu.max()) + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] + + num_requests = request_slice.stop - request_slice.start + num_actual_tokens = token_slice.stop - token_slice.start + max_query_len = int( + torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item() + ) + + # This is to account for the case where we are in a dummy + # run and query_start_loc_cpu is full of 0s + if max_query_len == 0: + max_query_len = attn_metadata.max_query_len + + block_table_tensor = attn_metadata.block_table_tensor[request_slice] + slot_mapping = attn_metadata.slot_mapping[token_slice] + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + num_reqs=num_requests, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + ) + + +def split_attn_metadata( + ubatch_slices: list[UBatchSlice], + common_attn_metadata: CommonAttentionMetadata, +) -> list[CommonAttentionMetadata]: + """ + Creates a new CommonAttentionMetadata instance that corresponds to the + requests for each UBatchSlice in ubatch_slices. + + Note: This function does not modify common_attn_metadata + """ + results = [] + for ubatch_slice in ubatch_slices: + results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + + return results diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 85acc1679..7fd6161a9 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -16,8 +16,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.platforms import current_platform from vllm.utils.mem_utils import MemorySnapshot, format_gib -from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec