feat: spec decode with draft models (#24322)

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
Tomas Ruiz
2026-01-19 15:05:46 -06:00
committed by GitHub
parent 73f2a81c75
commit 4a5299c93f
21 changed files with 897 additions and 115 deletions

View File

@@ -55,3 +55,38 @@ def test_bind_kv_cache_non_attention(default_vllm_config):
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
def test_bind_kv_cache_draft_model(default_vllm_config):
from vllm.attention.layer import Attention
layer_names = [
"model.layers.0.attn",
"model.layers.1.attn",
"draft_model.layers.0.attn",
"draft_model.layers.1.attn",
]
ctx = {
layer_name: Attention(32, 128, 0.1, prefix=layer_name)
for layer_name in layer_names
}
kv_cache = {layer_name: torch.zeros((1,)) for layer_name in layer_names}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"]
assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"]
assert (
ctx["draft_model.layers.0.attn"].kv_cache[0]
is kv_cache["draft_model.layers.0.attn"]
)
assert (
ctx["draft_model.layers.1.attn"].kv_cache[0]
is kv_cache["draft_model.layers.1.attn"]
)
# caches are ordered by layer_index, interleaving target and draft model
assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"]
assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"]
assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"]
assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"]