feat: spec decode with draft models (#24322)

Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
Tomas Ruiz
2026-01-19 15:05:46 -06:00
committed by GitHub
parent 73f2a81c75
commit 4a5299c93f
21 changed files with 897 additions and 115 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
_num_computed_tokens_cache: torch.Tensor | None = None
def batch_size(self) -> int:
return self.seq_lens.shape[0]
def naive_query_lens(self) -> torch.Tensor:
"""Naive because it assumes that query ends where the next query starts."""
return self.query_start_loc[1:] - self.query_start_loc[:-1]
def replace(self, **kwargs) -> "CommonAttentionMetadata":
return replace(self, **kwargs)
@property
@deprecated(
"""

View File

@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
)
dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1)
def extend_all_queries_by_1(
common_attn_metadata: CommonAttentionMetadata,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
Also all seq lens are increased by 1.
This is useful e.g. in speculative decoding with draft models, where we
extend each sequence by 1 token.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + 1,
# each request is extended by 1 token -> batch_size tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size(),
# All query lens increase by 1, so max query len increases by 1
max_query_len=cad.max_query_len + 1,
max_seq_len=cad.max_seq_len + 1,
slot_mapping=new_slot_mapping,
)
return new_cad