[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
@@ -6,8 +6,8 @@ import torch
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
|
||||
nvtx_range, sampler_output_to_torch,
|
||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||
sampler_output_to_torch,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
@@ -72,10 +72,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
proposal_lens_list = proposals.proposal_lens.tolist()
|
||||
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
|
||||
|
||||
# Filter the list to ignore -1 proposals.
|
||||
proposal_token_ids_list_without_skips = [
|
||||
proposals for proposals in proposal_token_ids_list
|
||||
if -1 not in proposals
|
||||
]
|
||||
|
||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens) = self._expand_batch(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list,
|
||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||
proposal_lens_list=proposal_lens_list,
|
||||
)
|
||||
|
||||
@@ -89,7 +95,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
target_sampler_output = target_sampler_output[0]
|
||||
|
||||
all_tokens, all_probs = self._contract_batch(
|
||||
original_bs=len(seq_group_metadata_list),
|
||||
contracted_bs=len(seq_group_metadata_list),
|
||||
target_sampler_output=target_sampler_output,
|
||||
proposals=proposals,
|
||||
num_scoring_tokens=num_scoring_tokens,
|
||||
@@ -128,14 +134,21 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
select_proposal_len_zero=True)
|
||||
|
||||
target_seq_group_metadata_list = self._create_scoring_model_input(
|
||||
spec_seqs, proposal_token_ids_list)
|
||||
seq_group_metadata_list=spec_seqs,
|
||||
proposal_token_ids=proposal_token_ids_list,
|
||||
# NOTE: We determine the seq ids in the expanded batch using the
|
||||
# full seq_group_metadata_list, instead of only spec_seqs.
|
||||
target_seq_ids_iter=self._create_target_seq_id_iterator(
|
||||
seq_ids=get_all_seq_ids(seq_group_metadata_list)),
|
||||
)
|
||||
|
||||
num_scoring_tokens = len(target_seq_group_metadata_list)
|
||||
target_seq_group_metadata_list.extend(non_spec_seqs)
|
||||
|
||||
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||
num_scoring_tokens)
|
||||
|
||||
def _contract_batch(self, original_bs: int,
|
||||
def _contract_batch(self, contracted_bs: int,
|
||||
target_sampler_output: List[SamplerOutput],
|
||||
proposals: SpeculativeProposals,
|
||||
num_scoring_tokens: int, non_spec_indices: List[int],
|
||||
@@ -144,42 +157,41 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Contract the expanded batch back into its original size.
|
||||
This maps the scores of speculative tokens back to their original
|
||||
sequences.
|
||||
|
||||
contracted_bs is the original batch size, and the batch size that the
|
||||
target_sampler_output will be contracted to.
|
||||
"""
|
||||
|
||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
maybe_mock_device_tensors(
|
||||
sampler_output=target_sampler_output,
|
||||
batch_size=len(non_spec_indices) + num_scoring_tokens,
|
||||
vocab_size=self._vocab_size,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
(target_token_ids, target_probs, non_spec_target_token_ids,
|
||||
non_spec_target_probs) = self._split_scoring_output(
|
||||
target_sampler_output, num_scoring_tokens)
|
||||
|
||||
# Map distinct sequences used to score each token
|
||||
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
|
||||
batch_size, k = proposals.proposal_token_ids.shape
|
||||
expanded_batch_size, k = proposals.proposal_token_ids.shape
|
||||
|
||||
# The number of tokens in the expanded batch used for speculation is
|
||||
# equal to the total expanded batch size minus the number of samples for
|
||||
# non-speculative sequences.
|
||||
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
|
||||
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
|
||||
|
||||
target_token_ids = target_token_ids.squeeze().reshape(
|
||||
batch_size, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
|
||||
spec_expanded_bs, k + 1)
|
||||
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
|
||||
self._vocab_size)
|
||||
|
||||
all_tokens = torch.full(size=(original_bs, k + 1),
|
||||
all_tokens = torch.full(size=(contracted_bs, k + 1),
|
||||
fill_value=-1,
|
||||
device=self._device,
|
||||
dtype=torch.long)
|
||||
all_probs = torch.zeros(original_bs,
|
||||
all_probs = torch.zeros(contracted_bs,
|
||||
k + 1,
|
||||
self._vocab_size,
|
||||
device=self._device,
|
||||
dtype=torch.float32)
|
||||
|
||||
if non_spec_indices:
|
||||
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
|
||||
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
|
||||
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
|
||||
|
||||
if spec_indices:
|
||||
@@ -189,20 +201,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
return all_tokens, all_probs
|
||||
|
||||
def _create_scoring_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
|
||||
target_seq_ids_iter: Iterator[TargetSeqId],
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Given the original input sequences and proposed tokens from the draft
|
||||
model, create a list of target sequences that can be used for scoring.
|
||||
|
||||
target_seq_ids_iter provides sequence ids for the expanded batch,
|
||||
fulfilling the requirement that no seq id in the expanded batch is equal
|
||||
to the seq id in the original batch.
|
||||
"""
|
||||
|
||||
if not seq_group_metadata_list:
|
||||
return []
|
||||
|
||||
target_seq_ids_iter = self._create_target_seq_id_iterator(
|
||||
get_all_seq_ids(seq_group_metadata_list))
|
||||
|
||||
target_seq_group_metadata = list(
|
||||
chain.from_iterable(
|
||||
self._create_target_seq_group_metadata(
|
||||
|
||||
@@ -24,9 +24,9 @@ class SpeculativeProposals:
|
||||
|
||||
def __repr__(self):
|
||||
return (f"SpeculativeProposals("
|
||||
f"proposal_token_ids={self.proposal_token_ids.shape}, "
|
||||
f"proposal_token_ids={self.proposal_token_ids}, "
|
||||
f"proposal_probs={self.proposal_probs.shape}, "
|
||||
f"proposal_lens={self.proposal_lens.shape})")
|
||||
f"proposal_lens={self.proposal_lens})")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -147,15 +147,16 @@ class AsyncMetricsCollector:
|
||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||
draft_tokens = self._aggregate_num_draft_tokens
|
||||
|
||||
num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
|
||||
max_num_emitted_tokens = self.get_max_num_emitted_tokens(
|
||||
draft_tokens, k)
|
||||
|
||||
if draft_tokens > 0:
|
||||
draft_acceptance_rate = accepted_tokens / draft_tokens
|
||||
else:
|
||||
draft_acceptance_rate = float("nan")
|
||||
|
||||
if num_possible_tokens > 0:
|
||||
system_efficiency = emitted_tokens / num_possible_tokens
|
||||
if max_num_emitted_tokens > 0:
|
||||
system_efficiency = emitted_tokens / max_num_emitted_tokens
|
||||
else:
|
||||
system_efficiency = float("nan")
|
||||
|
||||
@@ -169,8 +170,22 @@ class AsyncMetricsCollector:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
|
||||
# Divide by k since batch size can be variable.
|
||||
total_num_spec_seqs = draft_tokens / k
|
||||
num_accepted_per_seq_if_all_accepted = k + 1
|
||||
return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
|
||||
def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
|
||||
"""Calculate the number of emitted tokens, assuming all tokens are
|
||||
accepted.
|
||||
|
||||
This is equal to the number of sequences that have been speculated on,
|
||||
times (speculation len + 1). The +1 comes from the bonus token.
|
||||
"""
|
||||
# Determine the number of sequences that have been speculated on. Since
|
||||
# the batch size can be variable, we divide by k.
|
||||
assert draft_tokens % k == 0
|
||||
total_num_spec_seqs = draft_tokens // k
|
||||
|
||||
# A single sequence may emit k accepted tokens and one bonus token in
|
||||
# the best case.
|
||||
num_emitted_per_seq_if_all_accepted = k + 1
|
||||
|
||||
# The max num of emitted tokens is the number of speculated sequences
|
||||
# times the max emitted per seq.
|
||||
return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
|
||||
|
||||
@@ -6,8 +6,7 @@ import torch
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeProposer)
|
||||
from vllm.spec_decode.util import (maybe_mock_device_tensors,
|
||||
sampler_output_to_torch)
|
||||
from vllm.spec_decode.util import sampler_output_to_torch
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@@ -329,12 +328,15 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
"""
|
||||
if maybe_sampler_output is None:
|
||||
# If no speculative tokens, the sampler output will be None.
|
||||
# In this case we return empty tensors.
|
||||
proposal_tokens = torch.zeros(0,
|
||||
max_proposal_len,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_probs = torch.zeros(0,
|
||||
# In this case we return empty proposals.
|
||||
proposal_tokens = torch.full(size=(
|
||||
batch_size,
|
||||
max_proposal_len,
|
||||
),
|
||||
fill_value=-1,
|
||||
dtype=torch.long,
|
||||
device=self._device)
|
||||
proposal_probs = torch.zeros(batch_size,
|
||||
max_proposal_len,
|
||||
self._vocab_size,
|
||||
dtype=torch.float32,
|
||||
@@ -345,17 +347,6 @@ class DraftModelTop1Proposer(SpeculativeProposer):
|
||||
return proposal_tokens, proposal_probs, proposal_lens_tensor
|
||||
|
||||
sampler_output = maybe_sampler_output
|
||||
|
||||
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
|
||||
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
||||
for step_output in sampler_output:
|
||||
maybe_mock_device_tensors(
|
||||
sampler_output=step_output,
|
||||
batch_size=len(proposal_lens),
|
||||
vocab_size=self._vocab_size,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
proposal_tokens, proposal_probs = sampler_output_to_torch(
|
||||
sampler_output)
|
||||
|
||||
|
||||
@@ -111,6 +111,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
device=self.device,
|
||||
vocab_size=self._vocab_size)
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||
to keep data on device without transferring to CPU and serializing,
|
||||
which significantly reduces overhead of rejection sampling.
|
||||
|
||||
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
|
||||
design is to have the "move to CPU and serialize" sampling decision be
|
||||
done outside of the model/sampler; this way the "last-mile" worker
|
||||
object which interfaces with the scheduler can serialize and incur the
|
||||
performance hit as necessary. This allows us to run the worker several
|
||||
iterations in a row without incurring the "move to CPU and serialize"
|
||||
performance penalty.
|
||||
|
||||
Since this requires a large change to vLLM, we defer it to later and
|
||||
temporarily accept this broken abstraction boundary.
|
||||
|
||||
NOTE(cade): This will require a special check if the proposer worker
|
||||
does not have a sampler (e.g. ngram speculation).
|
||||
"""
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
(self.proposer_worker.model_runner.model.sampler.
|
||||
include_gpu_probs_tensor) = True
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
@@ -286,15 +312,26 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
select_proposal_len_zero=True)
|
||||
original_indices = spec_indices + non_spec_indices
|
||||
|
||||
proposal_probs = proposal_scores.probs[spec_indices, :-1]
|
||||
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
# Get probabilities of target model, excluding bonus token.
|
||||
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
|
||||
|
||||
# Get non-speculative sampled tokens from target model.
|
||||
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
|
||||
|
||||
# Get bonus tokens from target model.
|
||||
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
|
||||
|
||||
# Get probabilities according to proposal method.
|
||||
proposal_probs = proposals.proposal_probs[spec_indices]
|
||||
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
accepted_token_ids = self.rejection_sampler(
|
||||
proposal_probs,
|
||||
bonus_token_ids,
|
||||
proposals.proposal_probs,
|
||||
proposals.proposal_token_ids,
|
||||
target_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
)
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
|
||||
Reference in New Issue
Block a user