93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
|
|
from vllm.v1.worker.utils import bind_kv_cache
|
|
|
|
|
|
def test_bind_kv_cache(default_vllm_config):
|
|
from vllm.attention.layer import Attention
|
|
|
|
ctx = {
|
|
"layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"),
|
|
"layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"),
|
|
"layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"),
|
|
"layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"),
|
|
}
|
|
kv_cache = {
|
|
"layers.0.self_attn": torch.zeros((1,)),
|
|
"layers.1.self_attn": torch.zeros((1,)),
|
|
"layers.2.self_attn": torch.zeros((1,)),
|
|
"layers.3.self_attn": torch.zeros((1,)),
|
|
}
|
|
runner_kv_caches: list[torch.Tensor] = []
|
|
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
|
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"]
|
|
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"]
|
|
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"]
|
|
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"]
|
|
|
|
assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
|
|
assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
|
|
assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"]
|
|
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
|
|
|
|
|
|
def test_bind_kv_cache_non_attention(default_vllm_config):
|
|
from vllm.attention.layer import Attention
|
|
|
|
# example from Jamba PP=2
|
|
ctx = {
|
|
"model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"),
|
|
"model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"),
|
|
}
|
|
kv_cache = {
|
|
"model.layers.20.attn": torch.zeros((1,)),
|
|
"model.layers.28.attn": torch.zeros((1,)),
|
|
}
|
|
|
|
runner_kv_caches: list[torch.Tensor] = []
|
|
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
|
|
|
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"]
|
|
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"]
|
|
|
|
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
|
|
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
|
|
|
|
|
|
def test_bind_kv_cache_draft_model(default_vllm_config):
|
|
from vllm.attention.layer import Attention
|
|
|
|
layer_names = [
|
|
"model.layers.0.attn",
|
|
"model.layers.1.attn",
|
|
"draft_model.layers.0.attn",
|
|
"draft_model.layers.1.attn",
|
|
]
|
|
ctx = {
|
|
layer_name: Attention(32, 128, 0.1, prefix=layer_name)
|
|
for layer_name in layer_names
|
|
}
|
|
kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
|
|
runner_kv_caches: list[torch.Tensor] = []
|
|
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
|
|
|
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
|
|
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
|
|
assert (
|
|
ctx["draft_model.layers.0.attn"].kv_cache[0]
|
|
is kv_cache["draft_model.layers.0.attn"]
|
|
)
|
|
assert (
|
|
ctx["draft_model.layers.1.attn"].kv_cache[0]
|
|
is kv_cache["draft_model.layers.1.attn"]
|
|
)
|
|
|
|
# caches are ordered by layer_index, interleaving target and draft model
|
|
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
|
|
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
|
|
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
|
|
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]
|