[3/N][Attention] Move AttentionMetadata-related code from utils.py to backend.py (#32054)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-12 12:13:56 -05:00
committed by GitHub
parent 7c0d3c5152
commit 20228cb851
37 changed files with 374 additions and 370 deletions

View File

@@ -149,7 +149,7 @@ The CUDA Graphs wrapper no longer manages the warm-up logic. The warm-up process
## CUDA Graphs Compatibility of Attention Backends
To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backends.utils.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`.
To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backend.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`.
```python
class AttentionCGSupport(enum.Enum):

View File

@@ -23,10 +23,9 @@ from vllm.utils.torch_utils import (
is_torch_equal_or_newer,
set_random_seed,
)
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
set_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@@ -22,10 +22,10 @@ from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@@ -18,12 +18,12 @@ from vllm.config import (
VllmConfig,
)
from vllm.config.model import ModelDType
from vllm.v1.attention.backend import AttentionImpl
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import (
from vllm.v1.attention.backend import (
AttentionImpl,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@@ -19,7 +19,7 @@ def sync_tracker():
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backend import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)

View File

@@ -12,9 +12,9 @@ from tests.v1.attention.utils import (
try_get_attention_backend,
)
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
if not is_flash_attn_varlen_func_available():
pytest.skip(

View File

@@ -8,11 +8,13 @@ from vllm.attention.layer import Attention
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import (
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches,
subclass_attention_backend,
)

View File

@@ -14,9 +14,9 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.attention.selector import get_attn_backend

View File

@@ -12,9 +12,9 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.attention.selector import get_attn_backend

View File

@@ -15,9 +15,9 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (

View File

@@ -16,10 +16,10 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.selector import get_attn_backend

View File

@@ -2,17 +2,22 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import numpy as np
import torch
from typing_extensions import deprecated
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
from vllm.v1.kv_cache_interface import AttentionSpec
class AttentionType(str, Enum):
@@ -271,6 +276,288 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata)
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length (may be an upper bound)"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int | None = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: torch.Tensor | None = None
encoder_seq_lens_cpu: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
_num_computed_tokens_cache: torch.Tensor | None = None
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync.
If a CPU copy is needed, use `seq_lens.cpu()` instead.
Will be removed in a future release (v0.15.0)
"""
)
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
async scheduling. If a CPU copy is needed, it can be derived from
query_start_loc_cpu and seq_lens.
Will be removed in a future release (v0.15.0)
"""
)
def num_computed_tokens_cpu(self) -> torch.Tensor:
if self._num_computed_tokens_cpu is None:
query_seq_lens = (
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
)
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu
def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache
# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int
) -> "CommonAttentionMetadata":
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
return CommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
if self._num_computed_tokens_cpu is not None
else None,
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
max_seq_len=self.max_seq_len,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
logits_indices_padded=self.logits_indices_padded,
num_logits_indices=self.num_logits_indices,
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
)
M = TypeVar("M")
class AttentionCGSupport(Enum):
"""Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
class AttentionMetadataBuilder(ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
# Do not access directly. Call get_cudagraph_support() instead.
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod
def __init__(
self,
kv_cache_spec: "AttentionSpec",
layer_names: list[str],
vllm_config: "VllmConfig",
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: "VllmConfig",
kv_cache_spec: "AttentionSpec",
) -> AttentionCGSupport:
"""Get the cudagraph support level of this builder class."""
return cls._cudagraph_support
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
@abstractmethod
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True,
)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
dcp_world_size: int,
) -> bool:
return False
class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor

View File

@@ -13,12 +13,12 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec

View File

@@ -41,10 +41,12 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_kv_cache_layout,
)

View File

@@ -43,14 +43,14 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import is_strictly_contiguous
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
KVCacheLayoutType,
get_dcp_local_seq_lens,
get_kv_cache_layout,

View File

@@ -32,12 +32,10 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -7,12 +7,14 @@ from dataclasses import dataclass
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)

View File

@@ -5,13 +5,13 @@ from dataclasses import dataclass
import torch
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import (
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec

View File

@@ -7,14 +7,11 @@ import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -10,11 +10,13 @@ import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)

View File

@@ -217,12 +217,12 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MLAAttentionImpl,
)
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_per_layer_parameters,
infer_global_hyperparameters,

View File

@@ -11,6 +11,7 @@ from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__)

View File

@@ -14,6 +14,7 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
@@ -31,7 +32,6 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,

View File

@@ -10,6 +10,7 @@ from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
@@ -21,7 +22,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__)

View File

@@ -13,7 +13,12 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
@@ -23,7 +28,6 @@ from vllm.v1.attention.backends.mla.common import (
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
)

View File

@@ -16,15 +16,15 @@ from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
split_decodes_and_prefills,

View File

@@ -11,12 +11,12 @@ from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backend import (
AttentionBackend,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)

View File

@@ -8,7 +8,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionLayer, MultipleOf
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
@@ -17,7 +17,6 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -13,18 +13,16 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:

View File

@@ -15,14 +15,14 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_prefills_and_extends,
)
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states

View File

@@ -16,16 +16,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
)

View File

@@ -14,12 +14,12 @@ from vllm.logger import init_logger
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.triton_unified_attention import unified_attention

View File

@@ -19,14 +19,12 @@ from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (

View File

@@ -1,16 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
Literal,
Protocol,
TypeVar,
@@ -19,7 +14,7 @@ from typing import (
import numpy as np
import torch
from typing_extensions import deprecated, runtime_checkable
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
@@ -38,8 +33,9 @@ from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__)
@@ -53,123 +49,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length (may be an upper bound)"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int | None = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: torch.Tensor | None = None
encoder_seq_lens_cpu: np.ndarray | None = None
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
"""Sequence lengths of the local rank in decode context parallelism world"""
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
_seq_lens_cpu: torch.Tensor | None = None
_num_computed_tokens_cpu: torch.Tensor | None = None
_num_computed_tokens_cache: torch.Tensor | None = None
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync.
If a CPU copy is needed, use `seq_lens.cpu()` instead.
Will be removed in a future release (v0.15.0)
"""
)
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@property
@deprecated(
"""
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
async scheduling. If a CPU copy is needed, it can be derived from
query_start_loc_cpu and seq_lens.
Will be removed in a future release (v0.15.0)
"""
)
def num_computed_tokens_cpu(self) -> torch.Tensor:
if self._num_computed_tokens_cpu is None:
query_seq_lens = (
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
)
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
return self._num_computed_tokens_cpu
def compute_num_computed_tokens(self) -> torch.Tensor:
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
if self._num_computed_tokens_cache is None:
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
self._num_computed_tokens_cache = self.seq_lens - query_lens
return self._num_computed_tokens_cache
# TODO(lucas): remove once we have FULL-CG spec-decode support
def unpadded(
self, num_actual_tokens: int, num_actual_reqs: int
) -> "CommonAttentionMetadata":
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
return CommonAttentionMetadata(
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
seq_lens=self.seq_lens[:num_actual_reqs],
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
if self._seq_lens_cpu is not None
else None,
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
if self._num_computed_tokens_cpu is not None
else None,
num_reqs=num_actual_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=self.max_query_len,
max_seq_len=self.max_seq_len,
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
slot_mapping=self.slot_mapping[:num_actual_tokens],
causal=self.causal,
logits_indices_padded=self.logits_indices_padded,
num_logits_indices=self.num_logits_indices,
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
)
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
@@ -299,171 +178,6 @@ def split_attn_metadata(
return results
M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
"""Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
# Do not access directly. Call get_cudagraph_support() instead.
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: int | None = None
# Does this backend/builder support updating the block table in existing
# metadata
supports_update_block_table: bool = False
@abstractmethod
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
"""Get the cudagraph support level of this builder class."""
return cls._cudagraph_support
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
@abstractmethod
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
def update_block_table(
self,
metadata: M,
blk_table: torch.Tensor,
slot_mapping: torch.Tensor,
) -> M:
"""
Update the block table for the attention metadata.
Faster when theres multiple kv-cache groups that create virtually the
same metadata but just with different block tables.
Only needs to be implemented if supports_update_block_table is True.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True,
)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
dcp_world_size: int,
) -> bool:
return False
@functools.lru_cache
def get_kv_cache_layout():
# Format specified by the code.
@@ -834,6 +548,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
return common_attn_metadata
M = TypeVar("M")
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],

View File

@@ -26,16 +26,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadata,
TreeAttentionMetadataBuilder,
)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS

View File

@@ -7,8 +7,8 @@ import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import (
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)

View File

@@ -100,15 +100,15 @@ from vllm.utils.torch_utils import (
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
create_fast_prefill_custom_backend,
get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills,