From 08bfedc152f064d8e84f85c4f42b810e5a564229 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 7 Apr 2026 11:18:33 -0700 Subject: [PATCH] [Bugfix] Fix extract_hidden_states crash with quantized KV cache dtype (#39160) Signed-off-by: Yubo Wang --- vllm/model_executor/models/extract_hidden_states.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py index 608e93d6a..3f1e7e693 100644 --- a/vllm/model_executor/models/extract_hidden_states.py +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -9,6 +9,7 @@ extract_hidden_states speculative decoding method. """ from collections.abc import Iterable +from dataclasses import replace from typing import ClassVar import torch @@ -352,6 +353,10 @@ class ExtractHiddenStatesModel(nn.Module): cache_config = vllm_config.cache_config + # Hidden states dtype should be independent of KV cache dtype. + if cache_config is not None and is_quantized_kv_cache(cache_config.cache_dtype): + cache_config = replace(cache_config, cache_dtype="auto") + # Create a single cache-only attention layer # Note: We set num_heads <- self.num_hidden_states # and head_size <- hidden_size so that we can insert