[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)

This commit is contained in:
sroy745
2024-07-20 23:58:58 -07:00
committed by GitHub
parent d7f4178dd9
commit 14f91fe67c
8 changed files with 333 additions and 64 deletions

View File

@@ -14,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids)
get_all_seq_ids, get_all_seq_ids_and_request_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
@@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
@@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
assert speculative_config is not None
target_worker = Worker(*args, **kwargs)
draft_worker_kwargs = kwargs.copy()
kwargs["model_runner_cls"] = TargetModelRunner
target_worker = Worker(*args, **kwargs)
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
target_worker.model_runner.disable_logprobs =\
speculative_config.disable_logprobs
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
@@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold=speculative_config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha)
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs)
return spec_decode_worker
@@ -107,6 +115,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
@@ -161,6 +170,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_logprobs=disable_logprobs,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
@@ -170,6 +180,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
@@ -189,6 +200,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
@@ -222,6 +236,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# Hidden states from target model to pass to proposer
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
@@ -357,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
) == 0 or disable_all_speculation:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
@@ -391,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
def _serialize_sampler_output_no_logprobs(
self, execute_model_req: ExecuteModelRequest,
sampler_output: SamplerOutput) -> SamplerOutput:
"""
Creates and returns a `SamplerOutput` with only the sampled token IDs
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
All other parameters in `CompletionSequenceGroupOutput` related to log
probabilities are skipped.
Args:
execute_model_req (ExecuteModelRequest): The model request that
was executed.
sampler_output (SamplerOutput): The output from the sampler with
only GPU tensors populated.
Returns:
SamplerOutput: A new `SamplerOutput` instance containing a list of
`CompletionSequenceGroupOutput` objects with only sampled token
IDs populated.
"""
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
completion_seq_group_output_list: List[
CompletionSequenceGroupOutput] = []
for index, seq_id in enumerate(seq_ids):
completion_seq_group_output_list.append(
create_sequence_group_output(
token_id=sampled_token_ids_list[index][0],
token_id_logprob_rank=-1,
token_id_logprob=0.0,
seq_id=seq_id,
topk_token_ids=[],
topk_logprobs=[],
))
return SamplerOutput(outputs=completion_seq_group_output_list)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
@@ -417,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.previous_hidden_states.update(
execute_model_req.seq_group_metadata_list, hidden_states)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
sampler_output)
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.probs = None
sampler_output.sampled_tokens = None
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return [sampler_output]
return [sampler_output_to_return]
def _run_non_driver_rank(self) -> bool:
"""Run proposer and verifier model in non-driver workers. This is used
@@ -480,7 +535,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req,
proposals,
)
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
@@ -601,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the logprob of
# the accepted token).
(topk_logprobs_by_step,
topk_indices_by_step) = target_logprobs_by_step.topk(
k=self.scorer_worker.model_config.max_logprobs,
dim=-1,
)
if self._disable_logprobs:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_dummy_logprob_lists(
batch_size, num_steps,
self.scorer_worker.model_config.max_logprobs)
else:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
# Serialize all tensors into Python lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_logprob_lists_from_tensors(
target_logprobs_by_step, accepted_token_ids_by_step,
self.scorer_worker.model_config.max_logprobs)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
@@ -628,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize all tensors to CPU Python lists.
# Serialize tensor to CPU Python list.
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step.tolist())
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
topk_indices_by_step = topk_indices_by_step.tolist()
# Construct the output on a per-step, per-sequence basis.
sampler_output_list: List[SamplerOutput] = []
@@ -677,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
0].spec_decode_worker_metrics = maybe_rejsample_metrics
return sampler_output_list
def _create_dummy_logprob_lists(
self,
batch_size: int,
num_steps: int,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four dummy lists representing token probabilities
and their ranks.
This method initializes and returns:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
batch_size (int): The size of the batch.
num_steps (int): The number of steps in the sequence.
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing four dummy lists as described above.
"""
accepted_token_id_ranks_by_step = [[-1] * batch_size
for _ in range(num_steps)]
accepted_token_id_logprobs_by_step = [[0.0] * batch_size
for _ in range(num_steps)]
topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
topk_indices_by_step: List[List[List[Optional[int]]]] = [[
[None] * num_top_k for _ in range(batch_size)
] for _ in range(num_steps)]
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _create_logprob_lists_from_tensors(
self,
target_logprobs_by_step: torch.Tensor,
accepted_token_ids_by_step: torch.Tensor,
num_top_k: int,
) -> Tuple[List[List[int]], List[List[float]],
List[List[List[Optional[float]]]],
List[List[List[Optional[int]]]]]:
"""
Creates and returns four lists representing token probabilities and
their ranks.
This method initializes and returns four lists containing:
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
- The log probabilities of the accepted tokens,
shaped (num_steps, batch_size)
- The log probabilities of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
- The token IDs of the top k tokens,
shaped (num_steps, batch_size, num_top_k)
Args:
target_logprobs_by_step (torch.Tensor): Tensor representing the
log probabilities of the target model,
shaped (num_steps, batch_size, vocab_size)
accepted_token_ids_by_step (torch.Tensor): Tensor representing
the accepted token_ids, shaped (num_steps, batch_size)
num_top_k (int): The number of top-k token log probabilities to
return.
Returns:
A tuple containing the lists as described above.
"""
# Serialize all tensors to CPU Python lists.
# Get the logprobs/rank of the accepted tokens.
(accepted_token_id_ranks_by_step_tensor,
accepted_token_id_logprobs_by_step_tensor
) = get_sampled_token_logprobs(
logprob_tensor=target_logprobs_by_step,
sampled_token_ids=accepted_token_ids_by_step,
)
# Get the top-k logprobs (which may or may not include the
# logprob of the accepted token).
(topk_logprobs_by_step_tensor,
topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
k=num_top_k,
dim=-1,
)
accepted_token_id_ranks_by_step = (
accepted_token_id_ranks_by_step_tensor.tolist())
accepted_token_id_logprobs_by_step = (
accepted_token_id_logprobs_by_step_tensor.tolist())
topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
topk_indices_by_step = topk_indices_by_step_tensor.tolist()
return (accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
topk_indices_by_step)
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
"""
Removes the finished requests and their associated sequence ids from

View File

@@ -0,0 +1,69 @@
from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
class TargetModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
that the time spent in the log probability calculation of the target model
is time wasted, since we calculate log probabilities after deciding which
tokens are accepted. For this reason disabling log probabilities in the
target model will make decode faster. The model runner sets the
SamplingMetadata parameters according to whether log probabilities are
requested or not.
"""
def __init__(self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self.disable_logprobs = True
super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
model_input: ModelInputForGPUWithSamplingMetadata = super(
).prepare_model_input(seq_group_metadata_list, virtual_engine,
finished_requests_ids)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# sampling related tensors which includes the logprobs tensors.
model_input.sampling_metadata.skip_sampler_cpu_output = (
self.disable_logprobs)
return model_input

View File

@@ -1,5 +1,5 @@
from contextlib import contextmanager
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
@@ -53,8 +53,8 @@ def create_sequence_group_output(
token_id_logprob_rank: int,
token_id_logprob: float,
seq_id: SeqId,
topk_token_ids: List[int],
topk_logprobs: List[float],
topk_token_ids: List[Optional[int]],
topk_logprobs: List[Optional[float]],
) -> CompletionSequenceGroupOutput:
"""Create a SequenceGroupOutput given the sampling results.
@@ -68,7 +68,7 @@ def create_sequence_group_output(
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs: Dict[int, Logprob] = {
logprobs: Dict[Optional[int], Logprob] = {
token_id: Logprob(
logprob=token_id_logprob,
rank=token_id_logprob_rank,