[Spec Decode] Add hidden states extraction system (#33736)

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
This commit is contained in:
Fynn Schmitt-Ulms
2026-03-02 14:29:09 -05:00
committed by GitHub
parent d1a6e96d9e
commit 9433acb8df
16 changed files with 2102 additions and 38 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import copy
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
@@ -181,9 +182,22 @@ class SpeculativeConfig:
the final hidden states.
"""
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
# Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
factors.append(uses_aux_hidden_states)
# The specific layers used also affect the computation graph
if uses_aux_hidden_states and self.draft_model_config is not None:
layer_ids = getattr(
self.draft_model_config.hf_config,
"eagle_aux_hidden_state_layer_ids",
None,
)
if layer_ids is not None:
# Convert to tuple to make it hashable
factors.append(tuple(layer_ids))
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@@ -352,6 +366,8 @@ class SpeculativeConfig:
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
elif self.method == "extract_hidden_states":
self.model = "extract_hidden_states"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
@@ -394,6 +410,34 @@ class SpeculativeConfig:
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
elif self.method == "extract_hidden_states":
from vllm.transformers_utils.configs.extract_hidden_states import (
ExtractHiddenStatesConfig,
)
# ExtractHiddenStatesModel is instantiated manually in load_model()
# We just need to store the target model config for KV cache shape info
self.model = "extract_hidden_states"
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if hasattr(self.draft_model_config, "hf_config"):
hf_config = self.draft_model_config.hf_config.to_dict()
elif (
isinstance(self.draft_model_config, dict)
and "hf_config" in self.draft_model_config
):
hf_config = self.draft_model_config["hf_config"]
else:
hf_config = {}
self.draft_model_config = copy.copy(self.target_model_config)
self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
self.draft_model_config.hf_config, **hf_config
)
self.update_arch_()
self.draft_parallel_config = self.target_parallel_config
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
@@ -478,23 +522,8 @@ class SpeculativeConfig:
method=self.method,
model_type="eagle",
)
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = (
self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
self.update_arch_()
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
@@ -671,6 +700,24 @@ class SpeculativeConfig:
)
return speculative_draft_tensor_parallel_size
def update_arch_(self):
"""
EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
architectures-related fields in self.draft_model_config
"""
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
@@ -718,7 +765,7 @@ class SpeculativeConfig:
self.draft_parallel_config
)
eagle3_target_supported = [
aux_hidden_states_supported = [
"llama",
"qwen",
"minicpm",
@@ -729,16 +776,16 @@ class SpeculativeConfig:
"nemotron_h",
]
if (
self.method == "eagle3"
self.method in ("eagle3", "extract_hidden_states")
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported
for supported_model in aux_hidden_states_supported
)
):
raise ValueError(
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}"
f"{self.method} is only supported for {aux_hidden_states_supported}"
f" models. Got {self.target_model_config.hf_text_config.model_type=}"
)
self.verify_equal_vocab_size_if_draft_model()
return self
@@ -782,8 +829,15 @@ class SpeculativeConfig:
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
model = (
None
if method in ("ngram", "suffix", "extract_hidden_states")
else self.draft_model_config.model
)
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"