[Spec Decode] Add hidden states extraction system (#33736)

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
This commit is contained in:
Fynn Schmitt-Ulms
2026-03-02 14:29:09 -05:00
committed by GitHub
parent d1a6e96d9e
commit 9433acb8df
16 changed files with 2102 additions and 38 deletions

View File

@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from safetensors import safe_open
from vllm import LLM, SamplingParams
# Example: Using the custom "extract_hidden_states" speculator method and
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
with tempfile.TemporaryDirectory() as tmpdirname:
llm = LLM(
model="Qwen/Qwen3-8B", # Your target model
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {
"eagle_aux_hidden_state_layer_ids": [ # Target model layer indices
1,
2,
3,
4,
],
}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"shared_storage_path": tmpdirname,
},
},
)
prompts = ["Generate a sentence with hidden states", "Write a python function"]
sampling_params = SamplingParams(max_tokens=1)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print("\nPrompt:", output.prompt)
print("Prompt token ids:", output.prompt_token_ids)
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
print("Prompt hidden states path:", hidden_states_path)
with safe_open(hidden_states_path, "pt") as f:
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")
print("Extracted token ids:", token_ids) # Matches prompt token ids
print(
"Extracted hidden states shape:", hidden_states.shape
) # [num_hidden_layers, prompt len, hidden size]
print("Extracted hidden states:", hidden_states)

View File

