Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com> Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1009 lines
34 KiB
Python
1009 lines
34 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, replace
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
|
|
|
|
import numpy as np
|
|
import torch
|
|
from typing_extensions import deprecated
|
|
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
kFp8StaticTensorSym,
|
|
kNvfp4Dynamic,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
from vllm.config.cache import CacheDType
|
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
|
from vllm.platforms.interface import DeviceCapability
|
|
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
|
from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode
|
|
|
|
from vllm.v1.kv_cache_interface import get_kv_quant_mode
|
|
|
|
|
|
class AttentionType(str, Enum):
|
|
"""
|
|
Attention type.
|
|
Use string to be compatible with `torch.compile`.
|
|
"""
|
|
|
|
DECODER = "decoder"
|
|
"""Decoder attention between previous layer Q/K/V."""
|
|
ENCODER = "encoder"
|
|
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
|
ENCODER_ONLY = "encoder_only"
|
|
"""Encoder attention between previous layer Q/K/V."""
|
|
ENCODER_DECODER = "encoder_decoder"
|
|
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
|
|
|
|
|
class MultipleOf:
|
|
base: int
|
|
|
|
def __init__(self, base: int):
|
|
self.base = base
|
|
|
|
|
|
class AttentionBackend(ABC):
|
|
"""Abstract class for attention backends."""
|
|
|
|
# For some attention backends, we allocate an output tensor before
|
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
|
# makes sure the output tensor is allocated inside the cudagraph.
|
|
accept_output_buffer: bool = False
|
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
|
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [
|
|
"auto",
|
|
"float16",
|
|
"bfloat16",
|
|
]
|
|
|
|
# Does attention's forward() include kv cache update?
|
|
forward_includes_kv_cache_update: bool = True
|
|
|
|
@staticmethod
|
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
|
return [MultipleOf(1)]
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_name() -> str:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_impl_cls() -> type["AttentionImplBase"]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def get_kv_cache_block_dim(
|
|
cls,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> int:
|
|
"""Discover which tensor dim is the block index, since different
|
|
backends lay out dims differently."""
|
|
_S = 1234567
|
|
shape = cls.get_kv_cache_shape(
|
|
_S,
|
|
block_size,
|
|
num_kv_heads,
|
|
head_size,
|
|
cache_dtype_str=cache_dtype_str,
|
|
)
|
|
return shape.index(_S)
|
|
|
|
@staticmethod
|
|
def get_kv_cache_stride_order(
|
|
include_num_layers_dimension: bool = False,
|
|
) -> tuple[int, ...]:
|
|
"""
|
|
Get the physical (memory layout) ordering of the kv cache dimensions.
|
|
e.g. if the KV cache shape is
|
|
[2, num_blocks, block_size, num_heads, head_size],
|
|
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
|
|
ordering of dimensions is
|
|
[num_blocks, num_heads, 2, block_size, head_size].
|
|
|
|
If this function is unimplemented / raises NotImplementedError,
|
|
the physical layout of the KV cache will match the logical shape.
|
|
|
|
Args:
|
|
include_num_layers_dimension: if True, includes an additional
|
|
num_layers dimension, which is assumed to be prepended
|
|
to the logical KV cache shape.
|
|
With the above example, a return value (2, 4, 0, 1, 3, 5)
|
|
corresponds to
|
|
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
|
|
|
|
If an additional dimension is NOT included in the returned
|
|
tuple, the physical layout will not include a layers dimension.
|
|
|
|
Returns:
|
|
A tuple of ints which is a permutation of range(len(shape)).
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def full_cls_name(cls) -> tuple[str, str]:
|
|
return (cls.__module__, cls.__qualname__)
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return []
|
|
|
|
@classmethod
|
|
def supports_head_size(cls, head_size: int) -> bool:
|
|
supported_head_sizes = cls.get_supported_head_sizes()
|
|
return (not supported_head_sizes) or head_size in supported_head_sizes
|
|
|
|
@classmethod
|
|
def supports_dtype(cls, dtype: torch.dtype) -> bool:
|
|
return dtype in cls.supported_dtypes
|
|
|
|
@classmethod
|
|
def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
|
|
if kv_cache_dtype is None:
|
|
return True
|
|
return (not cls.supported_kv_cache_dtypes) or (
|
|
kv_cache_dtype in cls.supported_kv_cache_dtypes
|
|
)
|
|
|
|
@classmethod
|
|
def supports_block_size(cls, block_size: int | None) -> bool:
|
|
if block_size is None:
|
|
return True
|
|
|
|
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
|
|
if not supported_kernel_block_sizes:
|
|
return True
|
|
|
|
for supported_size in supported_kernel_block_sizes:
|
|
if isinstance(supported_size, MultipleOf):
|
|
supported_size = supported_size.base
|
|
# With hybrid_blocks feature, the framework-level block size
|
|
# only needs to be a multiple of the kernel's requirement,
|
|
# even if the kernel requires a fixed block_size.
|
|
if block_size % supported_size == 0:
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def get_preferred_block_size(cls, default_block_size: int) -> int:
|
|
supported_sizes = cls.get_supported_kernel_block_sizes()
|
|
if not supported_sizes:
|
|
return default_block_size
|
|
|
|
if cls.supports_block_size(default_block_size):
|
|
return default_block_size
|
|
|
|
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
|
|
|
|
@classmethod
|
|
def is_mla(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_sink(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_alibi_sqrt(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_mm_prefix(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def is_sparse(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_per_head_quant_scales(cls) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_non_causal(cls) -> bool:
|
|
"""Check if backend supports non-causal (bidirectional) attention
|
|
for decoder models.
|
|
|
|
Unlike ENCODER_ONLY attention type which implies a different
|
|
execution model, this refers to non-causal attention within the
|
|
standard paged-KV-cache decoder path.
|
|
"""
|
|
return False
|
|
|
|
@classmethod
|
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
|
"""Check if backend supports a given attention type.
|
|
|
|
By default, only supports decoder attention.
|
|
Backends should override this to support other attention types.
|
|
"""
|
|
return attn_type == AttentionType.DECODER
|
|
|
|
@classmethod
|
|
def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
|
|
return True
|
|
|
|
@classmethod
|
|
def supports_combination(
|
|
cls,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: "CacheDType | None",
|
|
block_size: int | None,
|
|
use_mla: bool,
|
|
has_sink: bool,
|
|
use_sparse: bool,
|
|
device_capability: "DeviceCapability",
|
|
) -> str | None:
|
|
return None
|
|
|
|
@classmethod
|
|
def validate_configuration(
|
|
cls,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: "CacheDType | None",
|
|
block_size: int | None,
|
|
use_mla: bool,
|
|
has_sink: bool,
|
|
use_sparse: bool,
|
|
use_mm_prefix: bool,
|
|
use_per_head_quant_scales: bool,
|
|
device_capability: "DeviceCapability",
|
|
attn_type: str,
|
|
use_non_causal: bool = False,
|
|
) -> list[str]:
|
|
invalid_reasons = []
|
|
if not cls.supports_head_size(head_size):
|
|
invalid_reasons.append("head_size not supported")
|
|
if not cls.supports_dtype(dtype):
|
|
invalid_reasons.append("dtype not supported")
|
|
if not cls.supports_kv_cache_dtype(kv_cache_dtype):
|
|
invalid_reasons.append("kv_cache_dtype not supported")
|
|
if not cls.supports_block_size(block_size):
|
|
invalid_reasons.append("block_size not supported")
|
|
if use_mm_prefix and not cls.supports_mm_prefix():
|
|
invalid_reasons.append(
|
|
"partial multimodal token full attention not supported"
|
|
)
|
|
if use_mla != cls.is_mla():
|
|
if use_mla:
|
|
invalid_reasons.append("MLA not supported")
|
|
else:
|
|
invalid_reasons.append("non-MLA not supported")
|
|
if has_sink and not cls.supports_sink():
|
|
invalid_reasons.append("attention sinks not supported")
|
|
if use_sparse != cls.is_sparse():
|
|
if use_sparse:
|
|
invalid_reasons.append("sparse not supported")
|
|
else:
|
|
invalid_reasons.append("non-sparse not supported")
|
|
if use_per_head_quant_scales and not cls.supports_per_head_quant_scales():
|
|
invalid_reasons.append("per-head quant scales not supported")
|
|
if not cls.supports_compute_capability(device_capability):
|
|
invalid_reasons.append("compute capability not supported")
|
|
if not cls.supports_attn_type(attn_type):
|
|
invalid_reasons.append(f"attention type {attn_type} not supported")
|
|
if use_non_causal and not cls.supports_non_causal():
|
|
invalid_reasons.append("non-causal attention not supported")
|
|
combination_reason = cls.supports_combination(
|
|
head_size,
|
|
dtype,
|
|
kv_cache_dtype,
|
|
block_size,
|
|
use_mla,
|
|
has_sink,
|
|
use_sparse,
|
|
device_capability,
|
|
)
|
|
if combination_reason is not None:
|
|
invalid_reasons.append(combination_reason)
|
|
return invalid_reasons
|
|
|
|
@classmethod
|
|
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
|
return None
|
|
|
|
@classmethod
|
|
def is_ssm(cls) -> bool:
|
|
return False
|
|
|
|
|
|
class AttentionMetadata:
|
|
pass
|
|
|
|
|
|
T = TypeVar("T", bound=AttentionMetadata)
|
|
|
|
|
|
@dataclass
|
|
class CommonAttentionMetadata:
|
|
"""
|
|
Per-batch attention metadata, shared across layers and backends.
|
|
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
|
|
|
For many of the tensors we keep both GPU and CPU versions.
|
|
"""
|
|
|
|
query_start_loc: torch.Tensor
|
|
query_start_loc_cpu: torch.Tensor
|
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
|
|
|
seq_lens: torch.Tensor
|
|
"""(batch_size,), the number of computed tokens for each request"""
|
|
|
|
num_reqs: int
|
|
"""Number of requests"""
|
|
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
|
|
num_actual_tokens: int
|
|
"""Total number of tokens in batch"""
|
|
max_query_len: int
|
|
"""Longest query in batch"""
|
|
max_seq_len: int
|
|
"""Longest context length (may be an upper bound)"""
|
|
|
|
block_table_tensor: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
causal: bool = True
|
|
|
|
# Needed by FastPrefillAttentionBuilder
|
|
logits_indices_padded: torch.Tensor | None = None
|
|
num_logits_indices: int | None = None
|
|
|
|
# Needed by CrossAttentionBuilder
|
|
encoder_seq_lens: torch.Tensor | None = None
|
|
encoder_seq_lens_cpu: np.ndarray | None = None
|
|
|
|
dcp_local_seq_lens: torch.Tensor | None = None
|
|
dcp_local_seq_lens_cpu: torch.Tensor | None = None
|
|
"""Sequence lengths of the local rank in decode context parallelism world"""
|
|
|
|
is_prefilling: torch.Tensor | None = None
|
|
"""(batch_size,) bool tensor: True if request is still in prefill phase
|
|
(num_computed_tokens < num_prompt_tokens). Used by some backends to
|
|
distinguish actual decodes from short extends."""
|
|
|
|
# WARNING: Deprecated fields. Will be removed in a future release (v0.15.0)
|
|
_seq_lens_cpu: torch.Tensor | None = None
|
|
_num_computed_tokens_cpu: torch.Tensor | None = None
|
|
|
|
_num_computed_tokens_cache: torch.Tensor | None = None
|
|
|
|
def batch_size(self) -> int:
|
|
return self.seq_lens.shape[0]
|
|
|
|
def naive_query_lens(self) -> torch.Tensor:
|
|
"""Naive because it assumes that query ends where the next query starts."""
|
|
return self.query_start_loc[1:] - self.query_start_loc[:-1]
|
|
|
|
def replace(self, **kwargs) -> "CommonAttentionMetadata":
|
|
return replace(self, **kwargs)
|
|
|
|
@property
|
|
@deprecated(
|
|
"""
|
|
Prefer using device seq_lens directly to avoid implicit H<>D sync.
|
|
If a CPU copy is needed, use `seq_lens.cpu()` instead.
|
|
Will be removed in a future release, please migrate as soon as possible.
|
|
"""
|
|
)
|
|
def seq_lens_cpu(self) -> torch.Tensor:
|
|
if self._seq_lens_cpu is None:
|
|
self._seq_lens_cpu = self.seq_lens.to("cpu")
|
|
return self._seq_lens_cpu
|
|
|
|
@property
|
|
@deprecated(
|
|
"""
|
|
Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full
|
|
async scheduling. If a CPU copy is needed, it can be derived from
|
|
query_start_loc_cpu and seq_lens.
|
|
Will be removed in a future release, please migrate as soon as possible.
|
|
"""
|
|
)
|
|
def num_computed_tokens_cpu(self) -> torch.Tensor:
|
|
if self._num_computed_tokens_cpu is None:
|
|
query_seq_lens = (
|
|
self.query_start_loc_cpu[1:] - self.query_start_loc_cpu[:-1]
|
|
)
|
|
self._num_computed_tokens_cpu = self.seq_lens_cpu - query_seq_lens
|
|
return self._num_computed_tokens_cpu
|
|
|
|
def compute_num_computed_tokens(self) -> torch.Tensor:
|
|
"""Compute num_computed_tokens on device (seq_lens - query_lens)."""
|
|
if self._num_computed_tokens_cache is None:
|
|
query_lens = self.query_start_loc[1:] - self.query_start_loc[:-1]
|
|
self._num_computed_tokens_cache = self.seq_lens - query_lens
|
|
return self._num_computed_tokens_cache
|
|
|
|
# TODO(lucas): remove once we have FULL-CG spec-decode support
|
|
def unpadded(
|
|
self, num_actual_tokens: int, num_actual_reqs: int
|
|
) -> "CommonAttentionMetadata":
|
|
maybe_slice_reqs = lambda x: x[:num_actual_reqs] if x is not None else None
|
|
return CommonAttentionMetadata(
|
|
query_start_loc=self.query_start_loc[: num_actual_reqs + 1],
|
|
query_start_loc_cpu=self.query_start_loc_cpu[: num_actual_reqs + 1],
|
|
seq_lens=self.seq_lens[:num_actual_reqs],
|
|
_seq_lens_cpu=self._seq_lens_cpu[:num_actual_reqs]
|
|
if self._seq_lens_cpu is not None
|
|
else None,
|
|
_num_computed_tokens_cpu=self._num_computed_tokens_cpu[:num_actual_reqs]
|
|
if self._num_computed_tokens_cpu is not None
|
|
else None,
|
|
num_reqs=num_actual_reqs,
|
|
num_actual_tokens=num_actual_tokens,
|
|
max_query_len=self.max_query_len,
|
|
max_seq_len=self.max_seq_len,
|
|
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
|
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
|
causal=self.causal,
|
|
logits_indices_padded=self.logits_indices_padded,
|
|
num_logits_indices=self.num_logits_indices,
|
|
encoder_seq_lens=maybe_slice_reqs(self.encoder_seq_lens),
|
|
encoder_seq_lens_cpu=maybe_slice_reqs(self.encoder_seq_lens_cpu),
|
|
dcp_local_seq_lens=maybe_slice_reqs(self.dcp_local_seq_lens),
|
|
dcp_local_seq_lens_cpu=maybe_slice_reqs(self.dcp_local_seq_lens_cpu),
|
|
is_prefilling=maybe_slice_reqs(self.is_prefilling),
|
|
)
|
|
|
|
|
|
M = TypeVar("M")
|
|
|
|
|
|
class AttentionCGSupport(Enum):
|
|
"""Constants for the cudagraph support of the attention backend
|
|
Here we do not consider the cascade attention, as currently
|
|
it is never cudagraph supported."""
|
|
|
|
ALWAYS = 3
|
|
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
|
UNIFORM_BATCH = 2
|
|
"""Cudagraph supported for batches the only contain query lengths that are
|
|
the same, this can be used for spec-decode
|
|
i.e. "decodes" are 1 + num_speculative_tokens"""
|
|
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
|
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
|
NEVER = 0
|
|
"""NO cudagraph support"""
|
|
|
|
|
|
class AttentionMetadataBuilder(ABC, Generic[M]):
|
|
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
|
# Do not access directly. Call get_cudagraph_support() instead.
|
|
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
|
# Does this backend/builder reorder the batch?
|
|
# If not, set this to None. Otherwise set it to the query
|
|
# length that will be pulled into the front of the batch.
|
|
reorder_batch_threshold: int | None = None
|
|
# Does this backend/builder support updating the block table in existing
|
|
# metadata
|
|
supports_update_block_table: bool = False
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: "AttentionSpec",
|
|
layer_names: list[str],
|
|
vllm_config: "VllmConfig",
|
|
device: torch.device,
|
|
):
|
|
self.kv_cache_spec = kv_cache_spec
|
|
self.layer_names = layer_names
|
|
self.vllm_config = vllm_config
|
|
self.device = device
|
|
|
|
@classmethod
|
|
def get_cudagraph_support(
|
|
cls: type["AttentionMetadataBuilder"],
|
|
vllm_config: "VllmConfig",
|
|
kv_cache_spec: "AttentionSpec",
|
|
) -> AttentionCGSupport:
|
|
"""Get the cudagraph support level of this builder class."""
|
|
return cls._cudagraph_support
|
|
|
|
def _init_reorder_batch_threshold(
|
|
self,
|
|
reorder_batch_threshold: int | None = 1,
|
|
supports_spec_as_decode: bool = False,
|
|
supports_dcp_with_varlen: bool = False,
|
|
) -> None:
|
|
self.reorder_batch_threshold = reorder_batch_threshold
|
|
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
|
# If the backend supports spec-as-decode kernels, then we can set
|
|
# the reorder_batch_threshold based on the number of speculative
|
|
# tokens from the config.
|
|
speculative_config = self.vllm_config.speculative_config
|
|
if (
|
|
speculative_config is not None
|
|
and speculative_config.num_speculative_tokens is not None
|
|
):
|
|
max_num_queries_for_spec = (
|
|
1
|
|
+ (2 if speculative_config.parallel_drafting else 1)
|
|
* speculative_config.num_speculative_tokens
|
|
)
|
|
self.reorder_batch_threshold = max(
|
|
self.reorder_batch_threshold,
|
|
max_num_queries_for_spec,
|
|
)
|
|
|
|
if (
|
|
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
|
and not supports_dcp_with_varlen
|
|
):
|
|
self.reorder_batch_threshold = 1
|
|
|
|
@abstractmethod
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> M:
|
|
"""
|
|
Central method that builds attention metadata.
|
|
Some builders (MLA) require reorder_batch to be called prior to build.
|
|
|
|
Args:
|
|
common_prefix_len: The length of the common prefix of the batch.
|
|
common_attn_metadata: The common attention metadata.
|
|
fast_build: The meta-data will prioritize speed of building over
|
|
then speed at execution. Can be used for spec-decode where the
|
|
result of a build call may only be used for few layers/iters.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def update_block_table(
|
|
self,
|
|
metadata: M,
|
|
blk_table: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
) -> M:
|
|
"""
|
|
Update the block table for the attention metadata.
|
|
Faster when theres multiple kv-cache groups that create virtually the
|
|
same metadata but just with different block tables.
|
|
|
|
Only needs to be implemented if supports_update_block_table is True.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def build_for_cudagraph_capture(
|
|
self, common_attn_metadata: CommonAttentionMetadata
|
|
) -> M:
|
|
"""
|
|
Build attention metadata for CUDA graph capture. Uses build by default.
|
|
Subclasses that override this method should call self.build or
|
|
super().build_for_cudagraph_capture.
|
|
"""
|
|
return self.build(
|
|
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
|
)
|
|
|
|
def build_for_drafting(
|
|
self,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
draft_index: int,
|
|
) -> M:
|
|
"""
|
|
Build attention metadata for draft model. Uses build by default.
|
|
|
|
Args:
|
|
common_attn_metadata: The common attention metadata.
|
|
draft_index: The index of the current draft operation.
|
|
When speculating a chain of tokens, this index refers to the
|
|
draft attempt for the i-th token.
|
|
For tree-based attention, this index instead refers to the
|
|
draft attempt for the i-th level in the tree of tokens.
|
|
"""
|
|
return self.build(
|
|
common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata,
|
|
fast_build=True,
|
|
)
|
|
|
|
def use_cascade_attention(
|
|
self,
|
|
common_prefix_len: int,
|
|
query_lens: np.ndarray,
|
|
num_query_heads: int,
|
|
num_kv_heads: int,
|
|
use_alibi: bool,
|
|
use_sliding_window: bool,
|
|
use_local_attention: bool,
|
|
num_sms: int,
|
|
dcp_world_size: int,
|
|
) -> bool:
|
|
return False
|
|
|
|
|
|
class AttentionLayer(Protocol):
|
|
_q_scale: torch.Tensor
|
|
_k_scale: torch.Tensor
|
|
_v_scale: torch.Tensor
|
|
_q_scale_float: float
|
|
_k_scale_float: float
|
|
_v_scale_float: float
|
|
_prob_scale: torch.Tensor
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
) -> torch.Tensor: ...
|
|
|
|
|
|
class AttentionImplBase(ABC, Generic[T]):
|
|
"""Base class for attention implementations.
|
|
|
|
Contains common attributes and initialization logic shared by both
|
|
standard AttentionImpl and MLAAttentionImpl. Does not define a forward
|
|
method - subclasses define their own forward interfaces.
|
|
"""
|
|
|
|
# Required attributes that all impls should have
|
|
num_heads: int
|
|
head_size: int
|
|
scale: float
|
|
|
|
# Whether the attention impl can return the softmax lse for decode.
|
|
# Some features like decode context parallelism require the softmax lse.
|
|
can_return_lse_for_decode: bool = False
|
|
|
|
# Whether the attention impl supports Prefill Context Parallelism.
|
|
supports_pcp: bool = False
|
|
# Whether the attention impl(or ops) supports MTP
|
|
# when cp_kv_cache_interleave_size > 1
|
|
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
|
|
|
|
# some attention backends might not always want to return lse
|
|
# even if they can return lse (for efficiency reasons)
|
|
need_to_return_lse_for_decode: bool = False
|
|
|
|
# Whether this attention implementation supports pre-quantized query input.
|
|
# When True, the attention layer will quantize queries before passing them
|
|
# to this backend, allowing torch.compile to fuse the quantization with
|
|
# previous operations. This is typically supported when using FP8 KV cache
|
|
# with compatible attention kernels (e.g., TRT-LLM).
|
|
# Subclasses should set this in __init__.
|
|
# TODO add support to more backends:
|
|
# https://github.com/vllm-project/vllm/issues/25584
|
|
supports_quant_query_input: bool = False
|
|
|
|
dcp_world_size: int
|
|
dcp_rank: int
|
|
|
|
pcp_world_size: int
|
|
pcp_rank: int
|
|
|
|
total_cp_world_size: int
|
|
total_cp_rank: int
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# use __new__ so that all subclasses will call this
|
|
self = super().__new__(cls)
|
|
try:
|
|
from vllm.distributed.parallel_state import get_dcp_group
|
|
|
|
self.dcp_world_size = get_dcp_group().world_size
|
|
self.dcp_rank = get_dcp_group().rank_in_group
|
|
except AssertionError:
|
|
# DCP might not be initialized in testing
|
|
self.dcp_world_size = 1
|
|
self.dcp_rank = 0
|
|
try:
|
|
from vllm.distributed.parallel_state import get_pcp_group
|
|
|
|
self.pcp_world_size = get_pcp_group().world_size
|
|
self.pcp_rank = get_pcp_group().rank_in_group
|
|
except AssertionError:
|
|
self.pcp_world_size = 1
|
|
self.pcp_rank = 0
|
|
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
|
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
|
|
|
self.need_to_return_lse_for_decode = (
|
|
self.dcp_world_size > 1 and self.can_return_lse_for_decode
|
|
)
|
|
return self
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
pass
|
|
|
|
|
|
class AttentionImpl(AttentionImplBase[T], Generic[T]):
|
|
"""Standard attention implementation with forward method."""
|
|
|
|
kv_cache_dtype: str
|
|
|
|
@property
|
|
def kv_quant_mode(self) -> "KVQuantMode":
|
|
"""Return the KV cache quantization mode for this layer."""
|
|
return get_kv_quant_mode(self.kv_cache_dtype)
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int | None = None,
|
|
alibi_slopes: list[float] | None = None,
|
|
sliding_window: int | None = None,
|
|
kv_cache_dtype: str = "auto",
|
|
logits_soft_cap: float | None = None,
|
|
attn_type: str = AttentionType.DECODER,
|
|
kv_sharing_target_layer_name: str | None = None,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
output: torch.Tensor | None = None,
|
|
output_scale: torch.Tensor | None = None,
|
|
output_block_scale: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
|
"""
|
|
Does this attention implementation support fused output quantization.
|
|
This is used by the AttnFusionPass to only fuse output quantization
|
|
onto implementations that support it.
|
|
|
|
:param quant_key: QuantKey object that describes the quantization op
|
|
:return: is fusion supported for this type of quantization
|
|
"""
|
|
return False
|
|
|
|
def fused_rope_kvcache_supported(self):
|
|
"""
|
|
Does this attention implementation support RoPE+KVCache fusion.
|
|
This is used by the RopeKVCacheFusionPass to only fuse the RoPE ops
|
|
with the KV cache update for implementations that support it.
|
|
"""
|
|
return False
|
|
|
|
def do_rope_and_kv_cache_update(
|
|
self,
|
|
layer: AttentionLayer,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
is_neox: bool,
|
|
kv_cache: torch.Tensor,
|
|
layer_slot_mapping: torch.Tensor,
|
|
):
|
|
"""
|
|
If `fused_rope_kvcache_supported` returns True, this method will be called
|
|
by torch.ops.vllm.fused_rope_and_unified_kv_cache_update
|
|
to perform the inplace RoPE and KV cache update.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
|
"""MLA attention implementation with forward_mqa and forward_mha methods."""
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None,
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: str | None,
|
|
# MLA Specific Arguments
|
|
q_lora_rank: int | None,
|
|
kv_lora_rank: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
qk_head_dim: int,
|
|
v_head_dim: int,
|
|
kv_b_proj: "ColumnParallelLinear",
|
|
indexer: object | None = None,
|
|
q_pad_num_heads: int | None = None,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward_mha(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
k_scale: torch.Tensor,
|
|
output: torch.Tensor,
|
|
) -> None:
|
|
"""MHA-style prefill forward pass."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward_mqa(
|
|
self,
|
|
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
layer: AttentionLayer,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
"""MQA-style decode forward pass."""
|
|
raise NotImplementedError
|
|
|
|
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
|
"""
|
|
Does this attention implementation support fused output quantization.
|
|
Since MLA quantization is done manually in forward_impl (common code),
|
|
all MLA backends support it by default.
|
|
"""
|
|
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
|
|
|
|
def do_kv_cache_update(
|
|
self,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache_dtype: str,
|
|
k_scale: torch.Tensor,
|
|
) -> None:
|
|
if kv_cache.numel() == 0:
|
|
return
|
|
from vllm import _custom_ops as ops
|
|
|
|
ops.concat_and_cache_mla(
|
|
kv_c_normed,
|
|
k_pe.squeeze(1),
|
|
kv_cache,
|
|
slot_mapping.flatten(),
|
|
kv_cache_dtype=kv_cache_dtype,
|
|
scale=k_scale,
|
|
)
|
|
|
|
|
|
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
|
"""Sparse MLA attention implementation with only forward_mqa method.
|
|
|
|
Sparse MLA implementations only support decode (MQA-style) attention.
|
|
They do not support prefill (MHA-style) attention.
|
|
"""
|
|
|
|
def fused_output_quant_supported(self, quant_key: "QuantKey"):
|
|
"""
|
|
Does this attention implementation support fused output quantization.
|
|
Since MLA quantization is done manually in forward_impl (common code),
|
|
all MLA backends support it by default.
|
|
"""
|
|
return quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None,
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: str | None,
|
|
# MLA Specific Arguments
|
|
q_lora_rank: int | None,
|
|
kv_lora_rank: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
qk_head_dim: int,
|
|
v_head_dim: int,
|
|
kv_b_proj: "ColumnParallelLinear",
|
|
indexer: object | None = None,
|
|
q_pad_num_heads: int | None = None,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def forward_mqa(
|
|
self,
|
|
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
layer: AttentionLayer,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
"""MQA-style decode forward pass."""
|
|
raise NotImplementedError
|
|
|
|
def do_kv_cache_update(
|
|
self,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
kv_cache_dtype: str,
|
|
k_scale: torch.Tensor,
|
|
) -> None:
|
|
if kv_cache.numel() == 0:
|
|
return
|
|
from vllm import _custom_ops as ops
|
|
|
|
ops.concat_and_cache_mla(
|
|
kv_c_normed,
|
|
k_pe.squeeze(1),
|
|
kv_cache,
|
|
slot_mapping.flatten(),
|
|
kv_cache_dtype=kv_cache_dtype,
|
|
scale=k_scale,
|
|
)
|
|
|
|
|
|
def subclass_attention_backend(
|
|
name_prefix: str,
|
|
attention_backend_cls: type[AttentionBackend],
|
|
builder_cls: type[AttentionMetadataBuilder[M]],
|
|
) -> type[AttentionBackend]:
|
|
"""
|
|
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
|
"""
|
|
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
|
|
|
return type(
|
|
name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls}
|
|
)
|
|
|
|
|
|
def subclass_attention_backend_with_overrides(
|
|
name_prefix: str,
|
|
attention_backend_cls: type[AttentionBackend],
|
|
overrides: dict[str, Any],
|
|
) -> type[AttentionBackend]:
|
|
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
|
return type(name, (attention_backend_cls,), overrides)
|