[3/N][Attention] Move AttentionMetadata-related code from utils.py to backend.py (#32054)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -149,7 +149,7 @@ The CUDA Graphs wrapper no longer manages the warm-up logic. The warm-up process
|
||||
|
||||
## CUDA Graphs Compatibility of Attention Backends
|
||||
|
||||
To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backends.utils.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`.
|
||||
To signal the CUDA Graphs compatibility of the attention backends, we introduce a new enum type [AttentionCGSupport][vllm.v1.attention.backend.AttentionCGSupport], which is an enum type that tracks the capability of the attention backend to support CUDA Graphs. The value is sorted in the order of the capability, i.e., `ALWAYS`> `UNIFORM_BATCH`> `UNIFORM_SINGLE_TOKEN_DECODE`> `NEVER`.
|
||||
|
||||
```python
|
||||
class AttentionCGSupport(enum.Enum):
|
||||
|
||||
@@ -23,10 +23,9 @@ from vllm.utils.torch_utils import (
|
||||
is_torch_equal_or_newer,
|
||||
set_random_seed,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
from vllm.v1.attention.backend import AttentionType, CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
set_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
@@ -22,10 +22,10 @@ from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
||||
@@ -18,12 +18,12 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.model import ModelDType
|
||||
from vllm.v1.attention.backend import AttentionImpl
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def sync_tracker():
|
||||
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
|
||||
lazy init syncs. Prints stack traces immediately when syncs occur.
|
||||
"""
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
|
||||
# Shared counter for cross-process communication (inherited by fork)
|
||||
sync_count = multiprocessing.Value("i", 0)
|
||||
|
||||
@@ -12,9 +12,9 @@ from tests.v1.attention.utils import (
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
if not is_flash_attn_varlen_func_available():
|
||||
pytest.skip(
|
||||
|
||||
@@ -8,11 +8,13 @@ from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
make_local_attention_virtual_batches,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
|
||||
@@ -14,9 +14,9 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
@@ -12,9 +12,9 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
@@ -15,9 +15,9 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
|
||||
@@ -16,10 +16,10 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
@@ -2,17 +2,22 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AttentionType(str, Enum):
|
||||
@@ -271,6 +276,288 @@ class AttentionMetadata:
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
max_seq_len: int
|
||||
"""Longest context length (may be an upper bound)"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# Needed by FastPrefillAttentionBuilder
|
||||
logits_indices_padded: torch.Tensor | None = None
|
||||
num_logits_indices: int | None = None
|
||||
|
||||
# Needed by CrossAttentionBuilder
|
||||
encoder_seq_lens: torch.Tensor | None = None
|
||||
encoder_seq_lens_cpu: np.ndarray | None = None
|
||||
|
||||
dcp_local_seq_lens: torch.Tensor | None = None
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
|
||||
_seq_lens_cpu: torch.Tensor | None = None
|
||||
_num_computed_tokens_cpu: torch.Tensor | None = None
|
||||
|
||||
_num_computed_tokens_cache: torch.Tensor | None = None
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync.
|
||||
If a CPU copy is needed, use `seq_lens.cpu()` instead.
|
||||
Will be removed in a future release (v0.15.0)
|
||||
"""
|
||||
)
|
||||
def seq_lens_cpu(self) -> torch.Tensor:
|
||||
if self._seq_lens_cpu is None:
|
||||
self._seq_lens_cpu = self.seq_lens.to("cpu")
|
||||
return self._seq_lens_cpu
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
|
||||
async scheduling. If a CPU copy is needed, it can be derived from
|
||||
query_start_loc_cpu and seq_lens.
|
||||
Will be removed in a future release (v0.15.0)
|
||||
"""
|
||||
)
|
||||
def num_computed_tokens_cpu(self) -> torch.Tensor:
|
||||
if self._num_computed_tokens_cpu is None:
|
||||
query_seq_lens = (
|
||||
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
|
||||
)
|
||||
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
|
||||
return self._num_computed_tokens_cpu
|
||||
|
||||
def compute_num_computed_tokens(self) -> torch.Tensor:
|
||||
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
|
||||
if self._num_computed_tokens_cache is None:
|
||||
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
|
||||
self._num_computed_tokens_cache = self.seq_lens - query_lens
|
||||
return self._num_computed_tokens_cache
|
||||
|
||||
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
||||
def unpadded(
|
||||
self, num_actual_tokens: int, num_actual_reqs: int
|
||||
) -> "CommonAttentionMetadata":
|
||||
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
||||
if self._seq_lens_cpu is not None
|
||||
else None,
|
||||
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
|
||||
if self._num_computed_tokens_cpu is not None
|
||||
else None,
|
||||
num_reqs=num_actual_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=self.max_query_len,
|
||||
max_seq_len=self.max_seq_len,
|
||||
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||
causal=self.causal,
|
||||
logits_indices_padded=self.logits_indices_padded,
|
||||
num_logits_indices=self.num_logits_indices,
|
||||
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
|
||||
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
|
||||
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
|
||||
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionCGSupport(Enum):
|
||||
"""Constants for the cudagraph support of the attention backend
|
||||
Here we do not consider the cascade attention, as currently
|
||||
it is never cudagraph supported."""
|
||||
|
||||
ALWAYS = 3
|
||||
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
||||
UNIFORM_BATCH = 2
|
||||
"""Cudagraph supported for batches the only contain query lengths that are
|
||||
the same, this can be used for spec-decode
|
||||
i.e. "decodes" are 1 + num_speculative_tokens"""
|
||||
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
||||
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
||||
NEVER = 0
|
||||
"""NO cudagraph support"""
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||
# Do not access directly. Call get_cudagraph_support() instead.
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: int | None = None
|
||||
# Does this backend/builder support updating the block table in existing
|
||||
# metadata
|
||||
supports_update_block_table: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: "AttentionSpec",
|
||||
layer_names: list[str],
|
||||
vllm_config: "VllmConfig",
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.layer_names = layer_names
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AttentionMetadataBuilder"],
|
||||
vllm_config: "VllmConfig",
|
||||
kv_cache_spec: "AttentionSpec",
|
||||
) -> AttentionCGSupport:
|
||||
"""Get the cudagraph support level of this builder class."""
|
||||
return cls._cudagraph_support
|
||||
|
||||
def _init_reorder_batch_threshold(
|
||||
self,
|
||||
reorder_batch_threshold: int | None = 1,
|
||||
supports_spec_as_decode: bool = False,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
) -> None:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold
|
||||
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
||||
# If the backend supports spec-as-decode kernels, then we can set
|
||||
# the reorder_batch_threshold based on the number of speculative
|
||||
# tokens from the config.
|
||||
speculative_config = self.vllm_config.speculative_config
|
||||
if (
|
||||
speculative_config is not None
|
||||
and speculative_config.num_speculative_tokens is not None
|
||||
):
|
||||
self.reorder_batch_threshold = max(
|
||||
self.reorder_batch_threshold,
|
||||
1 + speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and not supports_dcp_with_varlen
|
||||
):
|
||||
self.reorder_batch_threshold = 1
|
||||
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> M:
|
||||
"""
|
||||
Central method that builds attention metadata.
|
||||
Some builders (MLA) require reorder_batch to be called prior to build.
|
||||
|
||||
Args:
|
||||
common_prefix_len: The length of the common prefix of the batch.
|
||||
common_attn_metadata: The common attention metadata.
|
||||
fast_build: The meta-data will prioritize speed of building over
|
||||
then speed at execution. Can be used for spec-decode where the
|
||||
result of a build call may only be used for few layers/iters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
"""
|
||||
Update the block table for the attention metadata.
|
||||
Faster when theres multiple kv-cache groups that create virtually the
|
||||
same metadata but just with different block tables.
|
||||
|
||||
Only needs to be implemented if supports_update_block_table is True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
||||
Subclasses that override this method should call self.build or
|
||||
super().build_for_cudagraph_capture.
|
||||
"""
|
||||
return self.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for draft model. Uses build by default.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: The common attention metadata.
|
||||
draft_index: The index of the current draft operation.
|
||||
When speculating a chain of tokens, this index refers to the
|
||||
draft attempt for the i-th token.
|
||||
For tree-based attention, this index instead refers to the
|
||||
draft attempt for the i-th level in the tree of tokens.
|
||||
"""
|
||||
return self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
fast_build=True,
|
||||
)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
use_local_attention: bool,
|
||||
num_sms: int,
|
||||
dcp_world_size: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
|
||||
@@ -13,12 +13,12 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||
|
||||
@@ -41,10 +41,12 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
)
|
||||
|
||||
@@ -43,14 +43,14 @@ from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import is_strictly_contiguous
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
KVCacheLayoutType,
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
|
||||
@@ -32,12 +32,10 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
@@ -7,12 +7,14 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
@@ -5,13 +5,13 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
|
||||
@@ -7,14 +7,11 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadata,
|
||||
BaseMambaAttentionMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
|
||||
@@ -10,11 +10,13 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
@@ -217,12 +217,12 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_dcp_local_seq_lens,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
@@ -31,7 +32,6 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
flash_attn_varlen_func,
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
@@ -21,7 +22,7 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -13,7 +13,12 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@@ -23,7 +28,6 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
|
||||
@@ -16,15 +16,15 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
split_decodes_and_prefills,
|
||||
|
||||
@@ -11,12 +11,12 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionLayer, MultipleOf
|
||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@@ -17,7 +17,6 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
|
||||
@@ -13,18 +13,16 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -15,14 +15,14 @@ from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_prefills_and_extends,
|
||||
)
|
||||
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
|
||||
|
||||
@@ -16,16 +16,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode,
|
||||
)
|
||||
|
||||
@@ -14,12 +14,12 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
|
||||
@@ -19,14 +19,12 @@ from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import abc
|
||||
import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field, fields, make_dataclass
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
@@ -19,7 +14,7 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated, runtime_checkable
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@@ -38,8 +33,9 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -53,123 +49,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
|
||||
return value in get_args(KVCacheLayoutType)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
query_start_loc_cpu: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the number of computed tokens for each request"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
max_seq_len: int
|
||||
"""Longest context length (may be an upper bound)"""
|
||||
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# Needed by FastPrefillAttentionBuilder
|
||||
logits_indices_padded: torch.Tensor | None = None
|
||||
num_logits_indices: int | None = None
|
||||
|
||||
# Needed by CrossAttentionBuilder
|
||||
encoder_seq_lens: torch.Tensor | None = None
|
||||
encoder_seq_lens_cpu: np.ndarray | None = None
|
||||
|
||||
dcp_local_seq_lens: torch.Tensor | None = None
|
||||
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
||||
"""Sequence lengths of the local rank in decode context parallelism world"""
|
||||
|
||||
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
|
||||
_seq_lens_cpu: torch.Tensor | None = None
|
||||
_num_computed_tokens_cpu: torch.Tensor | None = None
|
||||
|
||||
_num_computed_tokens_cache: torch.Tensor | None = None
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync.
|
||||
If a CPU copy is needed, use `seq_lens.cpu()` instead.
|
||||
Will be removed in a future release (v0.15.0)
|
||||
"""
|
||||
)
|
||||
def seq_lens_cpu(self) -> torch.Tensor:
|
||||
if self._seq_lens_cpu is None:
|
||||
self._seq_lens_cpu = self.seq_lens.to("cpu")
|
||||
return self._seq_lens_cpu
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
|
||||
async scheduling. If a CPU copy is needed, it can be derived from
|
||||
query_start_loc_cpu and seq_lens.
|
||||
Will be removed in a future release (v0.15.0)
|
||||
"""
|
||||
)
|
||||
def num_computed_tokens_cpu(self) -> torch.Tensor:
|
||||
if self._num_computed_tokens_cpu is None:
|
||||
query_seq_lens = (
|
||||
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
|
||||
)
|
||||
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
|
||||
return self._num_computed_tokens_cpu
|
||||
|
||||
def compute_num_computed_tokens(self) -> torch.Tensor:
|
||||
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
|
||||
if self._num_computed_tokens_cache is None:
|
||||
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
|
||||
self._num_computed_tokens_cache = self.seq_lens - query_lens
|
||||
return self._num_computed_tokens_cache
|
||||
|
||||
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
||||
def unpadded(
|
||||
self, num_actual_tokens: int, num_actual_reqs: int
|
||||
) -> "CommonAttentionMetadata":
|
||||
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
|
||||
return CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_actual_reqs],
|
||||
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
||||
if self._seq_lens_cpu is not None
|
||||
else None,
|
||||
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
|
||||
if self._num_computed_tokens_cpu is not None
|
||||
else None,
|
||||
num_reqs=num_actual_reqs,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=self.max_query_len,
|
||||
max_seq_len=self.max_seq_len,
|
||||
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||
causal=self.causal,
|
||||
logits_indices_padded=self.logits_indices_padded,
|
||||
num_logits_indices=self.num_logits_indices,
|
||||
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
|
||||
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
|
||||
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
|
||||
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
|
||||
)
|
||||
|
||||
|
||||
def slice_query_start_locs(
|
||||
query_start_loc: torch.Tensor,
|
||||
request_slice: slice,
|
||||
@@ -299,171 +178,6 @@ def split_attn_metadata(
|
||||
return results
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionCGSupport(enum.Enum):
|
||||
"""Constants for the cudagraph support of the attention backend
|
||||
Here we do not consider the cascade attention, as currently
|
||||
it is never cudagraph supported."""
|
||||
|
||||
ALWAYS = 3
|
||||
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
||||
UNIFORM_BATCH = 2
|
||||
"""Cudagraph supported for batches the only contain query lengths that are
|
||||
the same, this can be used for spec-decode
|
||||
i.e. "decodes" are 1 + num_speculative_tokens"""
|
||||
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
||||
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
||||
NEVER = 0
|
||||
"""NO cudagraph support"""
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
||||
# Do not access directly. Call get_cudagraph_support() instead.
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: int | None = None
|
||||
# Does this backend/builder support updating the block table in existing
|
||||
# metadata
|
||||
supports_update_block_table: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.layer_names = layer_names
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AttentionMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
"""Get the cudagraph support level of this builder class."""
|
||||
return cls._cudagraph_support
|
||||
|
||||
def _init_reorder_batch_threshold(
|
||||
self,
|
||||
reorder_batch_threshold: int | None = 1,
|
||||
supports_spec_as_decode: bool = False,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
) -> None:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold
|
||||
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
||||
# If the backend supports spec-as-decode kernels, then we can set
|
||||
# the reorder_batch_threshold based on the number of speculative
|
||||
# tokens from the config.
|
||||
speculative_config = self.vllm_config.speculative_config
|
||||
if (
|
||||
speculative_config is not None
|
||||
and speculative_config.num_speculative_tokens is not None
|
||||
):
|
||||
self.reorder_batch_threshold = max(
|
||||
self.reorder_batch_threshold,
|
||||
1 + speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and not supports_dcp_with_varlen
|
||||
):
|
||||
self.reorder_batch_threshold = 1
|
||||
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> M:
|
||||
"""
|
||||
Central method that builds attention metadata.
|
||||
Some builders (MLA) require reorder_batch to be called prior to build.
|
||||
|
||||
Args:
|
||||
common_prefix_len: The length of the common prefix of the batch.
|
||||
common_attn_metadata: The common attention metadata.
|
||||
fast_build: The meta-data will prioritize speed of building over
|
||||
then speed at execution. Can be used for spec-decode where the
|
||||
result of a build call may only be used for few layers/iters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
"""
|
||||
Update the block table for the attention metadata.
|
||||
Faster when theres multiple kv-cache groups that create virtually the
|
||||
same metadata but just with different block tables.
|
||||
|
||||
Only needs to be implemented if supports_update_block_table is True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
||||
Subclasses that override this method should call self.build or
|
||||
super().build_for_cudagraph_capture.
|
||||
"""
|
||||
return self.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for draft model. Uses build by default.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: The common attention metadata.
|
||||
draft_index: The index of the current draft operation.
|
||||
When speculating a chain of tokens, this index refers to the
|
||||
draft attempt for the i-th token.
|
||||
For tree-based attention, this index instead refers to the
|
||||
draft attempt for the i-th level in the tree of tokens.
|
||||
"""
|
||||
return self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
fast_build=True,
|
||||
)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
use_local_attention: bool,
|
||||
num_sms: int,
|
||||
dcp_world_size: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Format specified by the code.
|
||||
@@ -834,6 +548,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
|
||||
return common_attn_metadata
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str,
|
||||
attention_backend_cls: type[AttentionBackend],
|
||||
|
||||
@@ -26,16 +26,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.backends.tree_attn import (
|
||||
TreeAttentionMetadata,
|
||||
TreeAttentionMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
|
||||
@@ -7,8 +7,8 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
|
||||
@@ -100,15 +100,15 @@ from vllm.utils.torch_utils import (
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
create_fast_prefill_custom_backend,
|
||||
get_dcp_local_seq_lens,
|
||||
reorder_batch_to_split_decodes_and_prefills,
|
||||
|
||||
Reference in New Issue
Block a user