@@ -108,7 +108,7 @@ class _HfExamplesInfo:
use_original_num_layers: bool = False
"""
If True, use the original number of layers from the model config
If True, use the original number of layers from the model config
instead of minimal layers for testing.
"""
@@ -1156,6 +1156,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
min_transformers_version="5.1.0",
),
"ExtractHiddenStatesModel": _HfExamplesInfo(
"Qwen/Qwen3-8B",
speculative_method="extract_hidden_states",
),
"Glm4MoeMTPModel": _HfExamplesInfo(
"zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5",

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Predictable dummy model for testing extract_hidden_states.
Subclasses LlamaForCausalLM but overrides the model to produce deterministic
hidden states: layer i outputs values equal to (i).
"""
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.sequence import IntermediateTensors
class PredictableLlamaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.aux_hidden_state_layers = tuple[int, ...]()
# Create minimal embed_tokens for embedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
)
# Required for pipeline parallelism
from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory,
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Embed input IDs."""
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
**extra_layer_kwargs,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""Forward pass that produces predictable outputs.
Returns:
If aux_hidden_state_layers is set: (hidden_states, aux_hidden_states)
Otherwise: hidden_states
"""
# Determine sequence length
if inputs_embeds is not None:
seq_len = inputs_embeds.shape[0]
device = inputs_embeds.device
elif input_ids is not None:
seq_len = input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1]
device = input_ids.device
else:
raise ValueError("Either input_ids or inputs_embeds must be provided")
# Final hidden states (last layer value)
hidden_states = torch.full(
(seq_len, self.config.hidden_size),
fill_value=float(self.config.num_hidden_layers),
device=device,
dtype=torch.bfloat16,
)
# Check if we need auxiliary hidden states
if len(self.aux_hidden_state_layers) > 0:
aux_hidden_states = []
for layer_idx in self.aux_hidden_state_layers:
# Fill with (layer_idx) for predictability
layer_hidden = torch.full(
(seq_len, self.config.hidden_size),
fill_value=float(layer_idx),
device=device,
dtype=torch.bfloat16,
)
aux_hidden_states.append(layer_hidden)
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Skip weight loading."""
return set()
class PredictableLlamaForCausalLM(LlamaForCausalLM):
"""Predictable Llama model for testing.
Overrides _init_model to use PredictableLlamaModel instead of LlamaModel.
"""
def _init_model(
self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] | None = None,
):
"""Initialize with predictable model."""
return PredictableLlamaModel(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Skip weight loading for dummy model."""
return set()

View File

@@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import os
import pytest
import torch
from safetensors import safe_open
from vllm import LLM, ModelRegistry, SamplingParams
def get_and_check_output(output, expected_shape):
assert output.kv_transfer_params is not None
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
assert os.path.exists(hidden_states_path)
# Load and verify the saved tensors
with safe_open(hidden_states_path, "pt") as f:
# Check that token_ids and hidden_states are present
tensor_names = f.keys()
assert "token_ids" in tensor_names
assert "hidden_states" in tensor_names
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")
prompt_token_ids = output.prompt_token_ids
assert torch.equal(token_ids, torch.tensor(prompt_token_ids))
assert hidden_states.shape == expected_shape
# Verify hidden_states are not all zeros (i.e., they were actually computed)
assert not torch.allclose(hidden_states, torch.zeros_like(hidden_states))
return token_ids, hidden_states
@pytest.fixture(scope="module")
def predictable_llama_config_path(tmp_path_factory):
"""Create a minimal LlamaConfig for PredictableLlamaForCausalLM."""
from transformers import LlamaConfig, LlamaTokenizerFast
config_dir = tmp_path_factory.mktemp("predictable_llama")
# Create a minimal Llama config with small dimensions
config = LlamaConfig(
vocab_size=1000,
hidden_size=256,
intermediate_size=512,
num_hidden_layers=24, # Enough layers to test various layer_ids
num_attention_heads=4,
num_key_value_heads=4,
max_position_embeddings=128,
architectures=["PredictableLlamaForCausalLM"],
)
# Save config
config.save_pretrained(config_dir)
# Create a simple tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
cache_dir=os.path.expanduser("~/.cache/huggingface"),
)
tokenizer.save_pretrained(config_dir)
return str(config_dir)
@pytest.fixture(scope="module", autouse=True)
def register_predictable_model():
"""Register the PredictableLlamaForCausalLM model."""
from .predictable_llama import PredictableLlamaForCausalLM
if "PredictableLlamaForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"PredictableLlamaForCausalLM", PredictableLlamaForCausalLM
)
yield
def test_extract_hidden_states_with_predictable_dummy_model(
predictable_llama_config_path, tmp_path
):
"""Comprehensive test using a predictable dummy model with synthetic weights.
The PredictableLlamaForCausalLM outputs deterministic hidden states where
each layer produces values equal to (layer_index). This test verifies:
1. Hidden states are correctly extracted from requested layers
2. Values match the expected predictable pattern
3. Layer ordering is preserved correctly (non-sequential layer IDs)
4. Multiple prompts of different lengths produce consistent layer values
"""
# Test with non-sequential layer ordering to verify correct association
layer_ids = [5, 2, 10]
num_layers = len(layer_ids)
llm = LLM(
model=predictable_llama_config_path,
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {"shared_storage_path": tmp_path},
},
max_model_len=128,
enforce_eager=True,
trust_remote_code=True,
load_format="dummy", # Don't try to load real weights
)
# Test with multiple prompts of different lengths
prompts = [
"Short",
"Medium length",
"Much longer prompt with many tokens",
"Much longer prompt with many tokens", # repeated prompt
]
sampling_params = SamplingParams(max_tokens=1, temperature=0.0)
hidden_size = llm.llm_engine.model_config.get_hidden_size()
outputs = llm.generate(prompts, sampling_params)
del llm
gc.collect()
assert len(outputs) == len(prompts)
for output in outputs:
# hidden_states shape is [prompt_len, num_hidden_layers, hidden_size]
expected_shape = (
len(output.prompt_token_ids),
num_layers,
hidden_size,
)
_token_ids, hidden_states = get_and_check_output(output, expected_shape)
for idx, layer_id in enumerate(layer_ids):
layer_hidden = hidden_states[:, idx, :]
assert torch.allclose(
layer_hidden,
torch.full_like(layer_hidden, layer_id),
atol=1e-5,
), (
f"Layer {layer_id} at position {idx} should output {float(layer_id)}, "
f"but got mean={layer_hidden.mean():.3f}, "
f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}"
)

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
)
from vllm.config import (
AttentionConfig,
CacheConfig,
DeviceConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.config.load import LoadConfig
from vllm.platforms import current_platform
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
def _create_proposer(
num_speculative_tokens: int = 1,
layer_ids: list[int] | None = None,
) -> ExtractHiddenStatesProposer:
"""Create an ExtractHiddenStatesProposer for testing."""
if layer_ids is None:
layer_ids = [1, 2, 3, 4]
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="extract_hidden_states",
num_speculative_tokens=num_speculative_tokens,
draft_model_config={
"hf_config": {
"eagle_aux_hidden_state_layer_ids": layer_ids,
}
},
)
device = current_platform.device_type
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
attention_config=AttentionConfig(),
)
return ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device)
def test_proposer_initialization():
"""Test that the proposer initializes correctly with the right parameters."""
layer_ids = [1, 2, 3, 4]
proposer = _create_proposer(num_speculative_tokens=1, layer_ids=layer_ids)
assert proposer.num_hidden_states == len(layer_ids)
assert proposer.vllm_config.speculative_config is not None
assert proposer.vllm_config.speculative_config.num_speculative_tokens == 1
# Verify the hidden states buffer is correctly shaped
expected_shape = (
proposer.max_num_tokens,
len(layer_ids),
proposer.hidden_size,
)
assert proposer.hidden_states.shape == expected_shape
def test_proposer_initialization_missing_layer_ids():
"""Test that initialization fails when layer_ids are not provided."""
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="extract_hidden_states",
num_speculative_tokens=1,
draft_model_config={
"hf_config": {} # Missing eagle_aux_hidden_state_layer_ids
},
)
device = current_platform.device_type
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
attention_config=AttentionConfig(),
)
with pytest.raises(
ValueError, match="eagle_aux_hidden_state_layer_ids must be set"
):
ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device)
def test_prepare_next_token_ids_padded():
"""
Test for prepare_next_token_ids_padded with extract_hidden_states.
Since num_speculative_tokens == 1, sampled_token_ids has shape (batch_size, 1).
For each request we either use the sampled token (if valid and not discarded)
or a backup token from the request state.
"""
device = torch.device(current_platform.device_type)
num_requests = 4
batch_spec = BatchSpec(
seq_lens=[5] * num_requests,
query_lens=[5] * num_requests,
)
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_requests = {}
for req_id in req_ids:
mock_request = mock.MagicMock(spec=CachedRequestState)
# Each request will have a backup next token id of 10, 20, 30, 40
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
mock_requests[req_id] = mock_request
# explicitly discard the last request
discarded_req_mask = torch.tensor(
[False, False, False, True], dtype=torch.bool, device=device
)
# With num_speculative_tokens=1, sampled_token_ids has shape [batch_size, 1]
sampled_token_ids = torch.tensor(
[
[1], # valid, use 1
[4], # valid, use 4
[-1], # invalid, use backup token "30"
[2], # explicitly discarded, use backup token "40"
],
dtype=torch.int32,
device=device,
)
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(
expected_next_token_ids_cpu, dtype=torch.int32, device=device
)
proposer = _create_proposer(num_speculative_tokens=1)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
# It doesn't depend on whether the request is discarded
expected_valid_sampled_tokens_count = torch.tensor(
[1, 1, 0, 1], dtype=torch.int32, device=device
)
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
mock_requests,
mock_input_batch,
discarded_req_mask,
)
assert torch.equal(next_token_ids, expected_next_token_ids_tensor)
assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
def test_propose():
"""
Test the propose() method of ExtractHiddenStatesProposer.
This should:
1. Accept target hidden states and sampled token IDs
2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1])
3. Cache the hidden states in the model's KV cache
"""
device = torch.device(current_platform.device_type)
# Setup test parameters
batch_size = 2
num_tokens = 5
num_hidden_layers = 4
proposer = _create_proposer(
num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers))
)
hidden_size = proposer.hidden_size
# Create mock model
model_mock = mock.MagicMock()
proposer.model = model_mock
# Mock attention layer names
proposer.attn_layer_names = ["cache_only_layers.28"]
# Mock attention metadata builder
mock_attn_metadata = mock.MagicMock()
mock_attn_metadata_builder = mock.MagicMock()
mock_attn_metadata_builder.build_for_drafting.return_value = mock_attn_metadata
proposer.attn_metadata_builder = mock_attn_metadata_builder
# Create input tensors
batch_spec = BatchSpec(
seq_lens=[3, 2],
query_lens=[3, 2],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Create target hidden states: list of tensors, one per layer
# Each tensor has shape [num_tokens, hidden_size]
target_hidden_states = [
torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device)
for _ in range(num_hidden_layers)
]
# Sampled token IDs from target model
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
# Mock scheduler output
mock_scheduler_output = mock.MagicMock()
# Call propose
with mock.patch(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, kv_connector_output = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None,
)
# Verify draft tokens match sampled tokens
# Shape should be [batch_size, 1] for num_speculative_tokens=1
assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
# Verify the model was called
model_mock.assert_called_once()
# Verify hidden states were copied to the buffer The stacked hidden states
# should have shape [num_tokens, num_hidden_layers, hidden_size]
expected_stacked = torch.stack(target_hidden_states, dim=1)
assert torch.allclose(
proposer.hidden_states[:num_tokens], expected_stacked, atol=1e-6
)
@pytest.mark.parametrize("num_hidden_layers", [1, 4, 8])
def test_propose_different_layer_counts(num_hidden_layers):
"""Test that propose works correctly with different numbers of hidden layers."""
device = torch.device(current_platform.device_type)
batch_size = 2
num_tokens = 5
proposer = _create_proposer(
num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers))
)
hidden_size = proposer.hidden_size
# Setup mocks
model_mock = mock.MagicMock()
proposer.model = model_mock
proposer.attn_layer_names = ["cache_only_layers.28"]
mock_attn_metadata_builder = mock.MagicMock()
mock_attn_metadata_builder.build_for_drafting.return_value = mock.MagicMock()
proposer.attn_metadata_builder = mock_attn_metadata_builder
batch_spec = BatchSpec(
seq_lens=[3, 2],
query_lens=[3, 2],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Create target hidden states
target_hidden_states = [
torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device)
for _ in range(num_hidden_layers)
]
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
mock_scheduler_output = mock.MagicMock()
with mock.patch(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, _ = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None,
)
assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import copy
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
@@ -181,9 +182,22 @@ class SpeculativeConfig:
the final hidden states.
"""
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
# Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
factors.append(uses_aux_hidden_states)
# The specific layers used also affect the computation graph
if uses_aux_hidden_states and self.draft_model_config is not None:
layer_ids = getattr(
self.draft_model_config.hf_config,
"eagle_aux_hidden_state_layer_ids",
None,
)
if layer_ids is not None:
# Convert to tuple to make it hashable
factors.append(tuple(layer_ids))
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@@ -352,6 +366,8 @@ class SpeculativeConfig:
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
elif self.method == "extract_hidden_states":
self.model = "extract_hidden_states"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
@@ -394,6 +410,34 @@ class SpeculativeConfig:
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
elif self.method == "extract_hidden_states":
from vllm.transformers_utils.configs.extract_hidden_states import (
ExtractHiddenStatesConfig,
)
# ExtractHiddenStatesModel is instantiated manually in load_model()
# We just need to store the target model config for KV cache shape info
self.model = "extract_hidden_states"
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if hasattr(self.draft_model_config, "hf_config"):
hf_config = self.draft_model_config.hf_config.to_dict()
elif (
isinstance(self.draft_model_config, dict)
and "hf_config" in self.draft_model_config
):
hf_config = self.draft_model_config["hf_config"]
else:
hf_config = {}
self.draft_model_config = copy.copy(self.target_model_config)
self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
self.draft_model_config.hf_config, **hf_config
)
self.update_arch_()
self.draft_parallel_config = self.target_parallel_config
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
@@ -478,23 +522,8 @@ class SpeculativeConfig:
method=self.method,
model_type="eagle",
)
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = (
self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
self.update_arch_()
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
@@ -671,6 +700,24 @@ class SpeculativeConfig:
)
return speculative_draft_tensor_parallel_size
def update_arch_(self):
"""
EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
architectures-related fields in self.draft_model_config
"""
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
@@ -718,7 +765,7 @@ class SpeculativeConfig:
self.draft_parallel_config
)
eagle3_target_supported = [
aux_hidden_states_supported = [
"llama",
"qwen",
"minicpm",
@@ -729,16 +776,16 @@ class SpeculativeConfig:
"nemotron_h",
]
if (
self.method == "eagle3"
self.method in ("eagle3", "extract_hidden_states")
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported
for supported_model in aux_hidden_states_supported
)
):
raise ValueError(
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}"
f"{self.method} is only supported for {aux_hidden_states_supported}"
f" models. Got {self.target_model_config.hf_text_config.model_type=}"
)
self.verify_equal_vocab_size_if_draft_model()
return self
@@ -782,8 +829,15 @@ class SpeculativeConfig:
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
model = (
None
if method in ("ngram", "suffix", "extract_hidden_states")
else self.draft_model_config.model
)
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"

