[Speculative decoding] Support target-model logprobs (#4378)

This commit is contained in:
Cade Daniel
2024-05-03 15:52:01 -07:00
committed by GitHub
parent 43c413ec57
commit ab50275111
15 changed files with 727 additions and 86 deletions

View File

@@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch(
all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
@@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=spec_logprobs,
)
def _expand_batch(
@@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
def _contract_batch(self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
def _contract_batch(
self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
@@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
"""
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
(target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
@@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)
all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
@@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
return all_tokens, all_probs
return all_tokens, all_probs, all_logprobs
def _create_scoring_model_input(
self,
@@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""Split the target model output into speculative and non-speculative
output.
"""
@@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(
spec_logprobs,
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
target_token_ids, target_probs = sampler_output_to_torch(
[sampler_output], True)
sampler_output.logprobs = spec_logprobs
(target_token_ids, target_probs,
target_logprobs) = sampler_output_to_torch([sampler_output], True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
non_spec_target_token_ids, non_spec_target_probs = (
sampler_output_to_torch([sampler_output], True))
sampler_output.logprobs = non_spec_logprobs
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
True)
return (target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs)
return (target_token_ids, target_probs, target_logprobs,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs)
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:

View File

@@ -38,6 +38,11 @@ class SpeculativeScores:
# Probabilities of the speculative tokens according to the scoring model.
probs: torch.Tensor
# Log-probabilities of the speculative tokens according to the scoring
# model. These values can be used to generate Logprob objects that are
# returned to the user.
logprobs: torch.Tensor
# Token ids sampled from the scoring model. Used for speculative bonus
# tokens and also non-speculative normal decoding.
token_ids: torch.Tensor

View File

@@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase):
device=self.device,
)
token_probs.scatter_(2, indices, 1)
token_logprobs = torch.zeros(
(len(seq_group_metadata_list), sample_len, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
for i in range(len(seq_group_metadata_list)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_probs[i],
logprobs=token_logprobs,
sampled_token_ids=token_ids[i],
))
return outputs, False

View File

@@ -5,15 +5,16 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
@@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
sampler_output.logprobs = None
return [sampler_output]
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
@@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
)
#logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
accepted_token_ids, target_logprobs = self._verify_tokens(
seq_group_metadata_list, proposal_scores, proposals, k)
#logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
return self._create_output_sampler_list(
seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=k)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
@@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list = proposals.proposal_lens.tolist()
@@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
accepted_token_ids[original_indices] = accepted_token_ids.clone()
return accepted_token_ids
return accepted_token_ids, logprobs
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int,
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
@@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
seq_ids = get_all_seq_ids(seq_group_metadata_list)
batch_size, num_steps = accepted_token_ids.shape
# shape: [k+1, batch_size]
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
1).tolist()
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids = get_all_seq_ids(seq_group_metadata_list)
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list = []
for token_ids_by_step in accepted_token_ids_by_step:
if all(token_id == -1 for token_id in token_ids_by_step):
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
for sequence_index in range(batch_size):
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
SequenceGroupOutput(
samples=[
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
logprobs={token_id: Logprob(0.0)},
)
],
prompt_logprobs=None,
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))

View File

@@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer):
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs = sampler_output_to_torch(
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
sampler_output, sampler_transposed)
# Now, reformat the output GPU tensors such that each sequence has

View File

@@ -1,10 +1,11 @@
from contextlib import contextmanager
from itertools import chain
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
SeqId = int
@@ -21,6 +22,89 @@ def get_all_seq_ids(
]))
def get_all_num_logprobs(
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
If the sampling params do not call for any logprobs, return 0 for that
sequence.
"""
all_num_logprobs = []
for seq_group_metadata in seq_group_metadata_list:
num_logprobs = seq_group_metadata.sampling_params.logprobs
if seq_group_metadata.sampling_params.logprobs is None:
num_logprobs = 0
all_num_logprobs.append(num_logprobs)
return all_num_logprobs
def get_sampled_token_logprobs(
# shape [num_steps, batch_size, vocab_size]
logprob_tensor: torch.Tensor,
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
"""
num_steps, batch_size, vocab_size = logprob_tensor.shape
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
torch.arange(batch_size),
sampled_token_ids, ]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor >=
expanded_selected_logprobs).sum(-1)
return sampled_token_ids_ranks, selected_logprobs
def create_sequence_group_output(
token_id: int,
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
) -> SequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
Args:
token_id (int): The sampled token for the sequence.
token_id_logprob_rank (int): The logprob rank of the sampled token.
token_id_logprob (float): The logprob value of the sampled token.
seq_id (int): The sequence id.
topk_token_ids (List[int]): The list of top-k token ids.
topk_logprobs (List[float]): The list of top-k logprobs.
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,
),
}
logprobs.update({
topk_token_ids[topk_logprob_index]: Logprob(
logprob=topk_logprobs[topk_logprob_index],
rank=topk_logprob_index + 1,
)
for topk_logprob_index, _ in enumerate(topk_token_ids)
})
return SequenceGroupOutput(
samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs=logprobs)
],
# TODO add prompt logprobs support.
prompt_logprobs=None,
)
def split_batch_by_proposal_len(
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_lens: List[int], select_proposal_len_zero: bool
@@ -49,8 +133,8 @@ def split_batch_by_proposal_len(
def sampler_output_to_torch(
sampler_output_list: List[SamplerOutput],
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
@@ -76,6 +160,15 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_probs = sampled_token_probs.transpose(0, 1)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs = torch.stack(
[sampler_output.logprobs for sampler_output in sampler_output_list],
dim=0,
)
if sampler_transposed:
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
# shape: [batch_size, num_sampler_output]
sampled_token_ids = torch.stack(
[
@@ -87,7 +180,7 @@ def sampler_output_to_torch(
if sampler_transposed:
sampled_token_ids = sampled_token_ids.transpose(0, 1)
return sampled_token_ids, sampled_token_probs
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,