feat: spec decode with draft models (#24322)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
|
||||
|
||||
|
||||
def test_bind_kv_cache_draft_model(default_vllm_config):
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
layer_names = [
|
||||
"model.layers.0.attn",
|
||||
"model.layers.1.attn",
|
||||
"draft_model.layers.0.attn",
|
||||
"draft_model.layers.1.attn",
|
||||
]
|
||||
ctx = {
|
||||
layer_name: Attention(32, 128, 0.1, prefix=layer_name)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
|
||||
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
|
||||
assert (
|
||||
ctx["draft_model.layers.0.attn"].kv_cache[0]
|
||||
is kv_cache["draft_model.layers.0.attn"]
|
||||
)
|
||||
assert (
|
||||
ctx["draft_model.layers.1.attn"].kv_cache[0]
|
||||
is kv_cache["draft_model.layers.1.attn"]
|
||||
)
|
||||
|
||||
# caches are ordered by layer_index, interleaving target and draft model
|
||||
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
|
||||
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
|
||||
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]
|
||||
|
||||
Reference in New Issue
Block a user