[Spec Decode] Add hidden states extraction system (#33736)
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
This commit is contained in:
committed by
GitHub
parent
d1a6e96d9e
commit
9433acb8df
58
examples/offline_inference/extract_hidden_states.py
Normal file
58
examples/offline_inference/extract_hidden_states.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Example: Using the custom "extract_hidden_states" speculator method and
|
||||
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen3-8B", # Your target model
|
||||
speculative_config={
|
||||
"method": "extract_hidden_states",
|
||||
"num_speculative_tokens": 1,
|
||||
"draft_model_config": {
|
||||
"hf_config": {
|
||||
"eagle_aux_hidden_state_layer_ids": [ # Target model layer indices
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
],
|
||||
}
|
||||
},
|
||||
},
|
||||
kv_transfer_config={
|
||||
"kv_connector": "ExampleHiddenStatesConnector",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": tmpdirname,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
prompts = ["Generate a sentence with hidden states", "Write a python function"]
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
print("\nPrompt:", output.prompt)
|
||||
print("Prompt token ids:", output.prompt_token_ids)
|
||||
|
||||
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
|
||||
assert hidden_states_path is not None
|
||||
print("Prompt hidden states path:", hidden_states_path)
|
||||
|
||||
with safe_open(hidden_states_path, "pt") as f:
|
||||
token_ids = f.get_tensor("token_ids")
|
||||
hidden_states = f.get_tensor("hidden_states")
|
||||
|
||||
print("Extracted token ids:", token_ids) # Matches prompt token ids
|
||||
print(
|
||||
"Extracted hidden states shape:", hidden_states.shape
|
||||
) # [num_hidden_layers, prompt len, hidden size]
|
||||
print("Extracted hidden states:", hidden_states)
|
||||
@@ -108,7 +108,7 @@ class _HfExamplesInfo:
|
||||
|
||||
use_original_num_layers: bool = False
|
||||
"""
|
||||
If True, use the original number of layers from the model config
|
||||
If True, use the original number of layers from the model config
|
||||
instead of minimal layers for testing.
|
||||
"""
|
||||
|
||||
@@ -1156,6 +1156,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
speculative_model="LGAI-EXAONE/K-EXAONE-236B-A23B",
|
||||
min_transformers_version="5.1.0",
|
||||
),
|
||||
"ExtractHiddenStatesModel": _HfExamplesInfo(
|
||||
"Qwen/Qwen3-8B",
|
||||
speculative_method="extract_hidden_states",
|
||||
),
|
||||
"Glm4MoeMTPModel": _HfExamplesInfo(
|
||||
"zai-org/GLM-4.5",
|
||||
speculative_model="zai-org/GLM-4.5",
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Predictable dummy model for testing extract_hidden_states.
|
||||
|
||||
Subclasses LlamaForCausalLM but overrides the model to produce deterministic
|
||||
hidden states: layer i outputs values equal to (i).
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class PredictableLlamaModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
# Create minimal embed_tokens for embedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
)
|
||||
|
||||
# Required for pipeline parallelism
|
||||
from vllm.model_executor.models.utils import (
|
||||
make_empty_intermediate_tensors_factory,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], self.config.hidden_size
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Embed input IDs."""
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**extra_layer_kwargs,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Forward pass that produces predictable outputs.
|
||||
|
||||
Returns:
|
||||
If aux_hidden_state_layers is set: (hidden_states, aux_hidden_states)
|
||||
Otherwise: hidden_states
|
||||
"""
|
||||
# Determine sequence length
|
||||
if inputs_embeds is not None:
|
||||
seq_len = inputs_embeds.shape[0]
|
||||
device = inputs_embeds.device
|
||||
elif input_ids is not None:
|
||||
seq_len = input_ids.shape[0] if input_ids.ndim == 1 else input_ids.shape[-1]
|
||||
device = input_ids.device
|
||||
else:
|
||||
raise ValueError("Either input_ids or inputs_embeds must be provided")
|
||||
|
||||
# Final hidden states (last layer value)
|
||||
hidden_states = torch.full(
|
||||
(seq_len, self.config.hidden_size),
|
||||
fill_value=float(self.config.num_hidden_layers),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Check if we need auxiliary hidden states
|
||||
if len(self.aux_hidden_state_layers) > 0:
|
||||
aux_hidden_states = []
|
||||
for layer_idx in self.aux_hidden_state_layers:
|
||||
# Fill with (layer_idx) for predictability
|
||||
layer_hidden = torch.full(
|
||||
(seq_len, self.config.hidden_size),
|
||||
fill_value=float(layer_idx),
|
||||
device=device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
aux_hidden_states.append(layer_hidden)
|
||||
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Skip weight loading."""
|
||||
return set()
|
||||
|
||||
|
||||
class PredictableLlamaForCausalLM(LlamaForCausalLM):
|
||||
"""Predictable Llama model for testing.
|
||||
|
||||
Overrides _init_model to use PredictableLlamaModel instead of LlamaModel.
|
||||
"""
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] | None = None,
|
||||
):
|
||||
"""Initialize with predictable model."""
|
||||
return PredictableLlamaModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Skip weight loading for dummy model."""
|
||||
return set()
|
||||
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import gc
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from vllm import LLM, ModelRegistry, SamplingParams
|
||||
|
||||
|
||||
def get_and_check_output(output, expected_shape):
|
||||
assert output.kv_transfer_params is not None
|
||||
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
|
||||
assert hidden_states_path is not None
|
||||
assert os.path.exists(hidden_states_path)
|
||||
|
||||
# Load and verify the saved tensors
|
||||
with safe_open(hidden_states_path, "pt") as f:
|
||||
# Check that token_ids and hidden_states are present
|
||||
tensor_names = f.keys()
|
||||
assert "token_ids" in tensor_names
|
||||
assert "hidden_states" in tensor_names
|
||||
|
||||
token_ids = f.get_tensor("token_ids")
|
||||
hidden_states = f.get_tensor("hidden_states")
|
||||
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
assert torch.equal(token_ids, torch.tensor(prompt_token_ids))
|
||||
|
||||
assert hidden_states.shape == expected_shape
|
||||
|
||||
# Verify hidden_states are not all zeros (i.e., they were actually computed)
|
||||
assert not torch.allclose(hidden_states, torch.zeros_like(hidden_states))
|
||||
|
||||
return token_ids, hidden_states
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def predictable_llama_config_path(tmp_path_factory):
|
||||
"""Create a minimal LlamaConfig for PredictableLlamaForCausalLM."""
|
||||
from transformers import LlamaConfig, LlamaTokenizerFast
|
||||
|
||||
config_dir = tmp_path_factory.mktemp("predictable_llama")
|
||||
|
||||
# Create a minimal Llama config with small dimensions
|
||||
config = LlamaConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=256,
|
||||
intermediate_size=512,
|
||||
num_hidden_layers=24, # Enough layers to test various layer_ids
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
max_position_embeddings=128,
|
||||
architectures=["PredictableLlamaForCausalLM"],
|
||||
)
|
||||
|
||||
# Save config
|
||||
config.save_pretrained(config_dir)
|
||||
|
||||
# Create a simple tokenizer
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
cache_dir=os.path.expanduser("~/.cache/huggingface"),
|
||||
)
|
||||
tokenizer.save_pretrained(config_dir)
|
||||
|
||||
return str(config_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def register_predictable_model():
|
||||
"""Register the PredictableLlamaForCausalLM model."""
|
||||
from .predictable_llama import PredictableLlamaForCausalLM
|
||||
|
||||
if "PredictableLlamaForCausalLM" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model(
|
||||
"PredictableLlamaForCausalLM", PredictableLlamaForCausalLM
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
def test_extract_hidden_states_with_predictable_dummy_model(
|
||||
predictable_llama_config_path, tmp_path
|
||||
):
|
||||
"""Comprehensive test using a predictable dummy model with synthetic weights.
|
||||
|
||||
The PredictableLlamaForCausalLM outputs deterministic hidden states where
|
||||
each layer produces values equal to (layer_index). This test verifies:
|
||||
1. Hidden states are correctly extracted from requested layers
|
||||
2. Values match the expected predictable pattern
|
||||
3. Layer ordering is preserved correctly (non-sequential layer IDs)
|
||||
4. Multiple prompts of different lengths produce consistent layer values
|
||||
"""
|
||||
# Test with non-sequential layer ordering to verify correct association
|
||||
layer_ids = [5, 2, 10]
|
||||
num_layers = len(layer_ids)
|
||||
|
||||
llm = LLM(
|
||||
model=predictable_llama_config_path,
|
||||
speculative_config={
|
||||
"method": "extract_hidden_states",
|
||||
"num_speculative_tokens": 1,
|
||||
"draft_model_config": {
|
||||
"hf_config": {"eagle_aux_hidden_state_layer_ids": layer_ids}
|
||||
},
|
||||
},
|
||||
kv_transfer_config={
|
||||
"kv_connector": "ExampleHiddenStatesConnector",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_connector_extra_config": {"shared_storage_path": tmp_path},
|
||||
},
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
load_format="dummy", # Don't try to load real weights
|
||||
)
|
||||
|
||||
# Test with multiple prompts of different lengths
|
||||
prompts = [
|
||||
"Short",
|
||||
"Medium length",
|
||||
"Much longer prompt with many tokens",
|
||||
"Much longer prompt with many tokens", # repeated prompt
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=1, temperature=0.0)
|
||||
hidden_size = llm.llm_engine.model_config.get_hidden_size()
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
del llm
|
||||
gc.collect()
|
||||
|
||||
assert len(outputs) == len(prompts)
|
||||
|
||||
for output in outputs:
|
||||
# hidden_states shape is [prompt_len, num_hidden_layers, hidden_size]
|
||||
expected_shape = (
|
||||
len(output.prompt_token_ids),
|
||||
num_layers,
|
||||
hidden_size,
|
||||
)
|
||||
_token_ids, hidden_states = get_and_check_output(output, expected_shape)
|
||||
|
||||
for idx, layer_id in enumerate(layer_ids):
|
||||
layer_hidden = hidden_states[:, idx, :]
|
||||
assert torch.allclose(
|
||||
layer_hidden,
|
||||
torch.full_like(layer_hidden, layer_id),
|
||||
atol=1e-5,
|
||||
), (
|
||||
f"Layer {layer_id} at position {idx} should output {float(layer_id)}, "
|
||||
f"but got mean={layer_hidden.mean():.3f}, "
|
||||
f"min={layer_hidden.min():.3f}, max={layer_hidden.max():.3f}"
|
||||
)
|
||||
346
tests/v1/spec_decode/test_extract_hidden_states.py
Normal file
346
tests/v1/spec_decode/test_extract_hidden_states.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
)
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
model_dir = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
|
||||
|
||||
def _create_proposer(
|
||||
num_speculative_tokens: int = 1,
|
||||
layer_ids: list[int] | None = None,
|
||||
) -> ExtractHiddenStatesProposer:
|
||||
"""Create an ExtractHiddenStatesProposer for testing."""
|
||||
if layer_ids is None:
|
||||
layer_ids = [1, 2, 3, 4]
|
||||
|
||||
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
method="extract_hidden_states",
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
draft_model_config={
|
||||
"hf_config": {
|
||||
"eagle_aux_hidden_state_layer_ids": layer_ids,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
device = current_platform.device_type
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=device),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
attention_config=AttentionConfig(),
|
||||
)
|
||||
|
||||
return ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device)
|
||||
|
||||
|
||||
def test_proposer_initialization():
|
||||
"""Test that the proposer initializes correctly with the right parameters."""
|
||||
layer_ids = [1, 2, 3, 4]
|
||||
proposer = _create_proposer(num_speculative_tokens=1, layer_ids=layer_ids)
|
||||
|
||||
assert proposer.num_hidden_states == len(layer_ids)
|
||||
assert proposer.vllm_config.speculative_config is not None
|
||||
assert proposer.vllm_config.speculative_config.num_speculative_tokens == 1
|
||||
|
||||
# Verify the hidden states buffer is correctly shaped
|
||||
expected_shape = (
|
||||
proposer.max_num_tokens,
|
||||
len(layer_ids),
|
||||
proposer.hidden_size,
|
||||
)
|
||||
assert proposer.hidden_states.shape == expected_shape
|
||||
|
||||
|
||||
def test_proposer_initialization_missing_layer_ids():
|
||||
"""Test that initialization fails when layer_ids are not provided."""
|
||||
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
method="extract_hidden_states",
|
||||
num_speculative_tokens=1,
|
||||
draft_model_config={
|
||||
"hf_config": {} # Missing eagle_aux_hidden_state_layer_ids
|
||||
},
|
||||
)
|
||||
|
||||
device = current_platform.device_type
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=device),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
attention_config=AttentionConfig(),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="eagle_aux_hidden_state_layer_ids must be set"
|
||||
):
|
||||
ExtractHiddenStatesProposer(vllm_config=vllm_config, device=device)
|
||||
|
||||
|
||||
def test_prepare_next_token_ids_padded():
|
||||
"""
|
||||
Test for prepare_next_token_ids_padded with extract_hidden_states.
|
||||
|
||||
Since num_speculative_tokens == 1, sampled_token_ids has shape (batch_size, 1).
|
||||
For each request we either use the sampled token (if valid and not discarded)
|
||||
or a backup token from the request state.
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[5] * num_requests,
|
||||
query_lens=[5] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
mock_request = mock.MagicMock(spec=CachedRequestState)
|
||||
# Each request will have a backup next token id of 10, 20, 30, 40
|
||||
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
|
||||
mock_requests[req_id] = mock_request
|
||||
|
||||
# explicitly discard the last request
|
||||
discarded_req_mask = torch.tensor(
|
||||
[False, False, False, True], dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# With num_speculative_tokens=1, sampled_token_ids has shape [batch_size, 1]
|
||||
sampled_token_ids = torch.tensor(
|
||||
[
|
||||
[1], # valid, use 1
|
||||
[4], # valid, use 4
|
||||
[-1], # invalid, use backup token "30"
|
||||
[2], # explicitly discarded, use backup token "40"
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(
|
||||
expected_next_token_ids_cpu, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
proposer = _create_proposer(num_speculative_tokens=1)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
|
||||
# It doesn't depend on whether the request is discarded
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
[1, 1, 0, 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
discarded_req_mask,
|
||||
)
|
||||
|
||||
assert torch.equal(next_token_ids, expected_next_token_ids_tensor)
|
||||
assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count)
|
||||
|
||||
|
||||
def test_propose():
|
||||
"""
|
||||
Test the propose() method of ExtractHiddenStatesProposer.
|
||||
|
||||
This should:
|
||||
1. Accept target hidden states and sampled token IDs
|
||||
2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1])
|
||||
3. Cache the hidden states in the model's KV cache
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
# Setup test parameters
|
||||
batch_size = 2
|
||||
num_tokens = 5
|
||||
num_hidden_layers = 4
|
||||
|
||||
proposer = _create_proposer(
|
||||
num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers))
|
||||
)
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Create mock model
|
||||
model_mock = mock.MagicMock()
|
||||
proposer.model = model_mock
|
||||
|
||||
# Mock attention layer names
|
||||
proposer.attn_layer_names = ["cache_only_layers.28"]
|
||||
|
||||
# Mock attention metadata builder
|
||||
mock_attn_metadata = mock.MagicMock()
|
||||
mock_attn_metadata_builder = mock.MagicMock()
|
||||
mock_attn_metadata_builder.build_for_drafting.return_value = mock_attn_metadata
|
||||
proposer.attn_metadata_builder = mock_attn_metadata_builder
|
||||
|
||||
# Create input tensors
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[3, 2],
|
||||
query_lens=[3, 2],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create target hidden states: list of tensors, one per layer
|
||||
# Each tensor has shape [num_tokens, hidden_size]
|
||||
target_hidden_states = [
|
||||
torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device)
|
||||
for _ in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
# Sampled token IDs from target model
|
||||
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
|
||||
|
||||
# Mock scheduler output
|
||||
mock_scheduler_output = mock.MagicMock()
|
||||
|
||||
# Call propose
|
||||
with mock.patch(
|
||||
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
|
||||
) as mock_has_kv:
|
||||
mock_has_kv.return_value = False
|
||||
|
||||
draft_tokens, kv_connector_output = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
scheduler_output=mock_scheduler_output,
|
||||
slot_mappings=None,
|
||||
)
|
||||
|
||||
# Verify draft tokens match sampled tokens
|
||||
# Shape should be [batch_size, 1] for num_speculative_tokens=1
|
||||
assert draft_tokens.shape == (batch_size, 1)
|
||||
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
|
||||
|
||||
# Verify the model was called
|
||||
model_mock.assert_called_once()
|
||||
|
||||
# Verify hidden states were copied to the buffer The stacked hidden states
|
||||
# should have shape [num_tokens, num_hidden_layers, hidden_size]
|
||||
expected_stacked = torch.stack(target_hidden_states, dim=1)
|
||||
assert torch.allclose(
|
||||
proposer.hidden_states[:num_tokens], expected_stacked, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_hidden_layers", [1, 4, 8])
|
||||
def test_propose_different_layer_counts(num_hidden_layers):
|
||||
"""Test that propose works correctly with different numbers of hidden layers."""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
batch_size = 2
|
||||
num_tokens = 5
|
||||
|
||||
proposer = _create_proposer(
|
||||
num_speculative_tokens=1, layer_ids=list(range(num_hidden_layers))
|
||||
)
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Setup mocks
|
||||
model_mock = mock.MagicMock()
|
||||
proposer.model = model_mock
|
||||
proposer.attn_layer_names = ["cache_only_layers.28"]
|
||||
|
||||
mock_attn_metadata_builder = mock.MagicMock()
|
||||
mock_attn_metadata_builder.build_for_drafting.return_value = mock.MagicMock()
|
||||
proposer.attn_metadata_builder = mock_attn_metadata_builder
|
||||
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[3, 2],
|
||||
query_lens=[3, 2],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create target hidden states
|
||||
target_hidden_states = [
|
||||
torch.randn(num_tokens, hidden_size, dtype=proposer.dtype, device=device)
|
||||
for _ in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
|
||||
mock_scheduler_output = mock.MagicMock()
|
||||
|
||||
with mock.patch(
|
||||
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
|
||||
) as mock_has_kv:
|
||||
mock_has_kv.return_value = False
|
||||
|
||||
draft_tokens, _ = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
scheduler_output=mock_scheduler_output,
|
||||
slot_mappings=None,
|
||||
)
|
||||
|
||||
assert draft_tokens.shape == (batch_size, 1)
|
||||
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
|
||||
"pangu_ultra_moe_mtp",
|
||||
"step3p5_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
@@ -181,9 +182,22 @@ class SpeculativeConfig:
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
# Eagle3 and extract_hidden_states affect the computation graph because
|
||||
# they return intermediate hidden states in addition to the final hidden state.
|
||||
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
|
||||
factors.append(uses_aux_hidden_states)
|
||||
|
||||
# The specific layers used also affect the computation graph
|
||||
if uses_aux_hidden_states and self.draft_model_config is not None:
|
||||
layer_ids = getattr(
|
||||
self.draft_model_config.hf_config,
|
||||
"eagle_aux_hidden_state_layer_ids",
|
||||
None,
|
||||
)
|
||||
if layer_ids is not None:
|
||||
# Convert to tuple to make it hashable
|
||||
factors.append(tuple(layer_ids))
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@@ -352,6 +366,8 @@ class SpeculativeConfig:
|
||||
self.model = "ngram"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
elif self.method == "extract_hidden_states":
|
||||
self.model = "extract_hidden_states"
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
@@ -394,6 +410,34 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
elif self.method == "extract_hidden_states":
|
||||
from vllm.transformers_utils.configs.extract_hidden_states import (
|
||||
ExtractHiddenStatesConfig,
|
||||
)
|
||||
|
||||
# ExtractHiddenStatesModel is instantiated manually in load_model()
|
||||
# We just need to store the target model config for KV cache shape info
|
||||
self.model = "extract_hidden_states"
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if hasattr(self.draft_model_config, "hf_config"):
|
||||
hf_config = self.draft_model_config.hf_config.to_dict()
|
||||
elif (
|
||||
isinstance(self.draft_model_config, dict)
|
||||
and "hf_config" in self.draft_model_config
|
||||
):
|
||||
hf_config = self.draft_model_config["hf_config"]
|
||||
else:
|
||||
hf_config = {}
|
||||
|
||||
self.draft_model_config = copy.copy(self.target_model_config)
|
||||
self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
|
||||
self.draft_model_config.hf_config, **hf_config
|
||||
)
|
||||
self.update_arch_()
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
@@ -478,23 +522,8 @@ class SpeculativeConfig:
|
||||
method=self.method,
|
||||
model_type="eagle",
|
||||
)
|
||||
# EAGLEConfig primarily updates architectures, so update
|
||||
# all architectures-related fields in draft_model_config
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = (
|
||||
self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
self.update_arch_()
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
@@ -671,6 +700,24 @@ class SpeculativeConfig:
|
||||
)
|
||||
return speculative_draft_tensor_parallel_size
|
||||
|
||||
def update_arch_(self):
|
||||
"""
|
||||
EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
|
||||
architectures-related fields in self.draft_model_config
|
||||
"""
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
@@ -718,7 +765,7 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config
|
||||
)
|
||||
|
||||
eagle3_target_supported = [
|
||||
aux_hidden_states_supported = [
|
||||
"llama",
|
||||
"qwen",
|
||||
"minicpm",
|
||||
@@ -729,16 +776,16 @@ class SpeculativeConfig:
|
||||
"nemotron_h",
|
||||
]
|
||||
if (
|
||||
self.method == "eagle3"
|
||||
self.method in ("eagle3", "extract_hidden_states")
|
||||
and self.target_model_config
|
||||
and not any(
|
||||
supported_model in self.target_model_config.hf_text_config.model_type
|
||||
for supported_model in eagle3_target_supported
|
||||
for supported_model in aux_hidden_states_supported
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
f"{self.method} is only supported for {aux_hidden_states_supported}"
|
||||
f" models. Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
)
|
||||
self.verify_equal_vocab_size_if_draft_model()
|
||||
return self
|
||||
@@ -782,8 +829,15 @@ class SpeculativeConfig:
|
||||
def uses_draft_model(self) -> bool:
|
||||
return self.method == "draft_model"
|
||||
|
||||
def uses_extract_hidden_states(self) -> bool:
|
||||
return self.method == "extract_hidden_states"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
model = (
|
||||
None
|
||||
if method in ("ngram", "suffix", "extract_hidden_states")
|
||||
else self.draft_model_config.model
|
||||
)
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
|
||||
@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
|
||||
def clear_events(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
|
||||
self.add_events(other.get_all_events())
|
||||
return self
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
|
||||
@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
|
||||
"ExampleConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"ExampleHiddenStatesConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
|
||||
"ExampleHiddenStatesConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def extract_from_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""Extract data from KV cache
|
||||
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
|
||||
"""
|
||||
|
||||
padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
|
||||
# shape: [len(slot_mapping), num_heads, head_size]
|
||||
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request ID
|
||||
req_id: str
|
||||
# Request filename
|
||||
filename: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Whether this request is a new request or partially computed already
|
||||
new_req: bool
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool,
|
||||
) -> "ReqMeta":
|
||||
token_ids_tensor = torch.tensor(token_ids)
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
return ReqMeta(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
new_req=new_req,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool = True,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(
|
||||
req_id, filename, token_ids, block_ids, block_size, new_req
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExampleHiddenStatesConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Simple debug implementation of a HiddenStatesConnector.
|
||||
|
||||
Simply extracts the hidden states from the kv cache and stores them to disk.
|
||||
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
|
||||
"""
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
"""
|
||||
Indicates whether this connector prefers KV blocks that hold KV data for all
|
||||
layers, which can speed up KV data transfers. Defaults to False.
|
||||
"""
|
||||
# Must be False so that drafter kv cache isn't merged with verifier's
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
self.cache_layers: list[str] = [] # set by self.register_kv_caches
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
assert self._vllm_config.speculative_config is not None, (
|
||||
"ExampleHiddenStatesConnector only works when using "
|
||||
"'extract_hidden_states' speculative method"
|
||||
)
|
||||
spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.num_hidden_states = len(
|
||||
getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
|
||||
)
|
||||
|
||||
self._request_filenames: dict[str, str] = {}
|
||||
self._active_requests: dict[str, NewRequestData] = {}
|
||||
self._req_blocks: dict[str, list[int]] = {}
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, *args, **kwargs: Any) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_save(self):
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionLayer,
|
||||
)
|
||||
|
||||
# Filter layers to only include CacheOnlyAttentionLayers
|
||||
layers = get_layers_from_vllm_config(
|
||||
self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
|
||||
)
|
||||
self.cache_layers = list(layers.keys())
|
||||
assert len(self.cache_layers) == 1, (
|
||||
f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
|
||||
)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
if layer_name not in self.cache_layers:
|
||||
return
|
||||
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionMetadata,
|
||||
)
|
||||
|
||||
assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
|
||||
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
|
||||
)
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
|
||||
|
||||
os.makedirs(self._storage_path, exist_ok=True)
|
||||
for request in connector_metadata.requests:
|
||||
hidden_states = extract_from_kv_cache(
|
||||
kv_layer, request.slot_mapping, request.token_ids.shape[0]
|
||||
)
|
||||
tensors = {
|
||||
"hidden_states": hidden_states.detach().cpu(),
|
||||
"token_ids": request.token_ids.detach().cpu(),
|
||||
}
|
||||
safetensors.torch.save_file(tensors, request.filename)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# This connector is store-only, so we don't need to load any tokens
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
# Usually used to handle allocation of new blocks for requests that are loading
|
||||
# tokens from connector's external kv cache. We never load from external cache
|
||||
# so this is a no-op.
|
||||
assert num_external_tokens == 0, "This connector is store-only"
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ExampleHiddenStatesConnectorMetadata()
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
|
||||
meta.add_request(
|
||||
new_req.req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self._request_filenames[new_req.req_id] = filename
|
||||
self._active_requests[new_req.req_id] = new_req
|
||||
self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
if req_id not in self._active_requests:
|
||||
continue
|
||||
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
cached_req = self._active_requests[req_id]
|
||||
req_block_ids = self._req_blocks[req_id]
|
||||
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
req_block_ids.extend(block_ids)
|
||||
filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
|
||||
|
||||
meta.add_request(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=cached_req.prompt_token_ids or [],
|
||||
block_ids=req_block_ids,
|
||||
block_size=self._block_size,
|
||||
new_req=False,
|
||||
)
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
req_filename = self._request_filenames.pop(req_id, None)
|
||||
_ = self._active_requests.pop(req_id, None)
|
||||
_ = self._req_blocks.pop(req_id, None)
|
||||
|
||||
return False, {"hidden_states_path": req_filename}
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
# NHD means we have (num_tokens, num_heads)
|
||||
# HND means we have (num_heads, num_tokens)
|
||||
# For now, we only support NHD layout since this keeps the
|
||||
# hidden states for each token together in memory.
|
||||
# HND is primarily used when sharding heads across devices.
|
||||
return "NHD"
|
||||
394
vllm/model_executor/models/extract_hidden_states.py
Normal file
394
vllm/model_executor/models/extract_hidden_states.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Hidden States Extractor Model.
|
||||
|
||||
This model extracts and caches hidden states from the target model
|
||||
without performing actual token generation. It's used with the
|
||||
extract_hidden_states speculative decoding method.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.attention.attention import set_default_quant_scales
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
MLAAttentionSpec,
|
||||
)
|
||||
|
||||
########## Custom Ops ########
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
to_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer,
|
||||
to_cache,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def dummy_attention(layer_name, _placeholder):
|
||||
# Note: layer_name arg required by @maybe_transfer_kv_layer
|
||||
return _placeholder
|
||||
|
||||
|
||||
def basic_cache(
|
||||
to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
|
||||
kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
|
||||
slot_mapping: torch.Tensor, # shape: [seq_len]
|
||||
):
|
||||
num_blocks, block_size, num_heads, head_size = kv_cache.shape
|
||||
token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
|
||||
token_kv_cache[slot_mapping] = to_cache
|
||||
|
||||
|
||||
######### CacheOnlyAttentionBackend ########
|
||||
|
||||
|
||||
class CacheOnlyAttentionBackend(AttentionBackend):
|
||||
"""Attention backend that only caches KV without computing attention."""
|
||||
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"bfloat16",
|
||||
]
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CACHE_ONLY_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
return attn_type == AttentionType.DECODER
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
|
||||
return CacheOnlyAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
# We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
|
||||
# We also don't use a k/v (2) dim
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
|
||||
return CacheOnlyAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
class CacheOnlyAttentionMetadata:
|
||||
def __init__(self, slot_mapping: torch.Tensor):
|
||||
self.slot_mapping = slot_mapping
|
||||
|
||||
|
||||
class CacheOnlyAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[CacheOnlyAttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CacheOnlyAttentionMetadata:
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention not supported by CacheOnlyAttention"
|
||||
)
|
||||
causal = common_attn_metadata.causal
|
||||
if not causal:
|
||||
raise NotImplementedError(
|
||||
"Non-causal attention not supported by CacheOnlyAttention"
|
||||
)
|
||||
|
||||
return CacheOnlyAttentionMetadata(
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
class CacheOnlyAttentionImpl(AttentionImpl):
|
||||
"""Attention implementation that only caches KV states."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_torch_dtype: torch.dtype,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_cache_torch_dtype = kv_cache_torch_dtype
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(f"Unsupported attention type: {attn_type}")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("Quantized KV cache not supported")
|
||||
|
||||
self.num_queries_per_kv = 1
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer,
|
||||
to_cache,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
):
|
||||
assert to_cache.dtype == self.kv_cache_torch_dtype, (
|
||||
f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
|
||||
)
|
||||
assert kv_cache.dtype == self.kv_cache_torch_dtype, (
|
||||
f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
|
||||
)
|
||||
|
||||
basic_cache(to_cache, kv_cache, slot_mapping)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# Empty implementation of abstract method
|
||||
pass
|
||||
|
||||
|
||||
############## CacheOnlyAttentionLayer (replaces Attention) ############
|
||||
|
||||
|
||||
class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer that only caches key/value states without computing attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.layer_name = prefix
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# KV cache configuration
|
||||
cache_config = cache_config or vllm_config.cache_config
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
self.block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
self.block_size = 16
|
||||
|
||||
assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
|
||||
"CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
|
||||
f"kv cache dtype was set to {kv_cache_dtype}"
|
||||
)
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
set_default_quant_scales(self, register_buffer=True)
|
||||
|
||||
# Attention backend
|
||||
self.attn_backend = CacheOnlyAttentionBackend
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
self.kv_cache_torch_dtype,
|
||||
attn_type,
|
||||
)
|
||||
|
||||
assert not self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"KV cache update should be independent of forward"
|
||||
)
|
||||
|
||||
# Placeholder KV cache (replaced by bind_kv_cache)
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Register in compilation context
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
|
||||
"""Cache hidden states as KV pairs without computing attention.
|
||||
|
||||
Args:
|
||||
to_cache: The tensor to insert into the kv cache.
|
||||
shape [num_tokens, num_heads, head_size]
|
||||
|
||||
Returns:
|
||||
Dummy output tensor (not used)
|
||||
"""
|
||||
# Note: we set num_heads to num_hidden_layers and
|
||||
# head_size to hidden_size for hidden states storage
|
||||
output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)
|
||||
|
||||
# Note: dummy_out is used to force torch.compile to preserve ordering between
|
||||
# cache update and attention op (which triggers kv_connector transfer)
|
||||
dummy_out = unified_kv_cache_update(to_cache, self.layer_name)
|
||||
|
||||
# Triggers kv_connector transfer via decorator
|
||||
_ = dummy_attention(self.layer_name, dummy_out)
|
||||
|
||||
return output
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Note: we use MLAAttentionSpec here to because it will
|
||||
# produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
|
||||
# whereas FullAttentionSpec will add an additional factor of 2
|
||||
return MLAAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
############ ExtractHiddenStatesModel definition ##########
|
||||
|
||||
|
||||
class ExtractHiddenStatesModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||
self.target_num_hidden_layers = (
|
||||
vllm_config.model_config.get_total_num_hidden_layers()
|
||||
)
|
||||
self.num_hidden_states = len(
|
||||
getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
|
||||
)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
# Create a single cache-only attention layer
|
||||
# Note: We set num_heads <- self.num_hidden_states
|
||||
# and head_size <- hidden_size so that we can insert
|
||||
# the hidden states directly into the cache without
|
||||
# reshaping
|
||||
self.cache_only_layers = nn.ModuleDict(
|
||||
{
|
||||
str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
|
||||
num_heads=self.num_hidden_states,
|
||||
head_size=self.hidden_size,
|
||||
cache_config=cache_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> None:
|
||||
"""Process and cache hidden states.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states from target model
|
||||
shape: [num_tokens, num_hidden_states, hidden_size]
|
||||
|
||||
Returns:
|
||||
Tuple of (dummy_output, dummy_output) - both unused
|
||||
"""
|
||||
|
||||
# Call dummy attention layer to cache hidden states
|
||||
# Output is ignored - we only care about the KV cache side effects
|
||||
_ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""No weights to load for this dummy model."""
|
||||
return set()
|
||||
@@ -512,6 +512,7 @@ _MULTIMODAL_MODELS = {
|
||||
}
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
"ExtractHiddenStatesModel": ("extract_hidden_states", "ExtractHiddenStatesModel"),
|
||||
"MiMoMTPModel": ("mimo_mtp", "MiMoMTP"),
|
||||
"EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"),
|
||||
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
|
||||
|
||||
53
vllm/transformers_utils/configs/extract_hidden_states.py
Normal file
53
vllm/transformers_utils/configs/extract_hidden_states.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Config definitions for ExtractHiddenStatesModel, to be used with
|
||||
the extract_hidden_states spec decoding method."""
|
||||
|
||||
import os
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ExtractHiddenStatesConfig(PretrainedConfig):
|
||||
model_type = "extract_hidden_states"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PretrainedConfig | dict | None = None,
|
||||
method: str | None = "extract_hidden_states",
|
||||
**kwargs,
|
||||
):
|
||||
assert method == "extract_hidden_states"
|
||||
|
||||
if isinstance(model, dict):
|
||||
model_dict = model
|
||||
elif isinstance(model, PretrainedConfig):
|
||||
model_dict = model.to_dict()
|
||||
else:
|
||||
model_dict = {}
|
||||
|
||||
# Combine: model_dict first, then kwargs override
|
||||
combined = {**model_dict, **kwargs}
|
||||
# Remove architectures from the base, we'll set it explicitly
|
||||
combined = {k: v for k, v in combined.items() if k != "architectures"}
|
||||
|
||||
combined["architectures"] = ["ExtractHiddenStatesModel"]
|
||||
|
||||
super().__init__(**combined)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str | os.PathLike,
|
||||
**kwargs,
|
||||
) -> "ExtractHiddenStatesConfig":
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
def to_json_string(self, use_diff: bool = True) -> str:
|
||||
# we override use_diff to False as initializing
|
||||
# ExtractHiddenStatesConfig with default arguments is not supported
|
||||
del use_diff
|
||||
return super().to_json_string(use_diff=False)
|
||||
@@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
|
||||
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -120,6 +121,20 @@ class SamplerOutput:
|
||||
logprobs_tensors: LogprobsTensors | None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _combine_non_none(f: Callable[[T, T], T], items: list[T | None]) -> T | None:
|
||||
non_none = [item for item in items if item is not None]
|
||||
if len(non_none) == 0:
|
||||
return None
|
||||
|
||||
combined = non_none[0]
|
||||
for item in non_none[1:]:
|
||||
combined = f(combined, item)
|
||||
return combined
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorOutput:
|
||||
# [req_ids]
|
||||
@@ -146,6 +161,43 @@ class KVConnectorOutput:
|
||||
and not self.invalid_block_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, *outputs: "KVConnectorOutput"):
|
||||
assert len(outputs) > 0, "Cannot merge empty outputs"
|
||||
finished_sending = _combine_non_none(
|
||||
set.union, [output.finished_sending for output in outputs]
|
||||
)
|
||||
finished_recving = _combine_non_none(
|
||||
set.union, [output.finished_recving for output in outputs]
|
||||
)
|
||||
kv_connector_stats = _combine_non_none(
|
||||
lambda x, y: x.aggregate(y),
|
||||
[output.kv_connector_stats for output in outputs],
|
||||
)
|
||||
kv_cache_events = _combine_non_none(
|
||||
lambda x, y: x.merge(y),
|
||||
[output.kv_cache_events for output in outputs],
|
||||
)
|
||||
invalid_block_ids = _combine_non_none(
|
||||
set.union, [output.invalid_block_ids for output in outputs]
|
||||
)
|
||||
assert invalid_block_ids is not None
|
||||
|
||||
assert all(
|
||||
output.expected_finished_count == outputs[0].expected_finished_count
|
||||
for output in outputs
|
||||
)
|
||||
expected_finished_count = outputs[0].expected_finished_count
|
||||
|
||||
return cls(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
kv_connector_stats=kv_connector_stats,
|
||||
kv_cache_events=kv_cache_events,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=expected_finished_count,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ECConnectorOutput:
|
||||
|
||||
395
vllm/v1/spec_decode/extract_hidden_states.py
Normal file
395
vllm/v1/spec_decode/extract_hidden_states.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer import has_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
class ExtractHiddenStatesProposer:
|
||||
def __init__(self, vllm_config: VllmConfig, device):
|
||||
assert vllm_config.speculative_config is not None
|
||||
|
||||
assert vllm_config.speculative_config.num_speculative_tokens == 1
|
||||
if vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
raise ValueError(
|
||||
"disable_padded_drafter_batch is not supported with "
|
||||
"extract_hidden_states method"
|
||||
)
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Model and attention layer tracking (initialized in load_model)
|
||||
self.model: nn.Module | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
|
||||
# Maximum number of tokens for buffers
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
|
||||
)
|
||||
|
||||
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
|
||||
if not layer_ids:
|
||||
raise ValueError(
|
||||
"eagle_aux_hidden_state_layer_ids must be set in the draft "
|
||||
"model config for extract_hidden_states method"
|
||||
)
|
||||
self.num_hidden_states = len(layer_ids)
|
||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
|
||||
self._slot_mapping_buffer = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
def propose(
|
||||
self,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
target_hidden_states: list[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
scheduler_output: SchedulerOutput,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
|
||||
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
|
||||
|
||||
The ExtractHiddenStatesModel caches the hidden states in the KV cache
|
||||
without performing actual attention computation. This allows us to
|
||||
extract and store hidden states for later use (e.g., KV transfer).
|
||||
|
||||
This proposer doesn't actually perform speculation - it returns the
|
||||
sampled tokens as "draft" tokens, ensuring they always verify (match).
|
||||
The main purpose is to cache hidden states, not to speculate.
|
||||
|
||||
Args:
|
||||
sampled_token_ids: Sampled token IDs from the target model
|
||||
target_hidden_states: List of hidden state tensors from target model
|
||||
(one per aux hidden state layer)
|
||||
common_attn_metadata: Attention metadata
|
||||
scheduler_output: Scheduler output for KV connector
|
||||
slot_mappings: Slot mappings for KV cache (unused, provided for
|
||||
interface compatibility)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Draft tokens matching sampled tokens, shape [batch_size, 1]
|
||||
- KV connector output (if KV transfer is active), else None
|
||||
"""
|
||||
assert self.model is not None and isinstance(target_hidden_states, list)
|
||||
|
||||
# target_hidden_states is a list of tensors (one per layer)
|
||||
# Each tensor has shape [num_tokens, hidden_size]
|
||||
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
|
||||
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
|
||||
num_tokens = stacked_hidden_states.shape[0]
|
||||
|
||||
# Copy hidden states to buffer
|
||||
self.hidden_states[:num_tokens] = stacked_hidden_states
|
||||
|
||||
assert self.attn_metadata_builder is not None
|
||||
attn_metadata = self.attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
|
||||
# We assume all cache-only layers belong to the same KV cache group,
|
||||
# thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(num_tokens)
|
||||
)
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
with (
|
||||
set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
),
|
||||
(
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
|
||||
if has_kv_transfer_group()
|
||||
else nullcontext()
|
||||
) as kv_connector_output,
|
||||
):
|
||||
self.model(
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
|
||||
# Return the sampled tokens as "draft" tokens
|
||||
# Shape: [batch_size, 1] to match num_speculative_tokens=1
|
||||
return sampled_token_ids.unsqueeze(-1), kv_connector_output
|
||||
|
||||
def _get_slot_mapping(
|
||||
self,
|
||||
num_tokens: int,
|
||||
slot_mapping: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Return slot_mapping dict for cache-only attention layers.
|
||||
|
||||
If slot_mapping is provided, copies it into the buffer first.
|
||||
"""
|
||||
if slot_mapping is not None:
|
||||
num_actual = slot_mapping.shape[0]
|
||||
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
|
||||
if num_tokens > num_actual:
|
||||
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
|
||||
|
||||
view = self._slot_mapping_buffer[:num_tokens]
|
||||
return {name: view for name in self.attn_layer_names}
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens,
|
||||
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
|
||||
)
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to
|
||||
# coordinate across ranks
|
||||
# TODO(Flechman): support DBO ubatching
|
||||
should_ubatch, num_tokens_across_dp = False, None
|
||||
if self.vllm_config.parallel_config.data_parallel_size > 1:
|
||||
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
allow_microbatching=False,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
cudagraph_mode=cudagraph_mode.value,
|
||||
)
|
||||
)
|
||||
assert not should_ubatch, (
|
||||
"DBO ubatching not implemented for extract_hidden_states"
|
||||
)
|
||||
|
||||
# Extract DP-synced values
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.dp_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
# Re-dispatch with DP padding so we have the correct
|
||||
# batch_descriptor
|
||||
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens_padded,
|
||||
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
|
||||
)
|
||||
# Assert to make sure the agreed upon token count is correct
|
||||
# otherwise num_tokens_across_dp will no-longer be valid
|
||||
assert batch_desc.num_tokens == num_tokens_padded
|
||||
num_tokens_across_dp[dp_rank] = num_tokens_padded
|
||||
|
||||
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
"""Initialize cudagraph dispatcher keys.
|
||||
|
||||
Only supports PIECEWISE cudagraphs (via mixed_mode).
|
||||
Should be called after adjust_cudagraph_sizes_for_spec_decode.
|
||||
"""
|
||||
assert self.vllm_config.speculative_config is not None
|
||||
if (
|
||||
not self.vllm_config.speculative_config.enforce_eager
|
||||
and cudagraph_mode.mixed_mode()
|
||||
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
|
||||
):
|
||||
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
proposer_cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs: bool = True,
|
||||
is_graph_capturing: bool = False,
|
||||
slot_mappings: dict[str, torch.Tensor] | None = None,
|
||||
) -> None:
|
||||
assert self.model is not None, "Model must be initialized before dummy_run"
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens, use_cudagraphs=use_cudagraphs
|
||||
)
|
||||
)
|
||||
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
# Use our own slot mapping buffer during cudagraph capture.
|
||||
if (
|
||||
self.attn_layer_names
|
||||
and slot_mappings is not None
|
||||
and self.attn_layer_names[0] in slot_mappings
|
||||
):
|
||||
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
|
||||
else:
|
||||
slot_mapping_dict = slot_mappings or {}
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=slot_mapping_dict,
|
||||
):
|
||||
self.model(
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
)
|
||||
|
||||
def _build_attn_metadata_builder(
|
||||
self, draft_attn_layers: dict[str, AttentionLayerBase]
|
||||
) -> AttentionMetadataBuilder:
|
||||
"""Build the attention metadata builder from draft attention layers."""
|
||||
if not draft_attn_layers:
|
||||
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
|
||||
layer = next(iter(draft_attn_layers.values()))
|
||||
attn_backend = layer.get_attn_backend()
|
||||
return attn_backend.get_builder_cls()(
|
||||
layer.get_kv_cache_spec(self.vllm_config),
|
||||
self.attn_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare next token IDs for speculative decoding.
|
||||
|
||||
Since num_speculative_tokens == 1, sampled_token_ids has shape
|
||||
(batch_size, 1). For each request we either use the sampled token
|
||||
(if valid and not discarded) or a backup token from the request state.
|
||||
"""
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
device = sampled_token_ids.device
|
||||
|
||||
# Compute backup tokens for discarded / invalid requests
|
||||
backup_tokens_gpu = torch.tensor(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item()
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert discard_request_mask.dtype == torch.bool
|
||||
|
||||
# With num_speculative_tokens == 1, there is exactly one token
|
||||
sampled = sampled_token_ids[:, 0]
|
||||
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
|
||||
valid_sampled_tokens_count = is_valid.to(torch.int32)
|
||||
|
||||
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
|
||||
next_token_ids = torch.where(
|
||||
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
"""Load the ExtractHiddenStatesModel model.
|
||||
|
||||
This method instantiates the ExtractHiddenStatesModel model which is used
|
||||
to cache hidden states during speculative decoding. The model uses
|
||||
cache-only attention (no computation, just caching KV states).
|
||||
|
||||
Args:
|
||||
target_model: The target model (passed for compatibility with
|
||||
EagleProposer interface, but not used here)
|
||||
"""
|
||||
# Get the target model's attention layers before loading draft model
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
|
||||
)
|
||||
|
||||
assert self.vllm_config.speculative_config is not None
|
||||
draft_model_config = self.vllm_config.speculative_config.draft_model_config
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
with set_model_tag("extract_hidden_states"):
|
||||
self.model = get_model(
|
||||
vllm_config=self.vllm_config, model_config=draft_model_config
|
||||
)
|
||||
|
||||
# Identify draft model's attention layers (difference from target)
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
)
|
||||
draft_attn_layers = {
|
||||
name: layer
|
||||
for name, layer in all_attn_layers.items()
|
||||
if name not in target_attn_layer_names
|
||||
}
|
||||
self.attn_layer_names = list(draft_attn_layers.keys())
|
||||
assert len(draft_attn_layers) == 1, (
|
||||
"ExtractHiddenStatesModel should have exactly one "
|
||||
f"attention layer, found {len(draft_attn_layers)}"
|
||||
)
|
||||
self.attn_metadata_builder = self._build_attn_metadata_builder(
|
||||
draft_attn_layers
|
||||
)
|
||||
|
||||
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Validate all drafting layers belong to the same KV cache group.
|
||||
|
||||
With exactly one attention layer (asserted in load_model), this is
|
||||
trivially satisfied.
|
||||
"""
|
||||
assert len(self.attn_layer_names) == 1
|
||||
@@ -159,6 +159,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.draft_model import DraftModelProposer
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
@@ -495,6 +496,7 @@ class GPUModelRunner(
|
||||
| EagleProposer
|
||||
| DraftModelProposer
|
||||
| MedusaProposer
|
||||
| ExtractHiddenStatesProposer
|
||||
)
|
||||
if self.speculative_config.method == "ngram":
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
@@ -518,6 +520,11 @@ class GPUModelRunner(
|
||||
self.drafter = MedusaProposer(
|
||||
vllm_config=self.vllm_config, device=self.device
|
||||
)
|
||||
elif self.speculative_config.method == "extract_hidden_states":
|
||||
self.drafter = ExtractHiddenStatesProposer(
|
||||
vllm_config=self.vllm_config, device=self.device
|
||||
)
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown speculative decoding method: "
|
||||
@@ -3693,10 +3700,9 @@ class GPUModelRunner(
|
||||
def sample_tokens(
|
||||
self, grammar_output: "GrammarOutput | None"
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||||
kv_connector_output = self.kv_connector_output
|
||||
self.kv_connector_output = None
|
||||
|
||||
if self.execute_model_state is None:
|
||||
kv_connector_output = self.kv_connector_output
|
||||
self.kv_connector_output = None
|
||||
# receive sampled token ids from the last PP rank.
|
||||
if self.use_async_scheduling and get_pp_group().world_size > 1:
|
||||
self._pp_receive_prev_sampled_token_ids_to_input_batch()
|
||||
@@ -3778,12 +3784,17 @@ class GPUModelRunner(
|
||||
<= self.effective_drafter_max_model_len
|
||||
)
|
||||
use_gpu_toks = (
|
||||
spec_config.use_eagle() or spec_config.uses_draft_model()
|
||||
spec_config.use_eagle()
|
||||
or spec_config.uses_draft_model()
|
||||
or spec_config.uses_extract_hidden_states()
|
||||
) and not spec_config.disable_padded_drafter_batch
|
||||
if use_gpu_toks:
|
||||
# EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
assert isinstance(
|
||||
self.drafter,
|
||||
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
|
||||
)
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
propose_draft_token_ids(sampled_token_ids)
|
||||
@@ -3842,6 +3853,10 @@ class GPUModelRunner(
|
||||
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
||||
self.eplb_step()
|
||||
|
||||
# self.kv_connector_output may be modified during drafting
|
||||
kv_connector_output = self.kv_connector_output
|
||||
self.kv_connector_output = None
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
||||
if self.model_config.enable_return_routed_experts:
|
||||
capturer = RoutedExpertsCapturer.get_instance()
|
||||
@@ -4068,6 +4083,48 @@ class GPUModelRunner(
|
||||
sampling_metadata=sampling_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
elif spec_config.uses_extract_hidden_states():
|
||||
assert isinstance(self.drafter, ExtractHiddenStatesProposer)
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), (
|
||||
"sampled_token_ids should be a torch.Tensor for "
|
||||
"extract_hidden_states method."
|
||||
)
|
||||
if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
|
||||
raise ValueError(
|
||||
"aux_hidden_states are required when using `extract_hidden_states`"
|
||||
)
|
||||
target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
|
||||
|
||||
draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
scheduler_output=scheduler_output,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
# Combine KVConnectorOutputs or select the non-empty one
|
||||
if self.kv_connector_output and drafter_kv_connector_output:
|
||||
self.kv_connector_output = KVConnectorOutput.merge(
|
||||
self.kv_connector_output, drafter_kv_connector_output
|
||||
)
|
||||
else:
|
||||
self.kv_connector_output = (
|
||||
self.kv_connector_output or drafter_kv_connector_output
|
||||
)
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_mask.gpu,
|
||||
)
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
elif spec_config.use_eagle() or spec_config.uses_draft_model():
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
|
||||
@@ -4946,8 +5003,12 @@ class GPUModelRunner(
|
||||
if self.speculative_config and (
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_draft_model()
|
||||
or self.speculative_config.uses_extract_hidden_states()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
assert isinstance(
|
||||
self.drafter,
|
||||
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
|
||||
)
|
||||
assert self.speculative_config is not None
|
||||
# Eagle currently only supports PIECEWISE cudagraphs.
|
||||
# Therefore only use cudagraphs if the main model uses PIECEWISE
|
||||
@@ -5656,9 +5717,12 @@ class GPUModelRunner(
|
||||
cudagraph_mode, self.uniform_decode_query_len
|
||||
)
|
||||
|
||||
# Initialize eagle's cudagraph dispatcher if using eagle spec decode.
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# Initialize drafter's cudagraph dispatcher if using spec decode.
|
||||
if self.speculative_config and (
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_extract_hidden_states()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer)
|
||||
self.drafter.initialize_cudagraph_keys(cudagraph_mode)
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
@@ -6025,8 +6089,12 @@ class GPUModelRunner(
|
||||
if self.speculative_config and (
|
||||
self.speculative_config.use_eagle()
|
||||
or self.speculative_config.uses_draft_model()
|
||||
or self.speculative_config.uses_extract_hidden_states()
|
||||
):
|
||||
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
|
||||
assert isinstance(
|
||||
self.drafter,
|
||||
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
|
||||
)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
Reference in New Issue
Block a user