feat: spec decode with draft models (#24322)
Signed-off-by: Tomas Ruiz <tomas.ruiz.te@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
|
||||
@@ -329,6 +329,16 @@ class CommonAttentionMetadata:
|
||||
|
||||
_num_computed_tokens_cache: torch.Tensor | None = None
|
||||
|
||||
def batch_size(self) -> int:
|
||||
return self.seq_lens.shape[0]
|
||||
|
||||
def naive_query_lens(self) -> torch.Tensor:
|
||||
"""Naive because it assumes that query ends where the next query starts."""
|
||||
return self.query_start_loc[1:] - self.query_start_loc[:-1]
|
||||
|
||||
def replace(self, **kwargs) -> "CommonAttentionMetadata":
|
||||
return replace(self, **kwargs)
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"""
|
||||
|
||||
@@ -818,3 +818,35 @@ def get_dcp_local_seq_lens(
|
||||
)
|
||||
dcp_local_seq_lens = base + remainder
|
||||
return dcp_local_seq_lens.squeeze(1)
|
||||
|
||||
|
||||
def extend_all_queries_by_1(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
arange: torch.Tensor,
|
||||
new_slot_mapping: torch.Tensor,
|
||||
) -> CommonAttentionMetadata:
|
||||
"""
|
||||
Creates a new CommonAttentionMetadata with all query lengths increased by 1.
|
||||
Also all seq lens are increased by 1.
|
||||
This is useful e.g. in speculative decoding with draft models, where we
|
||||
extend each sequence by 1 token.
|
||||
The slot mapping is computed externally, as it requires more information.
|
||||
"""
|
||||
cad = common_attn_metadata
|
||||
# query start loc must be increased by [+0, +1, +2, ..., +batch_size]
|
||||
new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)]
|
||||
new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange(
|
||||
len(cad.query_start_loc_cpu), dtype=torch.int32
|
||||
)
|
||||
new_cad = cad.replace(
|
||||
query_start_loc=new_query_start_loc,
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens=cad.seq_lens + 1,
|
||||
# each request is extended by 1 token -> batch_size tokens are added
|
||||
num_actual_tokens=cad.num_actual_tokens + cad.batch_size(),
|
||||
# All query lens increase by 1, so max query len increases by 1
|
||||
max_query_len=cad.max_query_len + 1,
|
||||
max_seq_len=cad.max_seq_len + 1,
|
||||
slot_mapping=new_slot_mapping,
|
||||
)
|
||||
return new_cad
|
||||
|
||||
Reference in New Issue
Block a user