[Spec Decode] Add hidden states extraction system (#33736)

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
This commit is contained in:
Fynn Schmitt-Ulms
2026-03-02 14:29:09 -05:00
committed by GitHub
parent d1a6e96d9e
commit 9433acb8df
16 changed files with 2102 additions and 38 deletions

View File

@@ -108,7 +108,7 @@ class _HfExamplesInfo:
use_original_num_layers: bool = False
"""
If True, use the original number of layers from the model config
If True, use the original number of layers from the model config
instead of minimal layers for testing.
"""
@@ -1156,6 +1156,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
min_transformers_version="5.1.0",
),
"ExtractHiddenStatesModel": _HfExamplesInfo(
"Qwen/Qwen3-8B",
speculative_method="extract_hidden_states",
),
"Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5",

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Predictable dummy model for testing extract_hidden_states.
Subclasses LlamaForCausalLM but overrides the model to produce deterministic
hidden states: layer i outputs values equal to (i).
"""
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.sequence import IntermediateTensors
class PredictableLlamaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.aux_hidden_state_layers = tuple[int, ...]()
# Create minimal embed_tokens for embedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
)
# Required for pipeline parallelism
from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory,
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Embed input IDs."""
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
**extra_layer_kwargs,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""Forward pass that produces predictable outputs.
Returns:
If aux_hidden_state_layers is set: (hidden_states, aux_hidden_states)
Otherwise: hidden_states
"""
# Determine sequence length
if inputs_embeds is not None:
seq_len = inputs_embeds.shape[0]
device = inputs_embeds.device
elif input_ids is not None:
seq_len = input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1]
device = input_ids.device
else:
raise ValueError("Either input_ids or inputs_embeds must be provided")
# Final hidden states (last layer value)
hidden_states = torch.full(
(seq_len, self.config.hidden_size),
fill_value=float(self.config.num_hidden_layers),
device=device,
dtype=torch.bfloat16,
)
# Check if we need auxiliary hidden states
if len(self.aux_hidden_state_layers) > 0:
aux_hidden_states = []
for layer_idx in self.aux_hidden_state_layers:
# Fill with (layer_idx) for predictability
layer_hidden = torch.full(
(seq_len, self.config.hidden_size),
fill_value=float(layer_idx),
device=device,
dtype=torch.bfloat16,
)
aux_hidden_states.append(layer_hidden)
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Skip weight loading."""
return set()
class PredictableLlamaForCausalLM(LlamaForCausalLM):
"""Predictable Llama model for testing.
Overrides _init_model to use PredictableLlamaModel instead of LlamaModel.
"""
def _init_model(
self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] | None = None,
):
"""Initialize with predictable model."""
return PredictableLlamaModel(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Skip weight loading for dummy model."""
return set()

View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import os
import pytest
import torch
from safetensors import safe_open
from vllm import LLM, ModelRegistry, SamplingParams
def get_and_check_output(output, expected_shape):
assert output.kv_transfer_params is not None
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
assert os.path.exists(hidden_states_path)
# Load and verify the saved tensors
with safe_open(hidden_states_path, "pt") as f:
# Check that token_ids and hidden_states are present
tensor_names = f.keys()
assert "token_ids" in tensor_names
assert "hidden_states" in tensor_names
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")
prompt_token_ids = output.prompt_token_ids
assert torch.equal(token_ids, torch.tensor(prompt_token_ids))
assert hidden_states.shape == expected_shape
# Verify hidden_states are not all zeros (i.e., they were actually computed)
assert not torch.allclose(hidden_states, torch.zeros_like(hidden_states))
return token_ids, hidden_states
@pytest.fixture(scope="module")
def predictable_llama_config_path(tmp_path_factory):
"""Create a minimal LlamaConfig for PredictableLlamaForCausalLM."""
from transformers import LlamaConfig, LlamaTokenizerFast
config_dir = tmp_path_factory.mktemp("predictable_llama")
# Create a minimal Llama config with small dimensions
config = LlamaConfig(
vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=24, # Enough layers to test various layer_ids
num_attention_heads=4,
num_key_value_heads=4,
max_position_embeddings=128,
architectures=["PredictableLlamaForCausalLM"],
)
# Save config
config.save_pretrained(config_dir)
# Create a simple tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
cache_dir=os.path.expanduser("~/.cache/huggingface"),
)
tokenizer.save_pretrained(config_dir)
return str(config_dir)
@pytest.fixture(scope="module", autouse=True)
def register_predictable_model():
"""Register the PredictableLlamaForCausalLM model."""
from .predictable_llama import PredictableLlamaForCausalLM
if "PredictableLlamaForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"PredictableLlamaForCausalLM", PredictableLlamaForCausalLM
)
yield
def test_extract_hidden_states_with_predictable_dummy_model(
predictable_llama_config_path, tmp_path
):
"""Comprehensive test using a predictable dummy model with synthetic weights.
The PredictableLlamaForCausalLM outputs deterministic hidden states where
each layer produces values equal to (layer_index). This test verifies:
1. Hidden states are correctly extracted from requested layers
2. Values match the expected predictable pattern
3. Layer ordering is preserved correctly (non-sequential layer IDs)
4. Multiple prompts of different lengths produce consistent layer values
"""
# Test with non-sequential layer ordering to verify correct association
layer_ids = [5, 2, 10]
num_layers = len(layer_ids)
llm = LLM(
model=predictable_llama_config_path,
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {"shared_storage_path": tmp_path},
},
max_model_len=128,
enforce_eager=True,
trust_remote_code=True,
load_format="dummy", # Don't try to load real weights
)
# Test with multiple prompts of different lengths
prompts = [
"Short",
"Medium length",
"Much longer prompt with many tokens",
"Much longer prompt with many tokens", # repeated prompt
]
sampling_params = SamplingParams(max_tokens=1, temperature=0.0)
hidden_size = llm.llm_engine.model_config.get_hidden_size()
outputs = llm.generate(prompts, sampling_params)
del llm
gc.collect()
assert len(outputs) == len(prompts)
for output in outputs:
# hidden_states shape is [prompt_len, num_hidden_layers, hidden_size]
expected_shape = (
len(output.prompt_token_ids),
num_layers,
hidden_size,
)
_token_ids, hidden_states = get_and_check_output(output, expected_shape)
for idx, layer_id in enumerate(layer_ids):
layer_hidden = hidden_states[:, idx, :]
assert torch.allclose(
layer_hidden,
torch.full_like(layer_hidden, layer_id),
atol=1e-5,
), (
f"Layer {layer_id} at position {idx} should output {float(layer_id)}, "
f"but got mean={layer_hidden.mean():.3f}, "
f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}"
)

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
)
from vllm.config import (
AttentionConfig,
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.config.load import LoadConfig
from vllm.platforms import current_platform
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def _create_proposer(
num_speculative_tokens: int = 1,
layer_ids: list[int] | None = None,
) -> ExtractHiddenStatesProposer:
"""Create an ExtractHiddenStatesProposer for testing."""
if layer_ids is None:
layer_ids = [1, 2, 3, 4]
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="extract_hidden_states",
num_speculative_tokens=num_speculative_tokens,
draft_model_config={
"hf_config": {
"eagle_aux_hidden_state_layer_ids": layer_ids,
}
},
)
device = current_platform.device_type
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
attention_config=AttentionConfig(),
)
return ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device)
def test_proposer_initialization():
"""Test that the proposer initializes correctly with the right parameters."""
layer_ids = [1, 2, 3, 4]
proposer = _create_proposer(num_speculative_tokens=1, layer_ids=layer_ids)
assert proposer.num_hidden_states == len(layer_ids)
assert proposer.vllm_config.speculative_config is not None
assert proposer.vllm_config.speculative_config.num_speculative_tokens == 1
# Verify the hidden states buffer is correctly shaped
expected_shape = (
proposer.max_num_tokens,
len(layer_ids),
proposer.hidden_size,
)
assert proposer.hidden_states.shape == expected_shape
def test_proposer_initialization_missing_layer_ids():
"""Test that initialization fails when layer_ids are not provided."""
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="extract_hidden_states",
num_speculative_tokens=1,
draft_model_config={
"hf_config": {} # Missing eagle_aux_hidden_state_layer_ids
},
)
device = current_platform.device_type
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
attention_config=AttentionConfig(),
)
with pytest.raises(
ValueError, match="eagle_aux_hidden_state_layer_ids must be set"
):
ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device)
def test_prepare_next_token_ids_padded():
"""
Test for prepare_next_token_ids_padded with extract_hidden_states.
Since num_speculative_tokens == 1, sampled_token_ids has shape (batch_size, 1).
For each request we either use the sampled token (if valid and not discarded)
or a backup token from the request state.
"""
device = torch.device(current_platform.device_type)
num_requests = 4
batch_spec = BatchSpec(
seq_lens=[5] * num_requests,
query_lens=[5] * num_requests,
)
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_requests = {}
for req_id in req_ids:
mock_request = mock.MagicMock(spec=CachedRequestState)
# Each request will have a backup next token id of 10, 20, 30, 40
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
mock_requests[req_id] = mock_request
# explicitly discard the last request
discarded_req_mask = torch.tensor(
[False, False, False, True], dtype=torch.bool, device=device
)
# With num_speculative_tokens=1, sampled_token_ids has shape [batch_size, 1]
sampled_token_ids = torch.tensor(
[
[1], # valid, use 1
[4], # valid, use 4
[-1], # invalid, use backup token "30"
[2], # explicitly discarded, use backup token "40"
],
dtype=torch.int32,
device=device,
)
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(
expected_next_token_ids_cpu, dtype=torch.int32, device=device
)
proposer = _create_proposer(num_speculative_tokens=1)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
# It doesn't depend on whether the request is discarded
expected_valid_sampled_tokens_count = torch.tensor(
[1, 1, 0, 1], dtype=torch.int32, device=device
)
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
mock_requests,
mock_input_batch,
discarded_req_mask,
)
assert torch.equal(next_token_ids, expected_next_token_ids_tensor)
assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
def test_propose():
"""
Test the propose() method of ExtractHiddenStatesProposer.
This should:
1. Accept target hidden states and sampled token IDs
2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1])
3. Cache the hidden states in the model's KV cache
"""
device = torch.device(current_platform.device_type)
# Setup test parameters
batch_size = 2
num_tokens = 5
num_hidden_layers = 4
proposer = _create_proposer(
num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers))
)
hidden_size = proposer.hidden_size
# Create mock model
model_mock = mock.MagicMock()
proposer.model = model_mock
# Mock attention layer names
proposer.attn_layer_names = ["cache_only_layers.28"]
# Mock attention metadata builder
mock_attn_metadata = mock.MagicMock()
mock_attn_metadata_builder = mock.MagicMock()
mock_attn_metadata_builder.build_for_drafting.return_value = mock_attn_metadata
proposer.attn_metadata_builder = mock_attn_metadata_builder
# Create input tensors
batch_spec = BatchSpec(
seq_lens=[3, 2],
query_lens=[3, 2],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Create target hidden states: list of tensors, one per layer
# Each tensor has shape [num_tokens, hidden_size]
target_hidden_states = [
torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device)
for _ in range(num_hidden_layers)
]
# Sampled token IDs from target model
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
# Mock scheduler output
mock_scheduler_output = mock.MagicMock()
# Call propose
with mock.patch(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, kv_connector_output = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None,
)
# Verify draft tokens match sampled tokens
# Shape should be [batch_size, 1] for num_speculative_tokens=1
assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
# Verify the model was called
model_mock.assert_called_once()
# Verify hidden states were copied to the buffer The stacked hidden states
# should have shape [num_tokens, num_hidden_layers, hidden_size]
expected_stacked = torch.stack(target_hidden_states, dim=1)
assert torch.allclose(
proposer.hidden_states[:num_tokens], expected_stacked, atol=1e-6
)
@pytest.mark.parametrize("num_hidden_layers", [1, 4, 8])
def test_propose_different_layer_counts(num_hidden_layers):
"""Test that propose works correctly with different numbers of hidden layers."""
device = torch.device(current_platform.device_type)
batch_size = 2
num_tokens = 5
proposer = _create_proposer(
num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers))
)
hidden_size = proposer.hidden_size
# Setup mocks
model_mock = mock.MagicMock()
proposer.model = model_mock
proposer.attn_layer_names = ["cache_only_layers.28"]
mock_attn_metadata_builder = mock.MagicMock()
mock_attn_metadata_builder.build_for_drafting.return_value = mock.MagicMock()
proposer.attn_metadata_builder = mock_attn_metadata_builder
batch_spec = BatchSpec(
seq_lens=[3, 2],
query_lens=[3, 2],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Create target hidden states
target_hidden_states = [
torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device)
for _ in range(num_hidden_layers)
]
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
mock_scheduler_output = mock.MagicMock()
with mock.patch(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, _ = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None,
)
assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)