[V0 Deprecation] Refactor kv cache from list to element (#37487)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-23 23:10:11 -04:00
committed by GitHub
parent de99d91ece
commit c59a132f96
27 changed files with 70 additions and 85 deletions

View File

@@ -670,8 +670,8 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config):
runner.initialize_kv_cache(kv_cache_config)
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
layer_0_kv = vllm_ctx[layer_0].kv_cache
layer_1_kv = vllm_ctx[layer_1].kv_cache
# check layer 1 kv cache does NOT share memory with layer 0
assert id(layer_1_kv) != id(layer_0_kv)
@@ -740,8 +740,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
runner.initialize_kv_cache(kv_cache_config)
kv_cache_config_after_init = runner.kv_cache_config
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
layer_0_kv = vllm_ctx[layer_0].kv_cache
layer_1_kv = vllm_ctx[layer_1].kv_cache
# check layer 1 kv cache shares memory with layer 0
assert id(layer_1_kv) == id(layer_0_kv)
@@ -864,9 +864,9 @@ def test_hybrid_attention_mamba_tensor_shapes():
np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
attn_shape = vllm_ctx[layer_0].kv_cache.shape
conv_shape = vllm_ctx[layer_2].kv_cache[0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[1].shape
# assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0
@@ -905,21 +905,21 @@ def test_hybrid_attention_mamba_tensor_shapes():
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
# attention: kv_cache[kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
vllm_ctx[layer].kv_cache[kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
# mamba: kv_cache[component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
actual_kv = vllm_ctx[layer].kv_cache[kernel_block, :]
expected = attn_blocks_constant[i]
# Check K and V separately
@@ -928,8 +928,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
@@ -938,8 +938,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)