[Speculative decoding] Support target-model logprobs (#4378)
This commit is contained in:
@@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs = self._contract_batch(
|
||||
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
||||
contracted_bs=len(seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
@@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
return SpeculativeScores(
|
||||
probs=all_probs,
|
||||
token_ids=all_tokens,
|
||||
logprobs=spec_logprobs,
|
||||
)
|
||||
|
||||
def _expand_batch(
|
||||
@@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals,
|
||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||
spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _contract_batch(
|
||||
self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals, num_scoring_tokens: int,
|
||||
non_spec_indices: List[int], spec_indices: List[int],
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
@@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs) = self._split_scoring_output(
|
||||
(target_token_ids, target_probs, target_logprobs,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
@@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||
self._vocab_size)
|
||||
target_logprobs = target_logprobs.squeeze().reshape(
|
||||
spec_expanded_bs, k + 1, self._vocab_size)
|
||||
|
||||
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1,
|
||||
@@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
all_logprobs = torch.full(size=(
|
||||
contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
),
|
||||
fill_value=-float("inf"),
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
|
||||
|
||||
if spec_indices:
|
||||
all_tokens[spec_indices] = target_token_ids
|
||||
all_probs[spec_indices] = target_probs
|
||||
all_logprobs[spec_indices] = target_logprobs
|
||||
|
||||
return all_tokens, all_probs
|
||||
return all_tokens, all_probs, all_logprobs
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
@@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
|
||||
def _split_scoring_output(
|
||||
self, sampler_output: SamplerOutput, num_scoring_tokens: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, torch.Tensor]:
|
||||
"""Split the target model output into speculative and non-speculative
|
||||
output.
|
||||
"""
|
||||
@@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
) = sampler_output.sampled_token_probs.split(split_sizes)
|
||||
(spec_sampled_tokens, non_spec_sampled_tokens
|
||||
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
|
||||
(
|
||||
spec_logprobs,
|
||||
non_spec_logprobs,
|
||||
) = sampler_output.logprobs.split(split_sizes)
|
||||
|
||||
# Convert scores to tensors.
|
||||
sampler_output.sampled_token_probs = spec_probs
|
||||
sampler_output.sampled_token_ids = spec_sampled_tokens
|
||||
target_token_ids, target_probs = sampler_output_to_torch(
|
||||
[sampler_output], True)
|
||||
sampler_output.logprobs = spec_logprobs
|
||||
(target_token_ids, target_probs,
|
||||
target_logprobs) = sampler_output_to_torch([sampler_output], True)
|
||||
|
||||
# Convert non-speculative output tokens to tensors.
|
||||
sampler_output.sampled_token_probs = non_spec_probs
|
||||
sampler_output.sampled_token_ids = non_spec_sampled_tokens
|
||||
non_spec_target_token_ids, non_spec_target_probs = (
|
||||
sampler_output_to_torch([sampler_output], True))
|
||||
sampler_output.logprobs = non_spec_logprobs
|
||||
(non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs) = sampler_output_to_torch([sampler_output],
|
||||
True)
|
||||
|
||||
return (target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs)
|
||||
return (target_token_ids, target_probs, target_logprobs,
|
||||
non_spec_target_token_ids, non_spec_target_probs,
|
||||
non_spec_target_logprobs)
|
||||
|
||||
def _create_target_seq_id_iterator(
|
||||
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
|
||||
|
||||
@@ -38,6 +38,11 @@ class SpeculativeScores:
|
||||
# Probabilities of the speculative tokens according to the scoring model.
|
||||
probs: torch.Tensor
|
||||
|
||||
# Log-probabilities of the speculative tokens according to the scoring
|
||||
# model. These values can be used to generate Logprob objects that are
|
||||
# returned to the user.
|
||||
logprobs: torch.Tensor
|
||||
|
||||
# Token ids sampled from the scoring model. Used for speculative bonus
|
||||
# tokens and also non-speculative normal decoding.
|
||||
token_ids: torch.Tensor
|
||||
|
||||
@@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
||||
device=self.device,
|
||||
)
|
||||
token_probs.scatter_(2, indices, 1)
|
||||
token_logprobs = torch.zeros(
|
||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(len(seq_group_metadata_list)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_probs[i],
|
||||
logprobs=token_logprobs,
|
||||
sampled_token_ids=token_ids[i],
|
||||
))
|
||||
return outputs, False
|
||||
|
||||
@@ -5,15 +5,16 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs, get_all_seq_ids,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
@@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.probs = None
|
||||
sampler_output.sampled_tokens = None
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output]
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
@@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
)
|
||||
|
||||
#logger.info("verify proposals")
|
||||
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
|
||||
proposal_scores, proposals, k)
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
seq_group_metadata_list, proposal_scores, proposals, k)
|
||||
|
||||
#logger.info("create output list")
|
||||
return self._create_output_sampler_list(seq_group_metadata_list,
|
||||
accepted_token_ids, k)
|
||||
return self._create_output_sampler_list(
|
||||
seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
target_logprobs=target_logprobs,
|
||||
k=k)
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
@@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposal_scores: SpeculativeScores,
|
||||
proposals: SpeculativeProposals,
|
||||
max_proposal_len: int,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Determine which speculative tokens are accepted using the
|
||||
probabilities of each token according to the proposer and scorer models.
|
||||
|
||||
Returns a tuple of Tensors, one for the accepted token ids and one for
|
||||
the logprobs according to the scoring model.
|
||||
"""
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
|
||||
@@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
non_spec_token_ids[:, 1:] = -1
|
||||
accepted_token_ids = torch.cat(
|
||||
[accepted_token_ids, non_spec_token_ids])
|
||||
logprobs = proposal_scores.logprobs
|
||||
|
||||
# Rearrange so that results are in the order of the original seq group
|
||||
# metadata.
|
||||
accepted_token_ids[original_indices] = accepted_token_ids.clone()
|
||||
|
||||
return accepted_token_ids
|
||||
return accepted_token_ids, logprobs
|
||||
|
||||
def _create_output_sampler_list(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||
k: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||
@@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
The output is padded with -1 tokens such that each sequence has
|
||||
the same number of outputs.
|
||||
"""
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
batch_size, num_steps = accepted_token_ids.shape
|
||||
|
||||
# shape: [k+1, batch_size]
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0,
|
||||
1).tolist()
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
||||
|
||||
# Get the logprobs/rank of the accepted tokens.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
|
||||
logprob_tensor=target_logprobs_by_step,
|
||||
sampled_token_ids=accepted_token_ids_by_step,
|
||||
)
|
||||
|
||||
# Get the top-k logprobs (which may or may not include the logprob of
|
||||
# the accepted token).
|
||||
(topk_logprobs_by_step,
|
||||
topk_indices_by_step) = target_logprobs_by_step.topk(
|
||||
k=self.scorer_worker.model_config.max_logprobs,
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
accepted_token_id_ranks_by_step = (
|
||||
accepted_token_id_ranks_by_step.tolist())
|
||||
accepted_token_id_logprobs_by_step = (
|
||||
accepted_token_id_logprobs_by_step.tolist())
|
||||
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
|
||||
topk_indices_by_step = topk_indices_by_step.tolist()
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
sampler_output_list = []
|
||||
for token_ids_by_step in accepted_token_ids_by_step:
|
||||
if all(token_id == -1 for token_id in token_ids_by_step):
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
|
||||
step_output_token_ids = []
|
||||
for token_id, seq_id in zip(token_ids_by_step, seq_ids):
|
||||
for sequence_index in range(batch_size):
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
|
||||
step_output_token_ids.append(
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
# TODO Add verifier logprobs.
|
||||
logprobs={token_id: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
create_sequence_group_output(
|
||||
token_id=accepted_token_ids_by_step[step_index]
|
||||
[sequence_index],
|
||||
token_id_logprob_rank=accepted_token_id_ranks_by_step[
|
||||
step_index][sequence_index],
|
||||
token_id_logprob=accepted_token_id_logprobs_by_step[
|
||||
step_index][sequence_index],
|
||||
seq_id=seq_ids[sequence_index],
|
||||
topk_token_ids=topk_indices_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
proposal_tokens, proposal_probs, _ = sampler_output_to_torch(
|
||||
sampler_output, sampler_transposed)
|
||||
|
||||
# Now, reformat the output GPU tensors such that each sequence has
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
||||
SequenceGroupOutput, SequenceOutput)
|
||||
|
||||
SeqId = int
|
||||
|
||||
@@ -21,6 +22,89 @@ def get_all_seq_ids(
|
||||
]))
|
||||
|
||||
|
||||
def get_all_num_logprobs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
||||
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
||||
|
||||
If the sampling params do not call for any logprobs, return 0 for that
|
||||
sequence.
|
||||
"""
|
||||
|
||||
all_num_logprobs = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
num_logprobs = seq_group_metadata.sampling_params.logprobs
|
||||
if seq_group_metadata.sampling_params.logprobs is None:
|
||||
num_logprobs = 0
|
||||
all_num_logprobs.append(num_logprobs)
|
||||
|
||||
return all_num_logprobs
|
||||
|
||||
|
||||
def get_sampled_token_logprobs(
|
||||
# shape [num_steps, batch_size, vocab_size]
|
||||
logprob_tensor: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
|
||||
"""
|
||||
num_steps, batch_size, vocab_size = logprob_tensor.shape
|
||||
|
||||
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
|
||||
torch.arange(batch_size),
|
||||
sampled_token_ids, ]
|
||||
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
||||
-1, -1, vocab_size)
|
||||
sampled_token_ids_ranks = (logprob_tensor >=
|
||||
expanded_selected_logprobs).sum(-1)
|
||||
|
||||
return sampled_token_ids_ranks, selected_logprobs
|
||||
|
||||
|
||||
def create_sequence_group_output(
|
||||
token_id: int,
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[int],
|
||||
topk_logprobs: List[float],
|
||||
) -> SequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
Args:
|
||||
token_id (int): The sampled token for the sequence.
|
||||
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
||||
token_id_logprob (float): The logprob value of the sampled token.
|
||||
seq_id (int): The sequence id.
|
||||
topk_token_ids (List[int]): The list of top-k token ids.
|
||||
topk_logprobs (List[float]): The list of top-k logprobs.
|
||||
"""
|
||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||
logprobs: Dict[int, Logprob] = {
|
||||
token_id: Logprob(
|
||||
logprob=token_id_logprob,
|
||||
rank=token_id_logprob_rank,
|
||||
),
|
||||
}
|
||||
logprobs.update({
|
||||
topk_token_ids[topk_logprob_index]: Logprob(
|
||||
logprob=topk_logprobs[topk_logprob_index],
|
||||
rank=topk_logprob_index + 1,
|
||||
)
|
||||
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
||||
})
|
||||
|
||||
return SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs=logprobs)
|
||||
],
|
||||
# TODO add prompt logprobs support.
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
|
||||
def split_batch_by_proposal_len(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_lens: List[int], select_proposal_len_zero: bool
|
||||
@@ -49,8 +133,8 @@ def split_batch_by_proposal_len(
|
||||
|
||||
|
||||
def sampler_output_to_torch(
|
||||
sampler_output_list: List[SamplerOutput],
|
||||
sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Utility function which converts a list of SamplerOutput to tensors.
|
||||
|
||||
sampler_transposed here is used as the indicator for whether
|
||||
@@ -76,6 +160,15 @@ def sampler_output_to_torch(
|
||||
if sampler_transposed:
|
||||
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output, vocab_size]
|
||||
sampled_token_logprobs = torch.stack(
|
||||
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if sampler_transposed:
|
||||
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
||||
|
||||
# shape: [batch_size, num_sampler_output]
|
||||
sampled_token_ids = torch.stack(
|
||||
[
|
||||
@@ -87,7 +180,7 @@ def sampler_output_to_torch(
|
||||
if sampler_transposed:
|
||||
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
||||
|
||||
return sampled_token_ids, sampled_token_probs
|
||||
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
|
||||
|
||||
|
||||
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
||||
|
||||
Reference in New Issue
Block a user