[Bugfix][Spec Decode] Fix extract_hidden_states for VLM models (#38987)
Signed-off-by: Aaron Batilo <abatilo@coreweave.com>
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPVisionConfig, LlamaConfig, LlavaConfig, PretrainedConfig
|
||||
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
@@ -23,6 +25,10 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_hf_text_config
|
||||
from vllm.transformers_utils.configs.extract_hidden_states import (
|
||||
ExtractHiddenStatesConfig,
|
||||
)
|
||||
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@@ -323,3 +329,160 @@ def test_propose_different_layer_counts(num_hidden_layers):
|
||||
|
||||
assert draft_tokens.shape == (batch_size, 1)
|
||||
assert torch.equal(draft_tokens, sampled_token_ids)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VLM / composite config tests for ExtractHiddenStatesConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _DummyVLMConfig(PretrainedConfig):
|
||||
"""Minimal composite config that mimics VLMs like Kimi-K2.5 or LLaVA.
|
||||
|
||||
The text model's parameters (hidden_size, num_attention_heads, …) live
|
||||
exclusively under ``text_config``; the top-level config has none of them.
|
||||
"""
|
||||
|
||||
model_type = "test_vlm"
|
||||
|
||||
def __init__(self, text_config: PretrainedConfig, **kwargs):
|
||||
self.text_config = text_config
|
||||
super().__init__(architectures=["LlamaForCausalLM"], **kwargs)
|
||||
|
||||
def get_text_config(self, decoder: bool = False) -> PretrainedConfig:
|
||||
del decoder
|
||||
return self.text_config
|
||||
|
||||
|
||||
def test_extract_hidden_states_text_only_config_regression():
|
||||
"""Text-only models (no nested text_config) must keep working."""
|
||||
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
method="extract_hidden_states",
|
||||
num_speculative_tokens=1,
|
||||
draft_model_config={
|
||||
"hf_config": {
|
||||
"eagle_aux_hidden_state_layer_ids": [1, 2, 3, 4],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert speculative_config.draft_model_config is not None
|
||||
# For text-only models, hf_text_config should be the config itself.
|
||||
assert speculative_config.draft_model_config.hf_text_config is (
|
||||
speculative_config.draft_model_config.hf_config
|
||||
)
|
||||
assert (
|
||||
speculative_config.draft_model_config.hf_text_config.num_attention_heads
|
||||
== model_config.hf_text_config.num_attention_heads
|
||||
)
|
||||
|
||||
|
||||
def test_extract_hidden_states_config_preserves_vlm_text_config():
|
||||
"""A real VLM config (LLaVA) with nested text_config must be preserved."""
|
||||
text_config = LlamaConfig(
|
||||
vocab_size=32000,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=8,
|
||||
)
|
||||
vlm_config = LlavaConfig(
|
||||
vision_config=CLIPVisionConfig(),
|
||||
text_config=text_config,
|
||||
)
|
||||
|
||||
# Precondition: to_dict() flattens the nested config to a plain dict.
|
||||
assert isinstance(vlm_config.to_dict()["text_config"], dict)
|
||||
|
||||
extract_config = ExtractHiddenStatesConfig(
|
||||
vlm_config,
|
||||
eagle_aux_hidden_state_layer_ids=[1, 2],
|
||||
)
|
||||
|
||||
# The fix: text_config is still a PretrainedConfig, not a dict.
|
||||
assert isinstance(extract_config.text_config, LlamaConfig)
|
||||
|
||||
extracted = get_hf_text_config(extract_config)
|
||||
assert extracted is extract_config.text_config
|
||||
assert extracted.num_attention_heads == text_config.num_attention_heads
|
||||
assert extracted.hidden_size == text_config.hidden_size
|
||||
|
||||
# Serialization must still round-trip correctly.
|
||||
serialized = extract_config.to_dict()
|
||||
assert isinstance(serialized["text_config"], dict)
|
||||
assert serialized["text_config"]["num_attention_heads"] == (
|
||||
text_config.num_attention_heads
|
||||
)
|
||||
|
||||
json_str = json.loads(extract_config.to_json_string())
|
||||
assert json_str["text_config"]["num_attention_heads"] == (
|
||||
text_config.num_attention_heads
|
||||
)
|
||||
|
||||
|
||||
def test_extract_hidden_states_speculative_config_vlm():
|
||||
"""SpeculativeConfig with a VLM target must build without errors."""
|
||||
nested_text_config = LlamaConfig(
|
||||
vocab_size=32000,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=8,
|
||||
)
|
||||
|
||||
target_model_config = ModelConfig(
|
||||
model=model_dir,
|
||||
runner="generate",
|
||||
max_model_len=100,
|
||||
)
|
||||
# Replace the real text-only config with our composite VLM config.
|
||||
target_model_config.hf_config = _DummyVLMConfig(
|
||||
text_config=nested_text_config,
|
||||
)
|
||||
target_model_config.hf_text_config = nested_text_config
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=target_model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
method="extract_hidden_states",
|
||||
num_speculative_tokens=1,
|
||||
draft_model_config={
|
||||
"hf_config": {
|
||||
"eagle_aux_hidden_state_layer_ids": [1, 2],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert speculative_config.draft_model_config is not None
|
||||
assert isinstance(
|
||||
speculative_config.draft_model_config.hf_config.text_config,
|
||||
LlamaConfig,
|
||||
)
|
||||
assert speculative_config.draft_model_config.hf_text_config is (
|
||||
speculative_config.draft_model_config.hf_config.text_config
|
||||
)
|
||||
assert (
|
||||
speculative_config.draft_model_config.hf_text_config.num_attention_heads
|
||||
== nested_text_config.num_attention_heads
|
||||
)
|
||||
|
||||
|
||||
def test_extract_hidden_states_config_invalid_text_config():
|
||||
"""A nested text_config missing required attrs must still be rejected."""
|
||||
broken_text_config = PretrainedConfig(hidden_size=128)
|
||||
vlm_config = _DummyVLMConfig(text_config=broken_text_config)
|
||||
|
||||
extract_config = ExtractHiddenStatesConfig(
|
||||
vlm_config,
|
||||
eagle_aux_hidden_state_layer_ids=[1],
|
||||
)
|
||||
|
||||
# The object is preserved (not flattened), …
|
||||
assert extract_config.text_config is broken_text_config
|
||||
# … but validation still rejects the missing attribute.
|
||||
with pytest.raises(ValueError, match="num_attention_heads"):
|
||||
get_hf_text_config(extract_config)
|
||||
|
||||
@@ -23,10 +23,14 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
|
||||
|
||||
if isinstance(model, dict):
|
||||
model_dict = model
|
||||
source_text_config = None
|
||||
elif isinstance(model, PretrainedConfig):
|
||||
model_dict = model.to_dict()
|
||||
text_config = model.get_text_config()
|
||||
source_text_config = text_config if text_config is not model else None
|
||||
else:
|
||||
model_dict = {}
|
||||
source_text_config = None
|
||||
|
||||
# Combine: model_dict first, then kwargs override
|
||||
combined = {**model_dict, **kwargs}
|
||||
@@ -35,6 +39,12 @@ class ExtractHiddenStatesConfig(PretrainedConfig):
|
||||
|
||||
combined["architectures"] = ["ExtractHiddenStatesModel"]
|
||||
|
||||
# to_dict() and kwargs both flatten text_config to a plain dict;
|
||||
# downstream get_hf_text_config() needs it as a PretrainedConfig
|
||||
# for attribute access. Re-insert the original object.
|
||||
if source_text_config is not None:
|
||||
combined["text_config"] = source_text_config
|
||||
|
||||
super().__init__(**combined)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user