[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
@@ -62,6 +62,9 @@ class SpeculativeProposer(ABC):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# If set, this contains all sequence IDs that were assigned
|
||||
# bonus tokens in their last forward pass.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import weakref
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,6 +40,8 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
@@ -97,12 +99,14 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,6 +20,9 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass to generate sample_len future tokens.
|
||||
Returns the list of sampler output, one per layer, along with indicator
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import weakref
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -51,6 +51,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
"""Run the model forward pass sample_len times. Returns the list of
|
||||
sampler output, one per model forward pass, along with indicator of
|
||||
@@ -60,44 +61,142 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
For multi step worker, this indicator shall be True.
|
||||
"""
|
||||
self._raise_if_unsupported(execute_model_req)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
copied_execute_model_req = execute_model_req.clone(
|
||||
copied_seq_group_metadata_list)
|
||||
|
||||
# Expand the batch for sequences with a bonus token.
|
||||
# Perform a forward pass on the expanded batch and filter the
|
||||
# response to retain only the original sequences' responses.
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner):
|
||||
copied_execute_model_req.num_steps = sample_len
|
||||
expanded_request.num_steps = sample_len
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = super().execute_model(
|
||||
execute_model_req=copied_execute_model_req)
|
||||
execute_model_req=expanded_request)
|
||||
assert (len(model_output) == 1
|
||||
), "composing multistep workers not supported"
|
||||
model_output = model_output[0]
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs, True
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
|
||||
@staticmethod
|
||||
def _expand_execute_model_request(
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_with_bonus_token_in_last_step: set,
|
||||
) -> Tuple[ExecuteModelRequest, List[int]]:
|
||||
"""
|
||||
Expands the execute model request based on sequences with bonus
|
||||
tokens.
|
||||
|
||||
For each sequence with a bonus token, this method creates a new
|
||||
sequence without the bonus token and adds it to the execute model
|
||||
request. The original sequence groups are also retained. The indices
|
||||
of the original sequence groups are returned for further processing.
|
||||
|
||||
Args:
|
||||
execute_model_req (ExecuteModelRequest): The original execute
|
||||
model request.
|
||||
seq_with_bonus_token_in_last_step (set): Set of sequence IDs that
|
||||
contain bonus tokens.
|
||||
|
||||
Returns:
|
||||
Tuple[ExecuteModelRequest, List[int]]: The updated execute model
|
||||
request with expanded sequences and a list of indices corresponding
|
||||
to the original sequence groups.
|
||||
"""
|
||||
updated_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
updated_execute_model_req = execute_model_req.clone(
|
||||
updated_seq_group_metadata_list)
|
||||
indices_of_original_sequence_groups = []
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
seq_group_has_bonus_tokens = False
|
||||
for seq_id, _ in seq_group.seq_data.items():
|
||||
# Identify sequences with bonus tokens in the sequence group.
|
||||
if seq_id in seq_with_bonus_token_in_last_step:
|
||||
seq_group_has_bonus_tokens = True
|
||||
break
|
||||
if seq_group_has_bonus_tokens:
|
||||
#Create new sequences without the last bonus token. These new
|
||||
# sequence have the same sequence id as the original sequence.
|
||||
# We create a new sequence group and add them there.
|
||||
updated_seq_group_without_bonus_token = \
|
||||
MultiStepWorker._copy_seq_metadata_excluding_last_token(
|
||||
seq_group, seq_with_bonus_token_in_last_step)
|
||||
updated_seq_group_metadata_list.append(
|
||||
updated_seq_group_without_bonus_token)
|
||||
# Add the original sequence group.
|
||||
updated_seq_group_metadata_list.append(
|
||||
MultiStepWorker._shallow_copy_seq_group_metadata(seq_group))
|
||||
# Record the index of the original sequence group.
|
||||
indices_of_original_sequence_groups.append(
|
||||
len(updated_seq_group_metadata_list) - 1)
|
||||
|
||||
updated_execute_model_req.seq_group_metadata_list =\
|
||||
updated_seq_group_metadata_list
|
||||
return updated_execute_model_req, indices_of_original_sequence_groups
|
||||
|
||||
@staticmethod
|
||||
def _filter_model_output(
|
||||
expanded_batch_outputs: List[SamplerOutput],
|
||||
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
|
||||
"""
|
||||
Filters the model output to include only the specified sequence
|
||||
outputs. This method contracts the expanded batch output from the
|
||||
model to retain the outputs of only those sequences indicated by the
|
||||
provided indices.
|
||||
|
||||
Args:
|
||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||
batch from the model.
|
||||
output_indices_to_retain (List[int]): Indices of the model outputs
|
||||
to retain.
|
||||
|
||||
Returns:
|
||||
List[SamplerOutput]: A list containing the filtered model
|
||||
outputs for the specified indices.
|
||||
"""
|
||||
return [
|
||||
SamplerOutput(
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_probs is not None
|
||||
else None),
|
||||
logprobs=(
|
||||
expanded_batch_output.logprobs[output_indices_to_retain]
|
||||
if expanded_batch_output.logprobs is not None else None),
|
||||
sampled_token_ids=(expanded_batch_output.
|
||||
sampled_token_ids[output_indices_to_retain]
|
||||
if expanded_batch_output.sampled_token_ids
|
||||
is not None else None))
|
||||
for expanded_batch_output in expanded_batch_outputs
|
||||
]
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: set,
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
@staticmethod
|
||||
def _append_new_tokens(
|
||||
@@ -123,9 +222,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
@staticmethod
|
||||
def _shallow_copy_inputs(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
def _shallow_copy_seq_group_metadata(
|
||||
seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
@@ -133,26 +231,62 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
The alternative is deep-copying (or other form of deep copy); this has
|
||||
performance downsides.
|
||||
"""
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# Shallow-copy the SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:]
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
@staticmethod
|
||||
def _copy_seq_metadata_excluding_last_token(
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_ids_to_copy: Set[int],
|
||||
) -> SequenceGroupMetadata:
|
||||
"""
|
||||
Creates a shallow copy of the given SequenceGroupMetadata, retaining
|
||||
only the sequence IDs specified in seq_ids_to_copy. For each of these
|
||||
sequence IDs, all output_token_ids except the last one are copied.
|
||||
Sequence IDs not in seq_ids_to_copy are excluded from the copy.
|
||||
|
||||
Parameters:
|
||||
seq_group_metadata (SequenceGroupMetadata): The original sequence
|
||||
group metadata.
|
||||
seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the
|
||||
copy.
|
||||
|
||||
Returns:
|
||||
SequenceGroupMetadata: A shallow copy of the sequence group metadata
|
||||
with the specified modifications.
|
||||
"""
|
||||
# Shallow-copy the SequenceGroupMetadata.
|
||||
new_seq_group_metadata = copy.copy(seq_group_metadata)
|
||||
# Shallow-copy seq_data and modify the output_token_ids.
|
||||
new_seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
if (seq_id in seq_ids_to_copy):
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||
|
||||
seq_group_metadata.seq_data = new_seq_data
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
# Copy all the output token ids except the last.
|
||||
# Also reduce num_computed_tokens by 1 since we are not
|
||||
# including the last output token.
|
||||
# NOTE: num_computed_tokens is not directly used by the
|
||||
# speculative decoding workers, as it is only relevant for
|
||||
# chunked prefill, which is disabled for speculative decoding.
|
||||
# However, to maintain consistency in num_computed_tokens,
|
||||
# we update it here.
|
||||
new_seq_data[seq_id].output_token_ids =\
|
||||
old_seq_data.output_token_ids[:-1]
|
||||
new_seq_data[seq_id].update_num_computed_tokens(-1)
|
||||
new_seq_group_metadata.seq_data = new_seq_data
|
||||
return new_seq_group_metadata
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import weakref
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -48,6 +48,9 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]:
|
||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||
sampler output, one per SequenceGroupMetadata.
|
||||
@@ -133,12 +136,15 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
# Unused parameter. NGramWorker does not use the KV Cache and
|
||||
# therefore does not need this parameter.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
"""
|
||||
|
||||
return self._proposer.get_spec_proposals(execute_model_req)
|
||||
return self._proposer.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposer
|
||||
@@ -14,6 +14,13 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
# A set containing all sequence IDs that were assigned bonus tokens
|
||||
# in their last forward pass. This set is used to backfill the KV cache
|
||||
# with the key-value pairs of the penultimate token in the sequences.
|
||||
# This parameter is only used by the MultiStepWorker, which relies on
|
||||
# the KV cache for token generation. It is not used by workers that
|
||||
# do not utilize the KV cache.
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int]
|
||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -110,13 +110,17 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
sample_len: int,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> Tuple[List[SamplerOutput], bool]:
|
||||
# Do not check _is_dummy, as it's always called by get_spec_proposals
|
||||
return self._worker.sampler_output(execute_model_req, sample_len)
|
||||
return self._worker.sampler_output(
|
||||
execute_model_req, sample_len,
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Produce speculations given an input batch of sequences. The number of
|
||||
speculative tokens per sequence is determined by max_proposal_len.
|
||||
@@ -125,7 +129,8 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
return SpeculativeProposals(None, None, None)
|
||||
|
||||
with self._patch_tensor_parallel_group():
|
||||
return self._worker.get_spec_proposals(execute_model_req)
|
||||
return self._worker.get_spec_proposals(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -13,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids)
|
||||
get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
@@ -112,11 +113,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||
ngram_prompt_lookup_min = (
|
||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||
|
||||
disable_bonus_tokens = True
|
||||
|
||||
if ngram_prompt_lookup_max > 0:
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||
ngram_prompt_lookup_max)
|
||||
@@ -128,11 +125,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
if draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||
elif draft_worker_kwargs[
|
||||
"model_config"].hf_config.model_type == "medusa":
|
||||
disable_bonus_tokens = False
|
||||
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||
else:
|
||||
if draft_tp == 1:
|
||||
@@ -149,10 +144,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
spec_decode_sampler: SpecDecodeBaseSampler = None
|
||||
if draft_token_acceptance_method == "rejection_sampler":
|
||||
spec_decode_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens, )
|
||||
disable_bonus_tokens=False, )
|
||||
elif draft_token_acceptance_method == "typical_acceptance_sampler":
|
||||
spec_decode_sampler = TypicalAcceptanceSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
disable_bonus_tokens=False,
|
||||
posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||
@@ -200,6 +195,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self._metrics = AsyncMetricsCollector(
|
||||
self.spec_decode_sampler
|
||||
) if metrics_collector is None else metrics_collector
|
||||
# Tracks the sequence IDs that received a bonus token ID in
|
||||
# their last forward pass. Needed only if KV cache is being
|
||||
# used for token generation such as in the case of MultiStepWorker.
|
||||
self._seq_with_bonus_token_in_last_step: Set[int] = set()
|
||||
# Tracks the currently active request ids and the sequence IDs
|
||||
# corresponding to them
|
||||
self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
|
||||
# Tracks if the proposer worker uses the KV cache or not.
|
||||
|
||||
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||
# Lazy initiazliation.
|
||||
@@ -307,6 +311,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return []
|
||||
|
||||
self._track_finished_requests(execute_model_req)
|
||||
disable_all_speculation = self._should_disable_all_speculation(
|
||||
execute_model_req)
|
||||
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||
@@ -453,7 +458,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.previous_hidden_states = None
|
||||
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
||||
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
@@ -585,7 +591,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
|
||||
seq_group_metadata_list)
|
||||
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
@@ -608,7 +616,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
for sequence_index in range(batch_size):
|
||||
# Each sequence may have a different num_logprobs; retrieve it.
|
||||
num_logprobs = num_logprobs_per_seq[sequence_index]
|
||||
|
||||
step_output_token_ids.append(
|
||||
create_sequence_group_output(
|
||||
token_id=accepted_token_ids_by_step[step_index]
|
||||
@@ -623,18 +630,48 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
topk_logprobs=topk_logprobs_by_step[step_index]
|
||||
[sequence_index][:num_logprobs],
|
||||
))
|
||||
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
|
||||
# Populate the data structures needed to keep track of sequences with
|
||||
# bonus tokens.
|
||||
self._track_sequences_with_bonus_tokens(seq_ids,
|
||||
request_ids_seq_ids_mapping,
|
||||
accepted_token_ids_by_step)
|
||||
maybe_rejsample_metrics = (
|
||||
self._metrics.maybe_collect_rejsample_metrics(k))
|
||||
if maybe_rejsample_metrics is not None:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
return sampler_output_list
|
||||
|
||||
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
|
||||
"""
|
||||
Removes the finished requests and their associated sequence ids from
|
||||
internal book keeping data structures.
|
||||
"""
|
||||
for finished_request in execute_model_req.finished_requests_ids:
|
||||
for seq_id in self._request_id_seq_id_mapping[finished_request]:
|
||||
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||
del self._request_id_seq_id_mapping[finished_request]
|
||||
|
||||
def _track_sequences_with_bonus_tokens(
|
||||
self, seq_ids: List[int],
|
||||
request_ids_seq_ids_mapping: Dict[str, Set[int]],
|
||||
accepted_token_ids_by_step: List[List[int]]):
|
||||
"""
|
||||
Updates the internal data structures which keep track of sequences
|
||||
which have been assigned bonus tokens in their last forward pass.
|
||||
"""
|
||||
for seq_index, seq_id in enumerate(seq_ids):
|
||||
last_token_id = accepted_token_ids_by_step[-1][seq_index]
|
||||
if last_token_id == -1:
|
||||
self._seq_with_bonus_token_in_last_step.discard(seq_id)
|
||||
else:
|
||||
self._seq_with_bonus_token_in_last_step.add(seq_id)
|
||||
for request_id, sequences in request_ids_seq_ids_mapping.items():
|
||||
self._request_id_seq_id_mapping[request_id].update(sequences)
|
||||
|
||||
@cached_property
|
||||
def _vocab_size(self) -> int:
|
||||
"""Get the vocab size of the model and make sure it's consistent between
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -42,6 +42,7 @@ class Top1Proposer(SpeculativeProposer):
|
||||
def get_spec_proposals(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
seq_ids_with_bonus_token_in_last_step: Set[int],
|
||||
) -> SpeculativeProposals:
|
||||
"""Get speculative proposals given the input batch.
|
||||
|
||||
@@ -76,6 +77,8 @@ class Top1Proposer(SpeculativeProposer):
|
||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||
execute_model_req=nonzero_execute_model_req,
|
||||
sample_len=proposal_len,
|
||||
seq_ids_with_bonus_token_in_last_step=\
|
||||
seq_ids_with_bonus_token_in_last_step,
|
||||
)
|
||||
(
|
||||
proposal_lens,
|
||||
|
||||
Reference in New Issue
Block a user