diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index 19c02fc88..af9e5b5ba 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -149,7 +149,7 @@ The CUDA Graphs wrapper no longer manages the warm-up logic. The warm-up process ## CUDA Graphs Compatibility of Attention Backends -To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backends.utils.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`. +To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backend.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`. ```python class AttentionCGSupport(enum.Enum): diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 2068c30c0..6e2bb44e0 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -23,10 +23,9 @@ from vllm.utils.torch_utils import ( is_torch_equal_or_newer, set_random_seed, ) -from vllm.v1.attention.backend import AttentionType +from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, set_kv_cache_layout, ) from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index de80c556b..85efc5d8f 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -22,10 +22,10 @@ from vllm.config.vllm import set_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 71e74f4d5..da4cea8fc 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -18,12 +18,12 @@ from vllm.config import ( VllmConfig, ) from vllm.config.model import ModelDType -from vllm.v1.attention.backend import AttentionImpl -from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( + AttentionImpl, AttentionMetadataBuilder, CommonAttentionMetadata, ) +from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/e2e/test_async_spec_decode.py b/tests/v1/e2e/test_async_spec_decode.py index 561f37a52..4bf76da45 100644 --- a/tests/v1/e2e/test_async_spec_decode.py +++ b/tests/v1/e2e/test_async_spec_decode.py @@ -19,7 +19,7 @@ def sync_tracker(): Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect lazy init syncs. Prints stack traces immediately when syncs occur. """ - from vllm.v1.attention.backends.utils import CommonAttentionMetadata + from vllm.v1.attention.backend import CommonAttentionMetadata # Shared counter for cross-process communication (inherited by fork) sync_count = multiprocessing.Value("i", 0) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index a0f140cca..b5ce37ea4 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -12,9 +12,9 @@ from tests.v1.attention.utils import ( try_get_attention_backend, ) from vllm.config import ParallelConfig, SpeculativeConfig +from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.attention.backends.utils import CommonAttentionMetadata if not is_flash_attn_varlen_func_available(): pytest.skip( diff --git a/vllm/model_executor/layers/attention/chunked_local_attention.py b/vllm/model_executor/layers/attention/chunked_local_attention.py index a34506934..8916ff0c4 100644 --- a/vllm/model_executor/layers/attention/chunked_local_attention.py +++ b/vllm/model_executor/layers/attention/chunked_local_attention.py @@ -8,11 +8,13 @@ from vllm.attention.layer import Attention from vllm.config import CacheConfig from vllm.config.vllm import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( + AttentionBackend, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, +) +from vllm.v1.attention.backends.utils import ( make_local_attention_virtual_batches, subclass_attention_backend, ) diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 9c3bc3403..a16981a83 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -14,9 +14,9 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionMetadata, AttentionType, + CommonAttentionMetadata, ) from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, subclass_attention_backend, ) from vllm.v1.attention.selector import get_attn_backend diff --git a/vllm/model_executor/layers/attention/encoder_only_attention.py b/vllm/model_executor/layers/attention/encoder_only_attention.py index c130fd095..8df9e05c8 100644 --- a/vllm/model_executor/layers/attention/encoder_only_attention.py +++ b/vllm/model_executor/layers/attention/encoder_only_attention.py @@ -12,9 +12,9 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionMetadata, AttentionType, + CommonAttentionMetadata, ) from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, subclass_attention_backend, ) from vllm.v1.attention.selector import get_attn_backend diff --git a/vllm/model_executor/layers/attention/static_sink_attention.py b/vllm/model_executor/layers/attention/static_sink_attention.py index 918dff560..f7ec382b3 100644 --- a/vllm/model_executor/layers/attention/static_sink_attention.py +++ b/vllm/model_executor/layers/attention/static_sink_attention.py @@ -15,9 +15,9 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionMetadata, AttentionType, + CommonAttentionMetadata, ) from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, subclass_attention_backend, ) from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( diff --git a/vllm/model_executor/models/whisper_utils.py b/vllm/model_executor/models/whisper_utils.py index 0bd0db061..d41ccde0a 100644 --- a/vllm/model_executor/models/whisper_utils.py +++ b/vllm/model_executor/models/whisper_utils.py @@ -16,10 +16,10 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionMetadata, AttentionType, + CommonAttentionMetadata, ) from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, subclass_attention_backend_with_overrides, ) from vllm.v1.attention.selector import get_attn_backend diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index e51a48cef..0fd3d6eb3 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -2,17 +2,22 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args +import numpy as np import torch +from typing_extensions import deprecated if TYPE_CHECKING: + from vllm.config import VllmConfig from vllm.config.cache import CacheDType from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backends.utils import KVCacheLayoutType + from vllm.v1.kv_cache_interface import AttentionSpec class AttentionType(str, Enum): @@ -271,6 +276,288 @@ class AttentionMetadata: T = TypeVar("T", bound=AttentionMetadata) +@dataclass +class CommonAttentionMetadata: + """ + Per-batch attention metadata, shared across layers and backends. + AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. + """ + + query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor + """(batch_size + 1,), the start location of each request in query Tensor""" + + seq_lens: torch.Tensor + """(batch_size,), the number of computed tokens for each request""" + + num_reqs: int + """Number of requests""" + # TODO(lucas): rename to num_tokens since it may be padded and this is misleading + num_actual_tokens: int + """Total number of tokens in batch""" + max_query_len: int + """Longest query in batch""" + max_seq_len: int + """Longest context length (may be an upper bound)""" + + block_table_tensor: torch.Tensor + slot_mapping: torch.Tensor + + causal: bool = True + + # Needed by FastPrefillAttentionBuilder + logits_indices_padded: torch.Tensor | None = None + num_logits_indices: int | None = None + + # Needed by CrossAttentionBuilder + encoder_seq_lens: torch.Tensor | None = None + encoder_seq_lens_cpu: np.ndarray | None = None + + dcp_local_seq_lens: torch.Tensor | None = None + dcp_local_seq_lens_cpu: torch.Tensor | None = None + """Sequence lengths of the local rank in decode context parallelism world""" + + # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) + _seq_lens_cpu: torch.Tensor | None = None + _num_computed_tokens_cpu: torch.Tensor | None = None + + _num_computed_tokens_cache: torch.Tensor | None = None + + @property + @deprecated( + """ + Prefer using device seq_lens directly to avoid implicit H<>D sync. + If a CPU copy is needed, use `seq_lens.cpu()` instead. + Will be removed in a future release (v0.15.0) + """ + ) + def seq_lens_cpu(self) -> torch.Tensor: + if self._seq_lens_cpu is None: + self._seq_lens_cpu = self.seq_lens.to("cpu") + return self._seq_lens_cpu + + @property + @deprecated( + """ + Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full + async scheduling. If a CPU copy is needed, it can be derived from + query_start_loc_cpu and seq_lens. + Will be removed in a future release (v0.15.0) + """ + ) + def num_computed_tokens_cpu(self) -> torch.Tensor: + if self._num_computed_tokens_cpu is None: + query_seq_lens = ( + self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] + ) + self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens + return self._num_computed_tokens_cpu + + def compute_num_computed_tokens(self) -> torch.Tensor: + """Compute num_computed_tokens on device (seq_lens - query_lens).""" + if self._num_computed_tokens_cache is None: + query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] + self._num_computed_tokens_cache = self.seq_lens - query_lens + return self._num_computed_tokens_cache + + # TODO(lucas): remove once we have FULL-CG spec-decode support + def unpadded( + self, num_actual_tokens: int, num_actual_reqs: int + ) -> "CommonAttentionMetadata": + maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None + return CommonAttentionMetadata( + query_start_loc=self.query_start_loc[: num_actual_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], + seq_lens=self.seq_lens[:num_actual_reqs], + _seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs] + if self._seq_lens_cpu is not None + else None, + _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] + if self._num_computed_tokens_cpu is not None + else None, + num_reqs=num_actual_reqs, + num_actual_tokens=num_actual_tokens, + max_query_len=self.max_query_len, + max_seq_len=self.max_seq_len, + block_table_tensor=self.block_table_tensor[:num_actual_reqs], + slot_mapping=self.slot_mapping[:num_actual_tokens], + causal=self.causal, + logits_indices_padded=self.logits_indices_padded, + num_logits_indices=self.num_logits_indices, + encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens), + encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu), + dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens), + dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu), + ) + + +M = TypeVar("M") + + +class AttentionCGSupport(Enum): + """Constants for the cudagraph support of the attention backend + Here we do not consider the cascade attention, as currently + it is never cudagraph supported.""" + + ALWAYS = 3 + """Cudagraph always supported; supports mixed-prefill-decode""" + UNIFORM_BATCH = 2 + """Cudagraph supported for batches the only contain query lengths that are + the same, this can be used for spec-decode + i.e. "decodes" are 1 + num_speculative_tokens""" + UNIFORM_SINGLE_TOKEN_DECODE = 1 + """Cudagraph supported for batches the only contain query_len==1 decodes""" + NEVER = 0 + """NO cudagraph support""" + + +class AttentionMetadataBuilder(ABC, Generic[M]): + # Does this backend/builder support CUDA Graphs for attention (default: no). + # Do not access directly. Call get_cudagraph_support() instead. + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + # Does this backend/builder reorder the batch? + # If not, set this to None. Otherwise set it to the query + # length that will be pulled into the front of the batch. + reorder_batch_threshold: int | None = None + # Does this backend/builder support updating the block table in existing + # metadata + supports_update_block_table: bool = False + + @abstractmethod + def __init__( + self, + kv_cache_spec: "AttentionSpec", + layer_names: list[str], + vllm_config: "VllmConfig", + device: torch.device, + ): + self.kv_cache_spec = kv_cache_spec + self.layer_names = layer_names + self.vllm_config = vllm_config + self.device = device + + @classmethod + def get_cudagraph_support( + cls: type["AttentionMetadataBuilder"], + vllm_config: "VllmConfig", + kv_cache_spec: "AttentionSpec", + ) -> AttentionCGSupport: + """Get the cudagraph support level of this builder class.""" + return cls._cudagraph_support + + def _init_reorder_batch_threshold( + self, + reorder_batch_threshold: int | None = 1, + supports_spec_as_decode: bool = False, + supports_dcp_with_varlen: bool = False, + ) -> None: + self.reorder_batch_threshold = reorder_batch_threshold + if self.reorder_batch_threshold is not None and supports_spec_as_decode: + # If the backend supports spec-as-decode kernels, then we can set + # the reorder_batch_threshold based on the number of speculative + # tokens from the config. + speculative_config = self.vllm_config.speculative_config + if ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + ): + self.reorder_batch_threshold = max( + self.reorder_batch_threshold, + 1 + speculative_config.num_speculative_tokens, + ) + + if ( + self.vllm_config.parallel_config.decode_context_parallel_size > 1 + and not supports_dcp_with_varlen + ): + self.reorder_batch_threshold = 1 + + @abstractmethod + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: + """ + Central method that builds attention metadata. + Some builders (MLA) require reorder_batch to be called prior to build. + + Args: + common_prefix_len: The length of the common prefix of the batch. + common_attn_metadata: The common attention metadata. + fast_build: The meta-data will prioritize speed of building over + then speed at execution. Can be used for spec-decode where the + result of a build call may only be used for few layers/iters. + """ + raise NotImplementedError + + def update_block_table( + self, + metadata: M, + blk_table: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> M: + """ + Update the block table for the attention metadata. + Faster when theres multiple kv-cache groups that create virtually the + same metadata but just with different block tables. + + Only needs to be implemented if supports_update_block_table is True. + """ + raise NotImplementedError + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: + """ + Build attention metadata for CUDA graph capture. Uses build by default. + Subclasses that override this method should call self.build or + super().build_for_cudagraph_capture. + """ + return self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + draft_index: int, + ) -> M: + """ + Build attention metadata for draft model. Uses build by default. + + Args: + common_attn_metadata: The common attention metadata. + draft_index: The index of the current draft operation. + When speculating a chain of tokens, this index refers to the + draft attempt for the i-th token. + For tree-based attention, this index instead refers to the + draft attempt for the i-th level in the tree of tokens. + """ + return self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True, + ) + + def use_cascade_attention( + self, + common_prefix_len: int, + query_lens: np.ndarray, + num_query_heads: int, + num_kv_heads: int, + use_alibi: bool, + use_sliding_window: bool, + use_local_attention: bool, + num_sms: int, + dcp_world_size: int, + ) -> bool: + return False + + class AttentionLayer(Protocol): _q_scale: torch.Tensor _k_scale: torch.Tensor diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 3fc53278a..3eb9b4782 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -13,12 +13,12 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionLayer, + AttentionMetadataBuilder, AttentionType, + CommonAttentionMetadata, is_quantized_kv_cache, ) from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index aa51c1a43..6fec5001b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -41,10 +41,12 @@ from vllm.model_executor.layers.batch_invariant import ( ) from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, +) +from vllm.v1.attention.backends.utils import ( get_dcp_local_seq_lens, get_kv_cache_layout, ) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 314a8f2bb..9892c360d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -43,14 +43,14 @@ from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import is_strictly_contiguous from vllm.v1.attention.backend import ( AttentionBackend, + AttentionCGSupport, AttentionImpl, + AttentionMetadataBuilder, AttentionType, + CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, KVCacheLayoutType, get_dcp_local_seq_lens, get_kv_cache_layout, diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 994bbe3c9..48c8ac6a8 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -32,12 +32,10 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, - AttentionType, - is_quantized_kv_cache, -) -from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, + AttentionType, CommonAttentionMetadata, + is_quantized_kv_cache, ) from vllm.v1.kv_cache_interface import AttentionSpec diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 1d58ac683..426c17689 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -7,12 +7,14 @@ from dataclasses import dataclass import torch from vllm.config import VllmConfig -from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, +from vllm.v1.attention.backend import ( + AttentionBackend, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, +) +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, compute_causal_conv1d_metadata, split_decodes_and_prefills, ) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index b1aad30ee..4ef565691 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -5,13 +5,13 @@ from dataclasses import dataclass import torch from vllm.config import VllmConfig -from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( + AttentionBackend, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - split_decodes_and_prefills, ) +from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index a5f661d5d..f45315f1e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -7,14 +7,11 @@ import torch from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backend import AttentionBackend +from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadata, BaseMambaAttentionMetadataBuilder, ) -from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, -) from vllm.v1.kv_cache_interface import AttentionSpec diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 2d4335664..0c55877a5 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -10,11 +10,13 @@ import torch from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, +from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, +) +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, compute_causal_conv1d_metadata, split_decodes_and_prefills, ) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a5bd949e9..5cd5cc566 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -217,12 +217,12 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionLayer, AttentionMetadata, + AttentionMetadataBuilder, + CommonAttentionMetadata, MLAAttentionImpl, ) from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, get_dcp_local_seq_lens, get_per_layer_parameters, infer_global_hyperparameters, diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 8cb8fa1f5..55a8703c6 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -11,6 +11,7 @@ from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backend import ( + AttentionCGSupport, AttentionLayer, AttentionType, MultipleOf, @@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonMetadata, MLACommonMetadataBuilder, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 2e0a19ac5..eedaef72d 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.batch_invariant import ( ) from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backend import ( + AttentionCGSupport, AttentionLayer, AttentionType, MultipleOf, @@ -31,7 +32,6 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] flash_attn_varlen_func, diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index c0442b13f..ffd2d47c8 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -10,6 +10,7 @@ from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backend import ( + AttentionCGSupport, AttentionLayer, AttentionType, MultipleOf, @@ -21,7 +22,7 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType +from vllm.v1.attention.backends.utils import KVCacheLayoutType logger = init_logger(__name__) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 24ef6dd4d..cb79f4541 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -13,7 +13,12 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.platforms.interface import DeviceCapability -from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf +from vllm.v1.attention.backend import ( + AttentionCGSupport, + AttentionLayer, + AttentionType, + MultipleOf, +) from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -23,7 +28,6 @@ from vllm.v1.attention.backends.mla.common import ( QueryLenSupport, ) from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, reshape_attn_output_for_spec_decode, reshape_query_for_spec_decode, ) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 282880adf..a2554a53a 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -16,15 +16,15 @@ from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import ( AttentionBackend, + AttentionCGSupport, AttentionLayer, AttentionMetadata, + AttentionMetadataBuilder, + CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, reshape_attn_output_for_spec_decode, reshape_query_for_spec_decode, split_decodes_and_prefills, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 351cbc8a6..3af785620 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -11,12 +11,12 @@ from vllm.platforms import current_platform from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.v1.attention.backend import ( AttentionBackend, - MultipleOf, -) -from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + MultipleOf, +) +from vllm.v1.attention.backends.utils import ( split_decodes_and_prefills, split_prefill_chunks, ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d43516e55..9eacd5ee7 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -8,7 +8,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig -from vllm.v1.attention.backend import AttentionLayer, MultipleOf +from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -17,7 +17,6 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 7d05879d9..997b1f62a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -13,18 +13,16 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backend import ( AttentionBackend, + AttentionCGSupport, AttentionLayer, AttentionMetadata, + AttentionMetadataBuilder, + CommonAttentionMetadata, ) from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.mla.flashmla_sparse import ( triton_convert_req_index_to_global_index, ) -from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, -) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index da14a8484..f384aaa46 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -15,14 +15,14 @@ from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import get_cu_count from vllm.v1.attention.backend import ( AttentionBackend, + AttentionCGSupport, AttentionImpl, + AttentionMetadataBuilder, AttentionType, + CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, split_decodes_prefills_and_extends, ) from vllm.v1.attention.ops.merge_attn_states import merge_attn_states diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 9d00d8fa6..6ec6825cc 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -16,16 +16,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.v1.attention.backend import ( AttentionBackend, + AttentionCGSupport, AttentionImpl, + AttentionMetadataBuilder, AttentionType, + CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, -) from vllm.v1.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode, ) diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index b6e58a25f..c9c85ddc7 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -14,12 +14,12 @@ from vllm.logger import init_logger from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, + AttentionMetadataBuilder, AttentionType, + CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, split_decodes_and_prefills, ) from vllm.v1.attention.ops.triton_unified_attention import unified_attention diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 3ef5b4a22..4cc438d9f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,14 +19,12 @@ from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import next_power_of_2 from vllm.v1.attention.backend import ( AttentionBackend, - AttentionImpl, - AttentionType, - MultipleOf, -) -from vllm.v1.attention.backends.utils import ( AttentionCGSupport, + AttentionImpl, AttentionMetadataBuilder, + AttentionType, CommonAttentionMetadata, + MultipleOf, ) from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 42c7ead72..c549bf7b5 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1,16 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import abc -import enum import functools -from abc import abstractmethod from collections.abc import Callable from dataclasses import dataclass, field, fields, make_dataclass from typing import ( TYPE_CHECKING, Any, - ClassVar, - Generic, Literal, Protocol, TypeVar, @@ -19,7 +14,7 @@ from typing import ( import numpy as np import torch -from typing_extensions import deprecated, runtime_checkable +from typing_extensions import runtime_checkable from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils.math_utils import cdiv @@ -38,8 +33,9 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionMetadata, + AttentionMetadataBuilder, + CommonAttentionMetadata, ) -from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.ubatch_utils import UBatchSlice logger = init_logger(__name__) @@ -53,123 +49,6 @@ def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) -@dataclass -class CommonAttentionMetadata: - """ - Per-batch attention metadata, shared across layers and backends. - AttentionMetadataBuilder instances use it to construct per-layer metadata. - - For many of the tensors we keep both GPU and CPU versions. - """ - - query_start_loc: torch.Tensor - query_start_loc_cpu: torch.Tensor - """(batch_size + 1,), the start location of each request in query Tensor""" - - seq_lens: torch.Tensor - """(batch_size,), the number of computed tokens for each request""" - - num_reqs: int - """Number of requests""" - # TODO(lucas): rename to num_tokens since it may be padded and this is misleading - num_actual_tokens: int - """Total number of tokens in batch""" - max_query_len: int - """Longest query in batch""" - max_seq_len: int - """Longest context length (may be an upper bound)""" - - block_table_tensor: torch.Tensor - slot_mapping: torch.Tensor - - causal: bool = True - - # Needed by FastPrefillAttentionBuilder - logits_indices_padded: torch.Tensor | None = None - num_logits_indices: int | None = None - - # Needed by CrossAttentionBuilder - encoder_seq_lens: torch.Tensor | None = None - encoder_seq_lens_cpu: np.ndarray | None = None - - dcp_local_seq_lens: torch.Tensor | None = None - dcp_local_seq_lens_cpu: torch.Tensor | None = None - """Sequence lengths of the local rank in decode context parallelism world""" - - # WARNING: Deprecated fields. Will be removed in a future release (v0.15.0) - _seq_lens_cpu: torch.Tensor | None = None - _num_computed_tokens_cpu: torch.Tensor | None = None - - _num_computed_tokens_cache: torch.Tensor | None = None - - @property - @deprecated( - """ - Prefer using device seq_lens directly to avoid implicit H<>D sync. - If a CPU copy is needed, use `seq_lens.cpu()` instead. - Will be removed in a future release (v0.15.0) - """ - ) - def seq_lens_cpu(self) -> torch.Tensor: - if self._seq_lens_cpu is None: - self._seq_lens_cpu = self.seq_lens.to("cpu") - return self._seq_lens_cpu - - @property - @deprecated( - """ - Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full - async scheduling. If a CPU copy is needed, it can be derived from - query_start_loc_cpu and seq_lens. - Will be removed in a future release (v0.15.0) - """ - ) - def num_computed_tokens_cpu(self) -> torch.Tensor: - if self._num_computed_tokens_cpu is None: - query_seq_lens = ( - self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1] - ) - self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens - return self._num_computed_tokens_cpu - - def compute_num_computed_tokens(self) -> torch.Tensor: - """Compute num_computed_tokens on device (seq_lens - query_lens).""" - if self._num_computed_tokens_cache is None: - query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1] - self._num_computed_tokens_cache = self.seq_lens - query_lens - return self._num_computed_tokens_cache - - # TODO(lucas): remove once we have FULL-CG spec-decode support - def unpadded( - self, num_actual_tokens: int, num_actual_reqs: int - ) -> "CommonAttentionMetadata": - maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None - return CommonAttentionMetadata( - query_start_loc=self.query_start_loc[: num_actual_reqs + 1], - query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1], - seq_lens=self.seq_lens[:num_actual_reqs], - _seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs] - if self._seq_lens_cpu is not None - else None, - _num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs] - if self._num_computed_tokens_cpu is not None - else None, - num_reqs=num_actual_reqs, - num_actual_tokens=num_actual_tokens, - max_query_len=self.max_query_len, - max_seq_len=self.max_seq_len, - block_table_tensor=self.block_table_tensor[:num_actual_reqs], - slot_mapping=self.slot_mapping[:num_actual_tokens], - causal=self.causal, - logits_indices_padded=self.logits_indices_padded, - num_logits_indices=self.num_logits_indices, - encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens), - encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu), - dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens), - dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu), - ) - - def slice_query_start_locs( query_start_loc: torch.Tensor, request_slice: slice, @@ -299,171 +178,6 @@ def split_attn_metadata( return results -M = TypeVar("M") - - -class AttentionCGSupport(enum.Enum): - """Constants for the cudagraph support of the attention backend - Here we do not consider the cascade attention, as currently - it is never cudagraph supported.""" - - ALWAYS = 3 - """Cudagraph always supported; supports mixed-prefill-decode""" - UNIFORM_BATCH = 2 - """Cudagraph supported for batches the only contain query lengths that are - the same, this can be used for spec-decode - i.e. "decodes" are 1 + num_speculative_tokens""" - UNIFORM_SINGLE_TOKEN_DECODE = 1 - """Cudagraph supported for batches the only contain query_len==1 decodes""" - NEVER = 0 - """NO cudagraph support""" - - -class AttentionMetadataBuilder(abc.ABC, Generic[M]): - # Does this backend/builder support CUDA Graphs for attention (default: no). - # Do not access directly. Call get_cudagraph_support() instead. - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER - # Does this backend/builder reorder the batch? - # If not, set this to None. Otherwise set it to the query - # length that will be pulled into the front of the batch. - reorder_batch_threshold: int | None = None - # Does this backend/builder support updating the block table in existing - # metadata - supports_update_block_table: bool = False - - @abstractmethod - def __init__( - self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - ): - self.kv_cache_spec = kv_cache_spec - self.layer_names = layer_names - self.vllm_config = vllm_config - self.device = device - - @classmethod - def get_cudagraph_support( - cls: type["AttentionMetadataBuilder"], - vllm_config: VllmConfig, - kv_cache_spec: AttentionSpec, - ) -> AttentionCGSupport: - """Get the cudagraph support level of this builder class.""" - return cls._cudagraph_support - - def _init_reorder_batch_threshold( - self, - reorder_batch_threshold: int | None = 1, - supports_spec_as_decode: bool = False, - supports_dcp_with_varlen: bool = False, - ) -> None: - self.reorder_batch_threshold = reorder_batch_threshold - if self.reorder_batch_threshold is not None and supports_spec_as_decode: - # If the backend supports spec-as-decode kernels, then we can set - # the reorder_batch_threshold based on the number of speculative - # tokens from the config. - speculative_config = self.vllm_config.speculative_config - if ( - speculative_config is not None - and speculative_config.num_speculative_tokens is not None - ): - self.reorder_batch_threshold = max( - self.reorder_batch_threshold, - 1 + speculative_config.num_speculative_tokens, - ) - - if ( - self.vllm_config.parallel_config.decode_context_parallel_size > 1 - and not supports_dcp_with_varlen - ): - self.reorder_batch_threshold = 1 - - @abstractmethod - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> M: - """ - Central method that builds attention metadata. - Some builders (MLA) require reorder_batch to be called prior to build. - - Args: - common_prefix_len: The length of the common prefix of the batch. - common_attn_metadata: The common attention metadata. - fast_build: The meta-data will prioritize speed of building over - then speed at execution. Can be used for spec-decode where the - result of a build call may only be used for few layers/iters. - """ - raise NotImplementedError - - def update_block_table( - self, - metadata: M, - blk_table: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> M: - """ - Update the block table for the attention metadata. - Faster when theres multiple kv-cache groups that create virtually the - same metadata but just with different block tables. - - Only needs to be implemented if supports_update_block_table is True. - """ - raise NotImplementedError - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata - ) -> M: - """ - Build attention metadata for CUDA graph capture. Uses build by default. - Subclasses that override this method should call self.build or - super().build_for_cudagraph_capture. - """ - return self.build( - common_prefix_len=0, common_attn_metadata=common_attn_metadata - ) - - def build_for_drafting( - self, - common_attn_metadata: CommonAttentionMetadata, - draft_index: int, - ) -> M: - """ - Build attention metadata for draft model. Uses build by default. - - Args: - common_attn_metadata: The common attention metadata. - draft_index: The index of the current draft operation. - When speculating a chain of tokens, this index refers to the - draft attempt for the i-th token. - For tree-based attention, this index instead refers to the - draft attempt for the i-th level in the tree of tokens. - """ - return self.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - fast_build=True, - ) - - def use_cascade_attention( - self, - common_prefix_len: int, - query_lens: np.ndarray, - num_query_heads: int, - num_kv_heads: int, - use_alibi: bool, - use_sliding_window: bool, - use_local_attention: bool, - num_sms: int, - dcp_world_size: int, - ) -> bool: - return False - - @functools.lru_cache def get_kv_cache_layout(): # Format specified by the code. @@ -834,6 +548,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( return common_attn_metadata +M = TypeVar("M") + + def subclass_attention_backend( name_prefix: str, attention_backend_cls: type[AttentionBackend], diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cd4f55b79..9820f5109 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -26,16 +26,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.triton_utils import triton from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backend import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.tree_attn import ( TreeAttentionMetadata, TreeAttentionMetadataBuilder, ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, -) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index f62a71858..70c622fc0 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -7,8 +7,8 @@ import torch from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.v1.attention.backend import AttentionBackend -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( + AttentionBackend, AttentionMetadataBuilder, CommonAttentionMetadata, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e08463c40..525ad5db4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -100,15 +100,15 @@ from vllm.utils.torch_utils import ( ) from vllm.v1.attention.backend import ( AttentionBackend, + AttentionCGSupport, AttentionMetadata, + AttentionMetadataBuilder, AttentionType, + CommonAttentionMetadata, MultipleOf, ) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, create_fast_prefill_custom_backend, get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills,