[Bugfix] Fix extract_hidden_states crash with quantized KV cache dtype (#39160)
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ extract_hidden_states speculative decoding method.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import replace
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
@@ -352,6 +353,10 @@ class ExtractHiddenStatesModel(nn.Module):
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
# Hidden states dtype should be independent of KV cache dtype.
|
||||
if cache_config is not None and is_quantized_kv_cache(cache_config.cache_dtype):
|
||||
cache_config = replace(cache_config, cache_dtype="auto")
|
||||
|
||||
# Create a single cache-only attention layer
|
||||
# Note: We set num_heads <- self.num_hidden_states
|
||||
# and head_size <- hidden_size so that we can insert
|
||||
|
||||
Reference in New Issue
Block a user