View File

@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
def clear_events(self) -> None:
raise NotImplementedError
def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
self.add_events(other.get_all_events())
return self
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism

View File

@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
"ExampleConnector",
)
KVConnectorFactory.register_connector(
"ExampleHiddenStatesConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
"ExampleHiddenStatesConnector",
)
KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",

View File

@@ -0,0 +1,354 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
def extract_from_kv_cache(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Extract data from KV cache
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
"""
padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
# shape: [len(slot_mapping), num_heads, head_size]
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
@dataclass
class ReqMeta:
# Request ID
req_id: str
# Request filename
filename: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Whether this request is a new request or partially computed already
new_req: bool
@staticmethod
def make_meta(
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool,
) -> "ReqMeta":
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()
return ReqMeta(
req_id=req_id,
filename=filename,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
new_req=new_req,
)
@dataclass
class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool = True,
) -> None:
self.requests.append(
ReqMeta.make_meta(
req_id, filename, token_ids, block_ids, block_size, new_req
)
)
class ExampleHiddenStatesConnector(KVConnectorBase_V1):
"""
Simple debug implementation of a HiddenStatesConnector.
Simply extracts the hidden states from the kv cache and stores them to disk.
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
"""
@property
def prefer_cross_layer_blocks(self) -> bool:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
# Must be False so that drafter kv cache isn't merged with verifier's
return False
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
self.cache_layers: list[str] = [] # set by self.register_kv_caches
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
assert self._vllm_config.speculative_config is not None, (
"ExampleHiddenStatesConnector only works when using "
"'extract_hidden_states' speculative method"
)
spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
self.num_hidden_states = len(
getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
)
self._request_filenames: dict[str, str] = {}
self._active_requests: dict[str, NewRequestData] = {}
self._req_blocks: dict[str, list[int]] = {}
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, *args, **kwargs: Any) -> None:
pass # Empty implementation of abstract method
def wait_for_layer_load(self, layer_name: str) -> None:
pass # Empty implementation of abstract method
def wait_for_save(self):
pass # Empty implementation of abstract method
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
from vllm.model_executor.models.extract_hidden_states import (
CacheOnlyAttentionLayer,
)
# Filter layers to only include CacheOnlyAttentionLayers
layers = get_layers_from_vllm_config(
self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
)
self.cache_layers = list(layers.keys())
assert len(self.cache_layers) == 1, (
f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
)
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
if layer_name not in self.cache_layers:
return
from vllm.model_executor.models.extract_hidden_states import (
CacheOnlyAttentionMetadata,
)
assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
)
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
os.makedirs(self._storage_path, exist_ok=True)
for request in connector_metadata.requests:
hidden_states = extract_from_kv_cache(
kv_layer, request.slot_mapping, request.token_ids.shape[0]
)
tensors = {
"hidden_states": hidden_states.detach().cpu(),
"token_ids": request.token_ids.detach().cpu(),
}
safetensors.torch.save_file(tensors, request.filename)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# This connector is store-only, so we don't need to load any tokens
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
# Usually used to handle allocation of new blocks for requests that are loading
# tokens from connector's external kv cache. We never load from external cache
# so this is a no-op.
assert num_external_tokens == 0, "This connector is store-only"
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ExampleHiddenStatesConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
meta.add_request(
new_req.req_id,
filename=filename,
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
self._request_filenames[new_req.req_id] = filename
self._active_requests[new_req.req_id] = new_req
self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
if req_id not in self._active_requests:
continue
new_block_ids = cached_reqs.new_block_ids[i]
cached_req = self._active_requests[req_id]
req_block_ids = self._req_blocks[req_id]
assert new_block_ids is not None
block_ids = new_block_ids[0]
req_block_ids.extend(block_ids)
filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
meta.add_request(
req_id=req_id,
filename=filename,
token_ids=cached_req.prompt_token_ids or [],
block_ids=req_block_ids,
block_size=self._block_size,
new_req=False,
)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
req_filename = self._request_filenames.pop(req_id, None)
_ = self._active_requests.pop(req_id, None)
_ = self._req_blocks.pop(req_id, None)
return False, {"hidden_states_path": req_filename}
@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
# NHD means we have (num_tokens, num_heads)
# HND means we have (num_heads, num_tokens)
# For now, we only support NHD layout since this keeps the
# hidden states for each token together in memory.
# HND is primarily used when sharding heads across devices.
return "NHD"

View File

@@ -0,0 +1,394 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Hidden States Extractor Model.
This model extracts and caches hidden states from the target model
without performing actual token generation. It's used with the
extract_hidden_states speculative decoding method.
"""
from collections.abc import Iterable
from typing import ClassVar
import torch
import torch.nn as nn
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.attention.attention import set_default_quant_scales
from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.utils import maybe_prefix
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
)
########## Custom Ops ########
def unified_kv_cache_update(
to_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
)
attn_layer.impl.do_kv_cache_update(
attn_layer,
to_cache,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
@maybe_transfer_kv_layer
def dummy_attention(layer_name, _placeholder):
# Note: layer_name arg required by @maybe_transfer_kv_layer
return _placeholder
def basic_cache(
to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
slot_mapping: torch.Tensor, # shape: [seq_len]
):
num_blocks, block_size, num_heads, head_size = kv_cache.shape
token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
token_kv_cache[slot_mapping] = to_cache
######### CacheOnlyAttentionBackend ########
class CacheOnlyAttentionBackend(AttentionBackend):
"""Attention backend that only caches KV without computing attention."""
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "CACHE_ONLY_ATTN"
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
return attn_type == AttentionType.DECODER
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@staticmethod
def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
return CacheOnlyAttentionImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
# We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
# We also don't use a k/v (2) dim
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
return CacheOnlyAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
class CacheOnlyAttentionMetadata:
def __init__(self, slot_mapping: torch.Tensor):
self.slot_mapping = slot_mapping
class CacheOnlyAttentionMetadataBuilder(
AttentionMetadataBuilder[CacheOnlyAttentionMetadata]
):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> CacheOnlyAttentionMetadata:
use_cascade = common_prefix_len > 0
if use_cascade:
raise NotImplementedError(
"Cascade attention not supported by CacheOnlyAttention"
)
causal = common_attn_metadata.causal
if not causal:
raise NotImplementedError(
"Non-causal attention not supported by CacheOnlyAttention"
)
return CacheOnlyAttentionMetadata(
slot_mapping=common_attn_metadata.slot_mapping,
)
class CacheOnlyAttentionImpl(AttentionImpl):
"""Attention implementation that only caches KV states."""
def __init__(
self,
num_heads: int,
head_size: int,
kv_cache_dtype: str,
kv_cache_torch_dtype: torch.dtype,
attn_type: AttentionType = AttentionType.DECODER,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.kv_cache_dtype = kv_cache_dtype
self.kv_cache_torch_dtype = kv_cache_torch_dtype
if attn_type != AttentionType.DECODER:
raise NotImplementedError(f"Unsupported attention type: {attn_type}")
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("Quantized KV cache not supported")
self.num_queries_per_kv = 1
def do_kv_cache_update(
self,
layer,
to_cache,
kv_cache,
slot_mapping,
):
assert to_cache.dtype == self.kv_cache_torch_dtype, (
f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
)
assert kv_cache.dtype == self.kv_cache_torch_dtype, (
f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
)
basic_cache(to_cache, kv_cache, slot_mapping)
def forward(self, *args, **kwargs):
# Empty implementation of abstract method
pass
############## CacheOnlyAttentionLayer (replaces Attention) ############
class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
"""Attention layer that only caches key/value states without computing attention."""
def __init__(
self,
num_heads: int,
head_size: int,
cache_config: CacheConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.layer_name = prefix
vllm_config = get_current_vllm_config()
# KV cache configuration
cache_config = cache_config or vllm_config.cache_config
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
self.block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
self.block_size = 16
assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
"CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
f"kv cache dtype was set to {kv_cache_dtype}"
)
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
# Initialize KV cache quantization attributes
set_default_quant_scales(self, register_buffer=True)
# Attention backend
self.attn_backend = CacheOnlyAttentionBackend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
head_size,
kv_cache_dtype,
self.kv_cache_torch_dtype,
attn_type,
)
assert not self.attn_backend.forward_includes_kv_cache_update, (
"KV cache update should be independent of forward"
)
# Placeholder KV cache (replaced by bind_kv_cache)
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Register in compilation context
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
"""Cache hidden states as KV pairs without computing attention.
Args:
to_cache: The tensor to insert into the kv cache.
shape [num_tokens, num_heads, head_size]
Returns:
Dummy output tensor (not used)
"""
# Note: we set num_heads to num_hidden_layers and
# head_size to hidden_size for hidden states storage
output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)
# Note: dummy_out is used to force torch.compile to preserve ordering between
# cache update and attention op (which triggers kv_connector transfer)
dummy_out = unified_kv_cache_update(to_cache, self.layer_name)
# Triggers kv_connector transfer via decorator
_ = dummy_attention(self.layer_name, dummy_out)
return output
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Note: we use MLAAttentionSpec here to because it will
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
# whereas FullAttentionSpec will add an additional factor of 2
return MLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
############ ExtractHiddenStatesModel definition ##########
class ExtractHiddenStatesModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.target_num_hidden_layers = (
vllm_config.model_config.get_total_num_hidden_layers()
)
self.num_hidden_states = len(
getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
)
cache_config = vllm_config.cache_config
# Create a single cache-only attention layer
# Note: We set num_heads <- self.num_hidden_states
# and head_size <- hidden_size so that we can insert
# the hidden states directly into the cache without
# reshaping
self.cache_only_layers = nn.ModuleDict(
{
str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
num_heads=self.num_hidden_states,
head_size=self.hidden_size,
cache_config=cache_config,
prefix=maybe_prefix(
prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
),
)
}
)
def forward(self, hidden_states: torch.Tensor) -> None:
"""Process and cache hidden states.
Args:
hidden_states: Hidden states from target model
shape: [num_tokens, num_hidden_states, hidden_size]
Returns:
Tuple of (dummy_output, dummy_output) - both unused
"""
# Call dummy attention layer to cache hidden states
# Output is ignored - we only care about the KV cache side effects
_ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""No weights to load for this dummy model."""
return set()

View File

@@ -512,6 +512,7 @@ _MULTIMODAL_MODELS = {
}
_SPECULATIVE_DECODING_MODELS = {
"ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),

View File

@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Config definitions for ExtractHiddenStatesModel, to be used with
the extract_hidden_states spec decoding method."""
import os
from transformers import PretrainedConfig
class ExtractHiddenStatesConfig(PretrainedConfig):
model_type = "extract_hidden_states"
def __init__(
self,
model: PretrainedConfig | dict | None = None,
method: str | None = "extract_hidden_states",
**kwargs,
):
assert method == "extract_hidden_states"
if isinstance(model, dict):
model_dict = model
elif isinstance(model, PretrainedConfig):
model_dict = model.to_dict()
else:
model_dict = {}
# Combine: model_dict first, then kwargs override
combined = {**model_dict, **kwargs}
# Remove architectures from the base, we'll set it explicitly
combined = {k: v for k, v in combined.items() if k != "architectures"}
combined["architectures"] = ["ExtractHiddenStatesModel"]
super().__init__(**combined)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
**kwargs,
) -> "ExtractHiddenStatesConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
return cls.from_dict(config_dict, **kwargs)
def to_json_string(self, use_diff: bool = True) -> str:
# we override use_diff to False as initializing
# ExtractHiddenStatesConfig with default arguments is not supported
del use_diff
return super().to_json_string(use_diff=False)

