[V0 Deprecation] Refactor kv cache from list to element (#37487)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -670,8 +670,8 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config):
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache
|
||||
# check layer 1 kv cache does NOT share memory with layer 0
|
||||
assert id(layer_1_kv) != id(layer_0_kv)
|
||||
|
||||
@@ -740,8 +740,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
kv_cache_config_after_init = runner.kv_cache_config
|
||||
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||
layer_0_kv = vllm_ctx[layer_0].kv_cache
|
||||
layer_1_kv = vllm_ctx[layer_1].kv_cache
|
||||
# check layer 1 kv cache shares memory with layer 0
|
||||
assert id(layer_1_kv) == id(layer_0_kv)
|
||||
|
||||
@@ -864,9 +864,9 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
np.random.shuffle(ind)
|
||||
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
||||
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache.shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0].shape
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[1].shape
|
||||
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] % num_blocks == 0
|
||||
@@ -905,21 +905,21 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||
|
||||
for layer in [layer_0, layer_1]:
|
||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||
# attention: kv_cache[kernel_block_idx, kv_idx, ...]
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[kernel_block, :] = attn_blocks_constant[i]
|
||||
|
||||
# fill mamba blocks with constants using kernel block indices
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||
# mamba: kv_cache[component][kernel_block_idx, ...]
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[1][kv_block, :] = ssm_blocks_constant[i]
|
||||
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||
actual_kv = vllm_ctx[layer].kv_cache[kernel_block, :]
|
||||
expected = attn_blocks_constant[i]
|
||||
|
||||
# Check K and V separately
|
||||
@@ -928,8 +928,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
|
||||
@@ -938,8 +938,8 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
|
||||
Reference in New Issue
Block a user