[Bugfix] Fix extract_hidden_states crash with quantized KV cache dtype (#39160)
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ extract_hidden_states speculative decoding method.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import replace
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -352,6 +353,10 @@ class ExtractHiddenStatesModel(nn.Module):
|
|||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
|
# Hidden states dtype should be independent of KV cache dtype.
|
||||||
|
if cache_config is not None and is_quantized_kv_cache(cache_config.cache_dtype):
|
||||||
|
cache_config = replace(cache_config, cache_dtype="auto")
|
||||||
|
|
||||||
# Create a single cache-only attention layer
|
# Create a single cache-only attention layer
|
||||||
# Note: We set num_heads <- self.num_hidden_states
|
# Note: We set num_heads <- self.num_hidden_states
|
||||||
# and head_size <- hidden_size so that we can insert
|
# and head_size <- hidden_size so that we can insert
|
||||||
|
|||||||
Reference in New Issue
Block a user