[Bugfix] Fix extract_hidden_states crash with quantized KV cache dtype (#39160)

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
Yubo Wang
2026-04-07 11:18:33 -07:00
committed by GitHub
parent 0102bd2f4c
commit 08bfedc152

View File

@@ -9,6 +9,7 @@ extract_hidden_states speculative decoding method.
"""
from collections.abc import Iterable
from dataclasses import replace
from typing import ClassVar
import torch
@@ -352,6 +353,10 @@ class ExtractHiddenStatesModel(nn.Module):
cache_config = vllm_config.cache_config
# Hidden states dtype should be independent of KV cache dtype.
if cache_config is not None and is_quantized_kv_cache(cache_config.cache_dtype):
cache_config = replace(cache_config, cache_dtype="auto")
# Create a single cache-only attention layer
# Note: We set num_heads <- self.num_hidden_states
# and head_size <- hidden_size so that we can insert