[Speculative decoding] Add ngram prompt lookup decoding (#4237)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
@@ -333,13 +333,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
sampler_output.sampled_token_probs = spec_probs
|
||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||
target_token_ids, target_probs = sampler_output_to_torch(
|
||||
[sampler_output])
|
||||
[sampler_output], True)
|
||||
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
non_spec_target_token_ids, non_spec_target_probs = (
|
||||
sampler_output_to_torch([sampler_output]))
|
||||
sampler_output_to_torch([sampler_output], True))
|
||||
|
||||
return (target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@@ -26,29 +25,37 @@ class MultiStepWorker(Worker):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: DraftModelTop1Proposer
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def init_device(self):
|
||||
super().init_device()
|
||||
|
||||
self._proposer = DraftModelTop1Proposer(
|
||||
self._proposer = Top1Proposer(
|
||||
self,
|
||||
self.device,
|
||||
self.max_model_len,
|
||||
self.vocab_size,
|
||||
max_proposal_len=self.max_model_len,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model_multi_step(
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_steps: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run the model forward pass num_steps times. Returns the list of
|
||||
sampler output, one per model forward pass.
|
||||
sample_len: int,
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
whether torch tensor in sampler output need to be transposed in latter
|
||||
sampler_output_to_torch logic.
|
||||
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
@@ -58,12 +65,12 @@ class MultiStepWorker(Worker):
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for num_steps tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
||||
# Assert enough KV space for sample_len tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
|
||||
|
||||
# Run model num_steps times.
|
||||
# Run model sample_len times.
|
||||
model_outputs = []
|
||||
for _ in range(num_steps):
|
||||
for _ in range(sample_len):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
@@ -78,7 +85,7 @@ class MultiStepWorker(Worker):
|
||||
copied_seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs
|
||||
return model_outputs, True
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
@@ -206,171 +213,3 @@ class MultiStepWorker(Worker):
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
||||
|
||||
|
||||
class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
draft_worker: MultiStepWorker,
|
||||
device: str,
|
||||
max_model_len: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
self._draft_worker = draft_worker
|
||||
self._device = device
|
||||
self._max_model_len = max_model_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(proposal_lens, nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices) = self._split_by_max_model_len(
|
||||
seq_group_metadata_list, max_proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
maybe_sampler_output = self._draft_worker.execute_model_multi_step(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_steps=max_proposal_len,
|
||||
)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
max_proposal_len=max_proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
)
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_max_model_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
max_proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Determine which sequences would exceed the max model length.
|
||||
"""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
if seq_len + max_proposal_len < self._max_model_len:
|
||||
proposal_lens.append(max_proposal_len)
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
|
||||
return (proposal_lens, nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(size=(
|
||||
batch_size,
|
||||
max_proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_probs = torch.zeros(batch_size,
|
||||
max_proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(size=(batch_size,
|
||||
*proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (entire_proposal_tokens,
|
||||
entire_proposal_probs)
|
||||
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
190
vllm/spec_decode/ngram_worker.py
Normal file
190
vllm/spec_decode/ngram_worker.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
|
||||
class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implement prompt lookup decoding,
|
||||
and in future we may also do RAG type drafter and other scenerios
|
||||
which don't rely on LLM model to give proposals.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get local_rank/vocab_size from kwargs attribute
|
||||
self.local_rank = kwargs["local_rank"]
|
||||
self.vocab_size = kwargs["model_config"].get_vocab_size()
|
||||
|
||||
# Lazy initialization list.
|
||||
self._proposer: Top1Proposer
|
||||
|
||||
def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
|
||||
ngram_prompt_lookup_max: int):
|
||||
# Search valid candidate window between
|
||||
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
||||
|
||||
def init_device(self):
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.load_model = lambda *args, **kwargs: None
|
||||
|
||||
# Current only support Top1Proposer
|
||||
self._proposer = Top1Proposer(
|
||||
self,
|
||||
device=self.device,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
# NGram don't need gpu sampler
|
||||
pass
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
||||
) -> None:
|
||||
"""NGram doesn't depend on model execution, just pass this function"""
|
||||
pass
|
||||
|
||||
def determine_num_available_blocks(self) -> None:
|
||||
"""NGram doesn't depend on model execution, no need to check blocks"""
|
||||
pass
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""As there is no cache need to handle, just pass this function"""
|
||||
pass
|
||||
|
||||
def get_cache_block_size_bytes(self):
|
||||
"""Return the size of a cache block in bytes."""
|
||||
return 0
|
||||
|
||||
def sampler_output(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
sample_len: int,
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
|
||||
For ngram worker, we already done needed transposed internal, so the
|
||||
indicator pass to sampler_output_to_torch shall be False.
|
||||
"""
|
||||
self._raise_if_unsupported(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
|
||||
arr = []
|
||||
has_spec_out = False
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
|
||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_length = seq_data.get_len()
|
||||
|
||||
for ngram_size in range(
|
||||
min(self.ngram_prompt_lookup_max, input_length - 1),
|
||||
self.ngram_prompt_lookup_min,
|
||||
-1,
|
||||
):
|
||||
ngram_tensor = input_ids[-1 * ngram_size:]
|
||||
windows = input_ids.unfold(dimension=0,
|
||||
size=ngram_size,
|
||||
step=1)
|
||||
matches = (windows == ngram_tensor).all(dim=1)
|
||||
match_indices = matches.nonzero(as_tuple=True)[0]
|
||||
if match_indices.size()[0] > 1:
|
||||
has_spec_out = True
|
||||
res = seq_data.get_token_ids()
|
||||
res = res[match_indices[0] + ngram_size:match_indices[0] +
|
||||
ngram_size + sample_len]
|
||||
res_len = len(res)
|
||||
# pad 0 towards output as sample_len tokens required
|
||||
res += [0] * (sample_len - res_len)
|
||||
|
||||
break
|
||||
else:
|
||||
# if no candidate found, fill with 0
|
||||
res = [0] * sample_len
|
||||
|
||||
arr.append(res)
|
||||
|
||||
if not has_spec_out:
|
||||
return None, False
|
||||
|
||||
outputs = []
|
||||
token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
|
||||
indices = token_ids.unsqueeze(2)
|
||||
|
||||
token_probs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
for i in range(len(seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_probs[i],
|
||||
sampled_token_ids=token_ids[i],
|
||||
))
|
||||
return outputs, False
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
max_proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
max_proposal_len,
|
||||
)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
"""NGramWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"NGramWorker does not support beam search.")
|
||||
@@ -12,6 +12,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
@@ -48,8 +49,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_workers(cls, proposer_worker: MultiStepWorker,
|
||||
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
|
||||
def create_worker(
|
||||
cls,
|
||||
scorer_worker: WorkerBase,
|
||||
draft_worker_kwargs,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
|
||||
ngram_prompt_lookup_max = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
else:
|
||||
ngram_prompt_lookup_max = 0
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
else:
|
||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
@@ -59,7 +79,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proposer_worker: MultiStepWorker,
|
||||
proposer_worker: WorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
rejection_sampler: RejectionSampler,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
@@ -134,8 +154,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
(self.proposer_worker.model_runner.model.sampler.
|
||||
include_gpu_probs_tensor) = True
|
||||
self.proposer_worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
@@ -183,8 +202,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
|
||||
logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
|
||||
num_lookahead_slots)
|
||||
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
|
||||
# num_lookahead_slots)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
@@ -216,7 +235,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer and scorer model so that the KV cache is consistent between the
|
||||
two.
|
||||
"""
|
||||
logger.info("run proposer worker no spec")
|
||||
#logger.info("run proposer worker no spec")
|
||||
|
||||
self.proposer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
@@ -225,7 +244,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
logger.info("run target worker no spec")
|
||||
#logger.info("run target worker no spec")
|
||||
sampler_output = self.scorer_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
@@ -259,7 +278,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
sequence.
|
||||
"""
|
||||
|
||||
logger.info("get spec proposals")
|
||||
#logger.info("get spec proposals")
|
||||
# Generate proposals using draft worker.
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
@@ -268,7 +287,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
||||
blocks_to_copy, k)
|
||||
|
||||
logger.info("score proposals")
|
||||
#logger.info("score proposals")
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in,
|
||||
@@ -278,11 +297,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposals,
|
||||
)
|
||||
|
||||
logger.info("verify proposals")
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
|
||||
proposal_scores, proposals, k)
|
||||
|
||||
logger.info("create output list")
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(seq_group_metadata_list,
|
||||
accepted_token_ids, k)
|
||||
|
||||
|
||||
200
vllm/spec_decode/top1_proposer.py
Normal file
200
vllm/spec_decode/top1_proposer.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
class Top1Proposer(SpeculativeProposer):
|
||||
"""Helper class which separates out sequences which would exceed the max
|
||||
model length when speculated upon.
|
||||
|
||||
This allows combinations of models such as JackFram/llama-68m draft with
|
||||
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
|
||||
2048 while Llama2-13b has max_position_embeddings of 4096.
|
||||
|
||||
We treat the sequences which exceed the proposal draft model length as
|
||||
"non-spec sequences". Essentially they skip the draft model and go through
|
||||
normal decoding in the target model.
|
||||
|
||||
Currently, only proposal_lens of 0 and k are supported, where k is a global
|
||||
batch proposal length. In the future vLLM should support per-sequence
|
||||
proposal lengths.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker: WorkerBase,
|
||||
device: str,
|
||||
vocab_size: int,
|
||||
max_proposal_len: Optional[int] = None,
|
||||
):
|
||||
self._worker = worker
|
||||
self._device = device
|
||||
self.max_proposal_len = max_proposal_len
|
||||
self._vocab_size = vocab_size
|
||||
|
||||
def get_proposals(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
proposal_len: int,
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
Sequences which would exceed the max model length are skipped during
|
||||
speculation.
|
||||
"""
|
||||
|
||||
# Split speculative- and non-speculative- sequences.
|
||||
(
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
|
||||
|
||||
if nonzero_proposal_len_seqs:
|
||||
# Speculate tokens using the draft worker for the speculative
|
||||
# sequences.
|
||||
# If sampler_transposed is true, then maybe_sampler_output's
|
||||
# token_ids is like [batch] format in proposal_len size list,
|
||||
# while if it is false, the format would be [proposal_len]
|
||||
# in batch size list
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
sample_len=proposal_len,
|
||||
)
|
||||
else:
|
||||
# If no sequences can be speculated, set sampler output to None.
|
||||
maybe_sampler_output = None
|
||||
transposed = False
|
||||
|
||||
# Combine speculative- and non-speculative sequences into the same
|
||||
# representation.
|
||||
proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
|
||||
batch_size=len(seq_group_metadata_list),
|
||||
proposal_len=proposal_len,
|
||||
maybe_sampler_output=maybe_sampler_output,
|
||||
proposal_lens=proposal_lens,
|
||||
nonzero_proposal_len_indices=nonzero_proposal_len_indices,
|
||||
sampler_transposed=transposed,
|
||||
)
|
||||
|
||||
proposals = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_tokens,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens,
|
||||
)
|
||||
|
||||
return proposals
|
||||
|
||||
def _split_by_max_model_len(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_len: int,
|
||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||
"""Determine which sequences would exceed the max model length."""
|
||||
|
||||
proposal_lens: List[int] = []
|
||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||
nonzero_proposal_len_indices: List[int] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||
seq_len = seq_data.get_len()
|
||||
|
||||
# Currently only proposal lens of 0 or the global batch proposal len
|
||||
# are supported.
|
||||
# If max_proposal_len is defined, then we shall no exccess this
|
||||
# quota for nonzero_proposal
|
||||
if (self.max_proposal_len is None
|
||||
or seq_len + proposal_len < self.max_proposal_len):
|
||||
proposal_lens.append(proposal_len)
|
||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||
nonzero_proposal_len_indices.append(i)
|
||||
else:
|
||||
proposal_lens.append(0)
|
||||
|
||||
return (
|
||||
proposal_lens,
|
||||
nonzero_proposal_len_seqs,
|
||||
nonzero_proposal_len_indices,
|
||||
)
|
||||
|
||||
def _merge_outputs(
|
||||
self,
|
||||
batch_size: int,
|
||||
proposal_len: int,
|
||||
maybe_sampler_output: Optional[SamplerOutput],
|
||||
proposal_lens: List[int],
|
||||
nonzero_proposal_len_indices: List[int],
|
||||
sampler_transposed: bool,
|
||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
||||
"""After speculations are produced, merge the speculation results with
|
||||
the skipped sequences.
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(
|
||||
size=(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
proposal_lens_tensor = torch.zeros(len(proposal_lens),
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
|
||||
|
||||
entire_proposal_tokens = torch.full(
|
||||
size=(batch_size, *proposal_tokens.shape[1:]),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
|
||||
entire_proposal_probs = torch.zeros(
|
||||
batch_size,
|
||||
*proposal_probs.shape[1:],
|
||||
dtype=torch.float32,
|
||||
device=self._device,
|
||||
)
|
||||
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
|
||||
|
||||
proposal_tokens, proposal_probs = (
|
||||
entire_proposal_tokens,
|
||||
entire_proposal_probs,
|
||||
)
|
||||
|
||||
proposal_lens_tensor = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
|
||||
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
@@ -49,10 +49,13 @@ def split_batch_by_proposal_len(
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: List[SamplerOutput],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
sampler_output_list: List[SamplerOutput],
|
||||
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
sampler_transposed here is used as the indicator for whether
|
||||
we need do additional tensor transpose logic here.
|
||||
|
||||
Returns:
|
||||
sampled_token_ids: torch.Tensor
|
||||
shape: [batch_size, len(sampler_output_list)]
|
||||
@@ -68,7 +71,10 @@ def sampler_output_to_torch(
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
).transpose(0, 1)
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
@@ -77,7 +83,9 @@ def sampler_output_to_torch(
|
||||
for sampler_output in sampler_output_list
|
||||
],
|
||||
dim=0,
|
||||
).transpose(0, 1)
|
||||
)
|
||||
if sampler_transposed:
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
return sampled_token_ids, sampled_token_probs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user