View File

@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypeVar
import numpy as np
import torch
@@ -120,6 +121,20 @@ class SamplerOutput:
logprobs_tensors: LogprobsTensors | None
T = TypeVar("T")
def _combine_non_none(f: Callable[[T, T], T], items: list[T | None]) -> T | None:
non_none = [item for item in items if item is not None]
if len(non_none) == 0:
return None
combined = non_none[0]
for item in non_none[1:]:
combined = f(combined, item)
return combined
@dataclass
class KVConnectorOutput:
# [req_ids]
@@ -146,6 +161,43 @@ class KVConnectorOutput:
and not self.invalid_block_ids
)
@classmethod
def merge(cls, *outputs: "KVConnectorOutput"):
assert len(outputs) > 0, "Cannot merge empty outputs"
finished_sending = _combine_non_none(
set.union, [output.finished_sending for output in outputs]
)
finished_recving = _combine_non_none(
set.union, [output.finished_recving for output in outputs]
)
kv_connector_stats = _combine_non_none(
lambda x, y: x.aggregate(y),
[output.kv_connector_stats for output in outputs],
)
kv_cache_events = _combine_non_none(
lambda x, y: x.merge(y),
[output.kv_cache_events for output in outputs],
)
invalid_block_ids = _combine_non_none(
set.union, [output.invalid_block_ids for output in outputs]
)
assert invalid_block_ids is not None
assert all(
output.expected_finished_count == outputs[0].expected_finished_count
for output in outputs
)
expected_finished_count = outputs[0].expected_finished_count
return cls(
finished_sending=finished_sending,
finished_recving=finished_recving,
kv_connector_stats=kv_connector_stats,
kv_cache_events=kv_cache_events,
invalid_block_ids=invalid_block_ids,
expected_finished_count=expected_finished_count,
)
@dataclass
class ECConnectorOutput:

