[BugFix] Assign page_size_padded when unifying kv cache spec. (#32283)
Signed-off-by: Lihao Ran <imlihao.ran@gmail.com>
This commit is contained in:
@@ -36,6 +36,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
tensor_data,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
@@ -102,29 +103,57 @@ def new_kv_cache_spec(
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
page_size_padded=None,
|
||||
sliding_window=None,
|
||||
attention_chunk_size=None,
|
||||
):
|
||||
return FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
page_size_padded=page_size_padded,
|
||||
sliding_window=sliding_window,
|
||||
attention_chunk_size=attention_chunk_size,
|
||||
)
|
||||
|
||||
|
||||
def new_sliding_window_spec(
|
||||
block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, sliding_window=1
|
||||
block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
page_size_padded=None,
|
||||
sliding_window=1,
|
||||
):
|
||||
return SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
page_size_padded=page_size_padded,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
|
||||
def new_chunked_local_attention_spec(
|
||||
block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
page_size_padded=None,
|
||||
attention_chunk_size=4,
|
||||
):
|
||||
return ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
page_size_padded=page_size_padded,
|
||||
attention_chunk_size=attention_chunk_size,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_none_hash(monkeypatch, hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
@@ -1855,3 +1884,70 @@ def test_auto_fit_max_model_len_not_triggered():
|
||||
vllm_config, [kv_cache_specs], [mem_per_block_per_layer * 2 * 32]
|
||||
)
|
||||
assert vllm_config.model_config.max_model_len == 16
|
||||
|
||||
|
||||
def test_unify_hybrid_kv_cache_specs():
|
||||
# 1. has_full_attention and has_sliding_window
|
||||
before_spec_1 = new_kv_cache_spec()
|
||||
before_spec_2 = new_sliding_window_spec(
|
||||
page_size_padded=32 * 1024, sliding_window=1024
|
||||
)
|
||||
kv_cache_spec = {
|
||||
"layer_1": before_spec_1,
|
||||
"layer_2": before_spec_2,
|
||||
}
|
||||
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
expected_spec_1 = new_kv_cache_spec()
|
||||
expected_spec_2 = new_kv_cache_spec(page_size_padded=32 * 1024, sliding_window=1024)
|
||||
assert kv_cache_spec["layer_1"] == expected_spec_1
|
||||
assert kv_cache_spec["layer_2"] == expected_spec_2
|
||||
|
||||
# 2. has_full_attention and has_chunked_local_attention
|
||||
before_spec_1 = new_kv_cache_spec()
|
||||
before_spec_2 = new_chunked_local_attention_spec(
|
||||
page_size_padded=32 * 1024, attention_chunk_size=512
|
||||
)
|
||||
kv_cache_spec = {
|
||||
"layer_1": before_spec_1,
|
||||
"layer_2": before_spec_2,
|
||||
}
|
||||
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
expected_spec_1 = new_kv_cache_spec()
|
||||
expected_spec_2 = new_kv_cache_spec(
|
||||
page_size_padded=32 * 1024, attention_chunk_size=512
|
||||
)
|
||||
|
||||
assert kv_cache_spec["layer_1"] == expected_spec_1
|
||||
assert kv_cache_spec["layer_2"] == expected_spec_2
|
||||
|
||||
# 3. has_full_attention, has_sliding_window and has_chunked_local_attention
|
||||
before_spec_1 = new_kv_cache_spec()
|
||||
before_spec_2 = new_sliding_window_spec(
|
||||
page_size_padded=32 * 1024, sliding_window=1024
|
||||
)
|
||||
before_spec_3 = new_chunked_local_attention_spec(
|
||||
page_size_padded=32 * 1024, attention_chunk_size=512
|
||||
)
|
||||
kv_cache_spec = {
|
||||
"layer_1": before_spec_1,
|
||||
"layer_2": before_spec_2,
|
||||
"layer_3": before_spec_3,
|
||||
}
|
||||
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
expected_spec_1 = new_kv_cache_spec()
|
||||
expected_spec_2 = new_kv_cache_spec(page_size_padded=32 * 1024, sliding_window=1024)
|
||||
expected_spec_3 = new_kv_cache_spec(
|
||||
page_size_padded=32 * 1024, attention_chunk_size=512
|
||||
)
|
||||
assert kv_cache_spec["layer_1"] == expected_spec_1
|
||||
assert kv_cache_spec["layer_2"] == expected_spec_2
|
||||
assert kv_cache_spec["layer_3"] == expected_spec_3
|
||||
|
||||
# 4. No FullAttentionSpec, should not convert
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_sliding_window_spec(sliding_window=1024),
|
||||
"layer_2": new_chunked_local_attention_spec(attention_chunk_size=512),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
kv_cache_utils.unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
@@ -1185,6 +1185,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
sliding_window=spec.sliding_window,
|
||||
page_size_padded=spec.page_size_padded,
|
||||
)
|
||||
elif isinstance(spec, ChunkedLocalAttentionSpec):
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
@@ -1193,6 +1194,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
attention_chunk_size=spec.attention_chunk_size,
|
||||
page_size_padded=spec.page_size_padded,
|
||||
)
|
||||
|
||||
if not (
|
||||
|
||||
Reference in New Issue
Block a user