[Spec Decode] Add hidden states extraction system (#33736)
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
This commit is contained in:
committed by
GitHub
parent
d1a6e96d9e
commit
9433acb8df
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
|
||||
"pangu_ultra_moe_mtp",
|
||||
"step3p5_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
@@ -181,9 +182,22 @@ class SpeculativeConfig:
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
# Eagle3 and extract_hidden_states affect the computation graph because
|
||||
# they return intermediate hidden states in addition to the final hidden state.
|
||||
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
|
||||
factors.append(uses_aux_hidden_states)
|
||||
|
||||
# The specific layers used also affect the computation graph
|
||||
if uses_aux_hidden_states and self.draft_model_config is not None:
|
||||
layer_ids = getattr(
|
||||
self.draft_model_config.hf_config,
|
||||
"eagle_aux_hidden_state_layer_ids",
|
||||
None,
|
||||
)
|
||||
if layer_ids is not None:
|
||||
# Convert to tuple to make it hashable
|
||||
factors.append(tuple(layer_ids))
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@@ -352,6 +366,8 @@ class SpeculativeConfig:
|
||||
self.model = "ngram"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
elif self.method == "extract_hidden_states":
|
||||
self.model = "extract_hidden_states"
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
@@ -394,6 +410,34 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
elif self.method == "extract_hidden_states":
|
||||
from vllm.transformers_utils.configs.extract_hidden_states import (
|
||||
ExtractHiddenStatesConfig,
|
||||
)
|
||||
|
||||
# ExtractHiddenStatesModel is instantiated manually in load_model()
|
||||
# We just need to store the target model config for KV cache shape info
|
||||
self.model = "extract_hidden_states"
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if hasattr(self.draft_model_config, "hf_config"):
|
||||
hf_config = self.draft_model_config.hf_config.to_dict()
|
||||
elif (
|
||||
isinstance(self.draft_model_config, dict)
|
||||
and "hf_config" in self.draft_model_config
|
||||
):
|
||||
hf_config = self.draft_model_config["hf_config"]
|
||||
else:
|
||||
hf_config = {}
|
||||
|
||||
self.draft_model_config = copy.copy(self.target_model_config)
|
||||
self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
|
||||
self.draft_model_config.hf_config, **hf_config
|
||||
)
|
||||
self.update_arch_()
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
@@ -478,23 +522,8 @@ class SpeculativeConfig:
|
||||
method=self.method,
|
||||
model_type="eagle",
|
||||
)
|
||||
# EAGLEConfig primarily updates architectures, so update
|
||||
# all architectures-related fields in draft_model_config
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = (
|
||||
self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
self.update_arch_()
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
@@ -671,6 +700,24 @@ class SpeculativeConfig:
|
||||
)
|
||||
return speculative_draft_tensor_parallel_size
|
||||
|
||||
def update_arch_(self):
|
||||
"""
|
||||
EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
|
||||
architectures-related fields in self.draft_model_config
|
||||
"""
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
@@ -718,7 +765,7 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config
|
||||
)
|
||||
|
||||
eagle3_target_supported = [
|
||||
aux_hidden_states_supported = [
|
||||
"llama",
|
||||
"qwen",
|
||||
"minicpm",
|
||||
@@ -729,16 +776,16 @@ class SpeculativeConfig:
|
||||
"nemotron_h",
|
||||
]
|
||||
if (
|
||||
self.method == "eagle3"
|
||||
self.method in ("eagle3", "extract_hidden_states")
|
||||
and self.target_model_config
|
||||
and not any(
|
||||
supported_model in self.target_model_config.hf_text_config.model_type
|
||||
for supported_model in eagle3_target_supported
|
||||
for supported_model in aux_hidden_states_supported
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
f"{self.method} is only supported for {aux_hidden_states_supported}"
|
||||
f" models. Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
)
|
||||
self.verify_equal_vocab_size_if_draft_model()
|
||||
return self
|
||||
@@ -782,8 +829,15 @@ class SpeculativeConfig:
|
||||
def uses_draft_model(self) -> bool:
|
||||
return self.method == "draft_model"
|
||||
|
||||
def uses_extract_hidden_states(self) -> bool:
|
||||
return self.method == "extract_hidden_states"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
model = (
|
||||
None
|
||||
if method in ("ngram", "suffix", "extract_hidden_states")
|
||||
else self.draft_model_config.model
|
||||
)
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
|
||||
Reference in New Issue
Block a user