View File

@@ -0,0 +1,395 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1
class ExtractHiddenStatesProposer:
def __init__(self, vllm_config: VllmConfig, device):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.num_speculative_tokens == 1
if vllm_config.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self.vllm_config = vllm_config
self.device = device
self.dtype = vllm_config.model_config.dtype
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self.model: nn.Module | None = None
self.attn_layer_names: list[str] = []
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
# Maximum number of tokens for buffers
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
raise ValueError(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self.num_hidden_states = len(layer_ids)
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
dtype=self.dtype,
device=device,
)
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
def propose(
self,
sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert self.model is not None and isinstance(target_hidden_states, list)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
num_tokens = stacked_hidden_states.shape[0]
# Copy hidden states to buffer
self.hidden_states[:num_tokens] = stacked_hidden_states
assert self.attn_metadata_builder is not None
attn_metadata = self.attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
with (
set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names}
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, (
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert self.vllm_config.speculative_config is not None
if (
not self.vllm_config.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
proposer_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
assert self.model is not None, "Model must be initialized before dummy_run"
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
def _build_attn_metadata_builder(
self, draft_attn_layers: dict[str, AttentionLayerBase]
) -> AttentionMetadataBuilder:
"""Build the attention metadata builder from draft attention layers."""
if not draft_attn_layers:
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
layer = next(iter(draft_attn_layers.values()))
attn_backend = layer.get_attn_backend()
return attn_backend.get_builder_cls()(
layer.get_kv_cache_spec(self.vllm_config),
self.attn_layer_names,
self.vllm_config,
self.device,
)
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs = gpu_input_batch.num_reqs
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
],
dtype=torch.int32,
device=device,
)
assert discard_request_mask.dtype == torch.bool
# With num_speculative_tokens == 1, there is exactly one token
sampled = sampled_token_ids[:, 0]
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
valid_sampled_tokens_count = is_valid.to(torch.int32)
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
next_token_ids = torch.where(
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
)
return next_token_ids, valid_sampled_tokens_count
def load_model(self, target_model: nn.Module) -> None:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
)
assert self.vllm_config.speculative_config is not None
draft_model_config = self.vllm_config.speculative_config.draft_model_config
from vllm.compilation.backends import set_model_tag
with set_model_tag("extract_hidden_states"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
draft_attn_layers = {
name: layer
for name, layer in all_attn_layers.items()
if name not in target_attn_layer_names
}
self.attn_layer_names = list(draft_attn_layers.keys())
assert len(draft_attn_layers) == 1, (
"ExtractHiddenStatesModel should have exactly one "
f"attention layer, found {len(draft_attn_layers)}"
)
self.attn_metadata_builder = self._build_attn_metadata_builder(
draft_attn_layers
)
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert len(self.attn_layer_names) == 1

View File

@@ -159,6 +159,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
@@ -495,6 +496,7 @@ class GPUModelRunner(
| EagleProposer
| DraftModelProposer
| MedusaProposer
| ExtractHiddenStatesProposer
)
if self.speculative_config.method == "ngram":
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -518,6 +520,11 @@ class GPUModelRunner(
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device
)
elif self.speculative_config.method == "extract_hidden_states":
self.drafter = ExtractHiddenStatesProposer(
vllm_config=self.vllm_config, device=self.device
)
self.use_aux_hidden_state_outputs = True
else:
raise ValueError(
"Unknown speculative decoding method: "
@@ -3693,10 +3700,9 @@ class GPUModelRunner(
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
if self.execute_model_state is None:
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
# receive sampled token ids from the last PP rank.
if self.use_async_scheduling and get_pp_group().world_size > 1:
self._pp_receive_prev_sampled_token_ids_to_input_batch()
@@ -3778,12 +3784,17 @@ class GPUModelRunner(
<= self.effective_drafter_max_model_len
)
use_gpu_toks = (
spec_config.use_eagle() or spec_config.uses_draft_model()
spec_config.use_eagle()
or spec_config.uses_draft_model()
or spec_config.uses_extract_hidden_states()
) and not spec_config.disable_padded_drafter_batch
if use_gpu_toks:
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
propose_draft_token_ids(sampled_token_ids)
@@ -3842,6 +3853,10 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()
# self.kv_connector_output may be modified during drafting
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
if self.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance()
@@ -4068,6 +4083,48 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata,
slot_mappings=slot_mappings,
)
elif spec_config.uses_extract_hidden_states():
assert isinstance(self.drafter, ExtractHiddenStatesProposer)
assert isinstance(sampled_token_ids, torch.Tensor), (
"sampled_token_ids should be a torch.Tensor for "
"extract_hidden_states method."
)
if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
raise ValueError(
"aux_hidden_states are required when using `extract_hidden_states`"
)
target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=scheduler_output,
slot_mappings=slot_mappings,
)
# Combine KVConnectorOutputs or select the non-empty one
if self.kv_connector_output and drafter_kv_connector_output:
self.kv_connector_output = KVConnectorOutput.merge(
self.kv_connector_output, drafter_kv_connector_output
)
else:
self.kv_connector_output = (
self.kv_connector_output or drafter_kv_connector_output
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
@@ -4946,8 +5003,12 @@ class GPUModelRunner(
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
or self.speculative_config.uses_extract_hidden_states()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
)
assert self.speculative_config is not None
# Eagle currently only supports PIECEWISE cudagraphs.
# Therefore only use cudagraphs if the main model uses PIECEWISE
@@ -5656,9 +5717,12 @@ class GPUModelRunner(
cudagraph_mode, self.uniform_decode_query_len
)
# Initialize eagle's cudagraph dispatcher if using eagle spec decode.
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# Initialize drafter's cudagraph dispatcher if using spec decode.
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_extract_hidden_states()
):
assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer)
self.drafter.initialize_cudagraph_keys(cudagraph_mode)
def calculate_reorder_batch_threshold(self) -> None:
@@ -6025,8 +6089,12 @@ class GPUModelRunner(
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
or self.speculative_config.uses_extract_hidden_states()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
assert isinstance(
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)