395 lines
13 KiB
Python
395 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""Hidden States Extractor Model.
|
|
|
|
This model extracts and caches hidden states from the target model
|
|
without performing actual token generation. It's used with the
|
|
extract_hidden_states speculative decoding method.
|
|
"""
|
|
|
|
from collections.abc import Iterable
|
|
from typing import ClassVar
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
|
from vllm.config.cache import CacheDType
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.attention.attention import set_default_quant_scales
|
|
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
|
maybe_transfer_kv_layer,
|
|
)
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.model_executor.models.utils import maybe_prefix
|
|
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
|
|
from vllm.v1.attention.backend import (
|
|
AttentionBackend,
|
|
AttentionImpl,
|
|
AttentionMetadataBuilder,
|
|
AttentionType,
|
|
CommonAttentionMetadata,
|
|
is_quantized_kv_cache,
|
|
)
|
|
from vllm.v1.kv_cache_interface import (
|
|
AttentionSpec,
|
|
KVCacheSpec,
|
|
MLAAttentionSpec,
|
|
)
|
|
|
|
########## Custom Ops ########
|
|
|
|
|
|
def unified_kv_cache_update(
|
|
to_cache: torch.Tensor,
|
|
layer_name: str,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns a dummy that is passed to unified_attention to signal a side effect and
|
|
the data dependency between them to ensure torch.compile preserves ordering.
|
|
"""
|
|
forward_context = get_forward_context()
|
|
attn_layer = forward_context.no_compile_layers[layer_name]
|
|
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
|
|
|
slot_mapping = forward_context.slot_mapping
|
|
assert isinstance(slot_mapping, dict), (
|
|
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
|
)
|
|
layer_slot_mapping = slot_mapping.get(layer_name)
|
|
if layer_slot_mapping is not None:
|
|
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
|
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
|
)
|
|
attn_layer.impl.do_kv_cache_update(
|
|
attn_layer,
|
|
to_cache,
|
|
kv_cache,
|
|
layer_slot_mapping,
|
|
)
|
|
|
|
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
|
|
|
|
|
@maybe_transfer_kv_layer
|
|
def dummy_attention(layer_name, _placeholder):
|
|
# Note: layer_name arg required by @maybe_transfer_kv_layer
|
|
return _placeholder
|
|
|
|
|
|
def basic_cache(
|
|
to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
|
|
kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
|
|
slot_mapping: torch.Tensor, # shape: [seq_len]
|
|
):
|
|
num_blocks, block_size, num_heads, head_size = kv_cache.shape
|
|
token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
|
|
token_kv_cache[slot_mapping] = to_cache
|
|
|
|
|
|
######### CacheOnlyAttentionBackend ########
|
|
|
|
|
|
class CacheOnlyAttentionBackend(AttentionBackend):
|
|
"""Attention backend that only caches KV without computing attention."""
|
|
|
|
accept_output_buffer: bool = False
|
|
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
torch.float32,
|
|
]
|
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
|
"auto",
|
|
"bfloat16",
|
|
]
|
|
forward_includes_kv_cache_update: bool = False
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "CACHE_ONLY_ATTN"
|
|
|
|
@classmethod
|
|
def supports_attn_type(cls, attn_type: str) -> bool:
|
|
return attn_type == AttentionType.DECODER
|
|
|
|
@classmethod
|
|
def supports_mm_prefix(cls) -> bool:
|
|
return True
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
|
|
return CacheOnlyAttentionImpl
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
# We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
|
|
# We also don't use a k/v (2) dim
|
|
return (num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
|
|
return CacheOnlyAttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
|
return False
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return []
|
|
|
|
|
|
class CacheOnlyAttentionMetadata:
|
|
def __init__(self, slot_mapping: torch.Tensor):
|
|
self.slot_mapping = slot_mapping
|
|
|
|
|
|
class CacheOnlyAttentionMetadataBuilder(
|
|
AttentionMetadataBuilder[CacheOnlyAttentionMetadata]
|
|
):
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> CacheOnlyAttentionMetadata:
|
|
use_cascade = common_prefix_len > 0
|
|
if use_cascade:
|
|
raise NotImplementedError(
|
|
"Cascade attention not supported by CacheOnlyAttention"
|
|
)
|
|
causal = common_attn_metadata.causal
|
|
if not causal:
|
|
raise NotImplementedError(
|
|
"Non-causal attention not supported by CacheOnlyAttention"
|
|
)
|
|
|
|
return CacheOnlyAttentionMetadata(
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
|
)
|
|
|
|
|
|
class CacheOnlyAttentionImpl(AttentionImpl):
|
|
"""Attention implementation that only caches KV states."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
kv_cache_dtype: str,
|
|
kv_cache_torch_dtype: torch.dtype,
|
|
attn_type: AttentionType = AttentionType.DECODER,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.kv_cache_torch_dtype = kv_cache_torch_dtype
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError(f"Unsupported attention type: {attn_type}")
|
|
if is_quantized_kv_cache(kv_cache_dtype):
|
|
raise NotImplementedError("Quantized KV cache not supported")
|
|
|
|
self.num_queries_per_kv = 1
|
|
|
|
def do_kv_cache_update(
|
|
self,
|
|
layer,
|
|
to_cache,
|
|
kv_cache,
|
|
slot_mapping,
|
|
):
|
|
assert to_cache.dtype == self.kv_cache_torch_dtype, (
|
|
f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
|
|
)
|
|
assert kv_cache.dtype == self.kv_cache_torch_dtype, (
|
|
f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
|
|
)
|
|
|
|
basic_cache(to_cache, kv_cache, slot_mapping)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
# Empty implementation of abstract method
|
|
pass
|
|
|
|
|
|
############## CacheOnlyAttentionLayer (replaces Attention) ############
|
|
|
|
|
|
class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
|
|
"""Attention layer that only caches key/value states without computing attention."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
cache_config: CacheConfig | None = None,
|
|
prefix: str = "",
|
|
attn_type: str = AttentionType.DECODER,
|
|
):
|
|
super().__init__()
|
|
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.layer_name = prefix
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
|
|
# KV cache configuration
|
|
cache_config = cache_config or vllm_config.cache_config
|
|
if cache_config is not None:
|
|
kv_cache_dtype = cache_config.cache_dtype
|
|
self.block_size = cache_config.block_size
|
|
else:
|
|
kv_cache_dtype = "auto"
|
|
self.block_size = 16
|
|
|
|
assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
|
|
"CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
|
|
f"kv cache dtype was set to {kv_cache_dtype}"
|
|
)
|
|
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
|
kv_cache_dtype, vllm_config.model_config
|
|
)
|
|
|
|
# Initialize KV cache quantization attributes
|
|
set_default_quant_scales(self, register_buffer=True)
|
|
|
|
# Attention backend
|
|
self.attn_backend = CacheOnlyAttentionBackend
|
|
impl_cls = self.attn_backend.get_impl_cls()
|
|
self.impl = impl_cls(
|
|
num_heads,
|
|
head_size,
|
|
kv_cache_dtype,
|
|
self.kv_cache_torch_dtype,
|
|
attn_type,
|
|
)
|
|
|
|
assert not self.attn_backend.forward_includes_kv_cache_update, (
|
|
"KV cache update should be independent of forward"
|
|
)
|
|
|
|
# Placeholder KV cache (replaced by bind_kv_cache)
|
|
self.kv_cache = [
|
|
torch.tensor([])
|
|
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
|
]
|
|
|
|
# Register in compilation context
|
|
compilation_config = vllm_config.compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
|
|
"""Cache hidden states as KV pairs without computing attention.
|
|
|
|
Args:
|
|
to_cache: The tensor to insert into the kv cache.
|
|
shape [num_tokens, num_heads, head_size]
|
|
|
|
Returns:
|
|
Dummy output tensor (not used)
|
|
"""
|
|
# Note: we set num_heads to num_hidden_layers and
|
|
# head_size to hidden_size for hidden states storage
|
|
output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)
|
|
|
|
# Note: dummy_out is used to force torch.compile to preserve ordering between
|
|
# cache update and attention op (which triggers kv_connector transfer)
|
|
dummy_out = unified_kv_cache_update(to_cache, self.layer_name)
|
|
|
|
# Triggers kv_connector transfer via decorator
|
|
_ = dummy_attention(self.layer_name, dummy_out)
|
|
|
|
return output
|
|
|
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
|
return self.attn_backend
|
|
|
|
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
|
# Note: we use MLAAttentionSpec here to because it will
|
|
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
|
|
# whereas FullAttentionSpec will add an additional factor of 2
|
|
return MLAAttentionSpec(
|
|
block_size=self.block_size,
|
|
num_kv_heads=self.num_heads,
|
|
head_size=self.head_size,
|
|
dtype=self.kv_cache_torch_dtype,
|
|
)
|
|
|
|
|
|
############ ExtractHiddenStatesModel definition ##########
|
|
|
|
|
|
class ExtractHiddenStatesModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
self.vllm_config = vllm_config
|
|
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
|
|
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
|
self.target_num_hidden_layers = (
|
|
vllm_config.model_config.get_total_num_hidden_layers()
|
|
)
|
|
self.num_hidden_states = len(
|
|
getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
|
|
)
|
|
|
|
cache_config = vllm_config.cache_config
|
|
|
|
# Create a single cache-only attention layer
|
|
# Note: We set num_heads <- self.num_hidden_states
|
|
# and head_size <- hidden_size so that we can insert
|
|
# the hidden states directly into the cache without
|
|
# reshaping
|
|
self.cache_only_layers = nn.ModuleDict(
|
|
{
|
|
str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
|
|
num_heads=self.num_hidden_states,
|
|
head_size=self.hidden_size,
|
|
cache_config=cache_config,
|
|
prefix=maybe_prefix(
|
|
prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
|
|
),
|
|
)
|
|
}
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> None:
|
|
"""Process and cache hidden states.
|
|
|
|
Args:
|
|
hidden_states: Hidden states from target model
|
|
shape: [num_tokens, num_hidden_states, hidden_size]
|
|
|
|
Returns:
|
|
Tuple of (dummy_output, dummy_output) - both unused
|
|
"""
|
|
|
|
# Call dummy attention layer to cache hidden states
|
|
# Output is ignored - we only care about the KV cache side effects
|
|
_ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
"""No weights to load for this dummy model."""
|
|
return set()
|