[Spec Decode] Add hidden states extraction system (#33736)
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
This commit is contained in:
committed by
GitHub
parent
d1a6e96d9e
commit
9433acb8df
394
vllm/model_executor/models/extract_hidden_states.py
Normal file
394
vllm/model_executor/models/extract_hidden_states.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Hidden States Extractor Model.
|
||||
|
||||
This model extracts and caches hidden states from the target model
|
||||
without performing actual token generation. It's used with the
|
||||
extract_hidden_states speculative decoding method.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.attention.attention import set_default_quant_scales
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
MLAAttentionSpec,
|
||||
)
|
||||
|
||||
########## Custom Ops ########
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
to_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer,
|
||||
to_cache,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def dummy_attention(layer_name, _placeholder):
|
||||
# Note: layer_name arg required by @maybe_transfer_kv_layer
|
||||
return _placeholder
|
||||
|
||||
|
||||
def basic_cache(
|
||||
to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
|
||||
kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
|
||||
slot_mapping: torch.Tensor, # shape: [seq_len]
|
||||
):
|
||||
num_blocks, block_size, num_heads, head_size = kv_cache.shape
|
||||
token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
|
||||
token_kv_cache[slot_mapping] = to_cache
|
||||
|
||||
|
||||
######### CacheOnlyAttentionBackend ########
|
||||
|
||||
|
||||
class CacheOnlyAttentionBackend(AttentionBackend):
|
||||
"""Attention backend that only caches KV without computing attention."""
|
||||
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CACHE_ONLY_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
|
||||
return CacheOnlyAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
# We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
|
||||
# We also don't use a k/v (2) dim
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
|
||||
return CacheOnlyAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
class CacheOnlyAttentionMetadata:
|
||||
def __init__(self, slot_mapping: torch.Tensor):
|
||||
self.slot_mapping = slot_mapping
|
||||
|
||||
|
||||
class CacheOnlyAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[CacheOnlyAttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CacheOnlyAttentionMetadata:
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention not supported by CacheOnlyAttention"
|
||||
)
|
||||
causal = common_attn_metadata.causal
|
||||
if not causal:
|
||||
raise NotImplementedError(
|
||||
"Non-causal attention not supported by CacheOnlyAttention"
|
||||
)
|
||||
|
||||
return CacheOnlyAttentionMetadata(
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
class CacheOnlyAttentionImpl(AttentionImpl):
|
||||
"""Attention implementation that only caches KV states."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_torch_dtype: torch.dtype,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_cache_torch_dtype = kv_cache_torch_dtype
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(f"Unsupported attention type: {attn_type}")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("Quantized KV cache not supported")
|
||||
|
||||
self.num_queries_per_kv = 1
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer,
|
||||
to_cache,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
):
|
||||
assert to_cache.dtype == self.kv_cache_torch_dtype, (
|
||||
f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
|
||||
)
|
||||
assert kv_cache.dtype == self.kv_cache_torch_dtype, (
|
||||
f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
|
||||
)
|
||||
|
||||
basic_cache(to_cache, kv_cache, slot_mapping)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# Empty implementation of abstract method
|
||||
pass
|
||||
|
||||
|
||||
############## CacheOnlyAttentionLayer (replaces Attention) ############
|
||||
|
||||
|
||||
class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer that only caches key/value states without computing attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.layer_name = prefix
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# KV cache configuration
|
||||
cache_config = cache_config or vllm_config.cache_config
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
self.block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
self.block_size = 16
|
||||
|
||||
assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
|
||||
"CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
|
||||
f"kv cache dtype was set to {kv_cache_dtype}"
|
||||
)
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
set_default_quant_scales(self, register_buffer=True)
|
||||
|
||||
# Attention backend
|
||||
self.attn_backend = CacheOnlyAttentionBackend
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
self.kv_cache_torch_dtype,
|
||||
attn_type,
|
||||
)
|
||||
|
||||
assert not self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"KV cache update should be independent of forward"
|
||||
)
|
||||
|
||||
# Placeholder KV cache (replaced by bind_kv_cache)
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Register in compilation context
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
|
||||
"""Cache hidden states as KV pairs without computing attention.
|
||||
|
||||
Args:
|
||||
to_cache: The tensor to insert into the kv cache.
|
||||
shape [num_tokens, num_heads, head_size]
|
||||
|
||||
Returns:
|
||||
Dummy output tensor (not used)
|
||||
"""
|
||||
# Note: we set num_heads to num_hidden_layers and
|
||||
# head_size to hidden_size for hidden states storage
|
||||
output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)
|
||||
|
||||
# Note: dummy_out is used to force torch.compile to preserve ordering between
|
||||
# cache update and attention op (which triggers kv_connector transfer)
|
||||
dummy_out = unified_kv_cache_update(to_cache, self.layer_name)
|
||||
|
||||
# Triggers kv_connector transfer via decorator
|
||||
_ = dummy_attention(self.layer_name, dummy_out)
|
||||
|
||||
return output
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Note: we use MLAAttentionSpec here to because it will
|
||||
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
|
||||
# whereas FullAttentionSpec will add an additional factor of 2
|
||||
return MLAAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
############ ExtractHiddenStatesModel definition ##########
|
||||
|
||||
|
||||
class ExtractHiddenStatesModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||
self.target_num_hidden_layers = (
|
||||
vllm_config.model_config.get_total_num_hidden_layers()
|
||||
)
|
||||
self.num_hidden_states = len(
|
||||
getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
|
||||
)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
# Create a single cache-only attention layer
|
||||
# Note: We set num_heads <- self.num_hidden_states
|
||||
# and head_size <- hidden_size so that we can insert
|
||||
# the hidden states directly into the cache without
|
||||
# reshaping
|
||||
self.cache_only_layers = nn.ModuleDict(
|
||||
{
|
||||
str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
|
||||
num_heads=self.num_hidden_states,
|
||||
head_size=self.hidden_size,
|
||||
cache_config=cache_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> None:
|
||||
"""Process and cache hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states from target model
|
||||
shape: [num_tokens, num_hidden_states, hidden_size]
|
||||
|
||||
Returns:
|
||||
Tuple of (dummy_output, dummy_output) - both unused
|
||||
"""
|
||||
|
||||
# Call dummy attention layer to cache hidden states
|
||||
# Output is ignored - we only care about the KV cache side effects
|
||||
_ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""No weights to load for this dummy model."""
|
||||
return set()
|
||||
@@ -512,6 +512,7 @@ _MULTIMODAL_MODELS = {
|
||||
}
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
|
||||
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user