# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.v1.worker.utils import bind_kv_cache def test_bind_kv_cache(default_vllm_config): from vllm.model_executor.layers.attention import Attention ctx = { "layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"), "layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"), "layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"), "layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"), } kv_cache = { "layers.0.self_attn": torch.zeros((1,)), "layers.1.self_attn": torch.zeros((1,)), "layers.2.self_attn": torch.zeros((1,)), "layers.3.self_attn": torch.zeros((1,)), } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"] assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"] assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"] assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"] assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"] assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"] assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"] assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"] def test_bind_kv_cache_non_attention(default_vllm_config): from vllm.model_executor.layers.attention import Attention # example from Jamba PP=2 ctx = { "model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"), "model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"), } kv_cache = { "model.layers.20.attn": torch.zeros((1,)), "model.layers.28.attn": torch.zeros((1,)), } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"] assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"] assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] def test_bind_kv_cache_draft_model(default_vllm_config): from vllm.model_executor.layers.attention import Attention layer_names = [ "model.layers.0.attn", "model.layers.1.attn", "draft_model.layers.0.attn", "draft_model.layers.1.attn", ] ctx = { layer_name: Attention(32, 128, 0.1, prefix=layer_name) for layer_name in layer_names } kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names} runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] assert ( ctx["draft_model.layers.0.attn"].kv_cache[0] is kv_cache["draft_model.layers.0.attn"] ) assert ( ctx["draft_model.layers.1.attn"].kv_cache[0] is kv_cache["draft_model.layers.1.attn"] ) # caches are ordered by layer_index, interleaving target and draft model assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"] assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"] assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"] assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]