[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)
This commit is contained in:
@@ -14,7 +14,7 @@ from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
HiddenStates, SamplerOutput, SequenceGroupMetadata,
|
||||
get_all_seq_ids_and_request_ids)
|
||||
get_all_seq_ids, get_all_seq_ids_and_request_ids)
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
@@ -26,6 +26,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
@@ -44,9 +45,15 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
|
||||
assert speculative_config is not None
|
||||
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
|
||||
draft_worker_kwargs = kwargs.copy()
|
||||
|
||||
kwargs["model_runner_cls"] = TargetModelRunner
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
# Set the disable_logprobs variable in the TargetModelRunner instance
|
||||
# as per its value specified in the SpeculativeConfig.
|
||||
target_worker.model_runner.disable_logprobs =\
|
||||
speculative_config.disable_logprobs
|
||||
|
||||
# Override draft-model specific worker args.
|
||||
draft_worker_kwargs.update(
|
||||
model_config=speculative_config.draft_model_config,
|
||||
@@ -67,7 +74,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
typical_acceptance_sampler_posterior_threshold=speculative_config.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||
typical_acceptance_sampler_posterior_alpha)
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=speculative_config.disable_logprobs)
|
||||
|
||||
return spec_decode_worker
|
||||
|
||||
@@ -107,6 +115,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
@@ -161,6 +170,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
return SpecDecodeWorker(
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step)
|
||||
@@ -170,6 +180,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker: ProposerWorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
disable_logprobs: bool,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
allow_zero_draft_token_step: Optional[bool] = True,
|
||||
@@ -189,6 +200,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
types of sampler namely RejectionSampler and
|
||||
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
||||
instance of RejectionSampler or TypicalAcceptanceSampler.
|
||||
disable_logprobs: If set to True, token log probabilities will
|
||||
not be output in both the draft worker and the target worker.
|
||||
If set to False, log probabilities will be output by both.
|
||||
disable_by_batch_size: If the batch size is larger than this,
|
||||
disable speculative decoding for new incoming requests.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
@@ -222,6 +236,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Hidden states from target model to pass to proposer
|
||||
# in the subsequent step.
|
||||
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||
self._disable_logprobs = disable_logprobs
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
@@ -357,7 +372,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
) == 0 or disable_all_speculation:
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all_speculation)
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req,
|
||||
num_lookahead_slots)
|
||||
|
||||
@@ -391,6 +405,42 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# this state within spec decode worker.
|
||||
seq_group_metadata.num_speculative_tokens = 0
|
||||
|
||||
def _serialize_sampler_output_no_logprobs(
|
||||
self, execute_model_req: ExecuteModelRequest,
|
||||
sampler_output: SamplerOutput) -> SamplerOutput:
|
||||
"""
|
||||
Creates and returns a `SamplerOutput` with only the sampled token IDs
|
||||
being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
|
||||
All other parameters in `CompletionSequenceGroupOutput` related to log
|
||||
probabilities are skipped.
|
||||
|
||||
Args:
|
||||
execute_model_req (ExecuteModelRequest): The model request that
|
||||
was executed.
|
||||
sampler_output (SamplerOutput): The output from the sampler with
|
||||
only GPU tensors populated.
|
||||
|
||||
Returns:
|
||||
SamplerOutput: A new `SamplerOutput` instance containing a list of
|
||||
`CompletionSequenceGroupOutput` objects with only sampled token
|
||||
IDs populated.
|
||||
"""
|
||||
seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
|
||||
sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
|
||||
completion_seq_group_output_list: List[
|
||||
CompletionSequenceGroupOutput] = []
|
||||
for index, seq_id in enumerate(seq_ids):
|
||||
completion_seq_group_output_list.append(
|
||||
create_sequence_group_output(
|
||||
token_id=sampled_token_ids_list[index][0],
|
||||
token_id_logprob_rank=-1,
|
||||
token_id_logprob=0.0,
|
||||
seq_id=seq_id,
|
||||
topk_token_ids=[],
|
||||
topk_logprobs=[],
|
||||
))
|
||||
return SamplerOutput(outputs=completion_seq_group_output_list)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
skip_proposer: bool) -> List[SamplerOutput]:
|
||||
@@ -417,12 +467,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.previous_hidden_states.update(
|
||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
||||
|
||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||
if self._disable_logprobs else
|
||||
sampler_output)
|
||||
|
||||
# Clear device tensors from sampler output. This reduces communication
|
||||
# overhead when the engine runs in a different process than the workers.
|
||||
sampler_output.probs = None
|
||||
sampler_output.sampled_tokens = None
|
||||
sampler_output.sampled_token_probs = None
|
||||
sampler_output.sampled_token_ids = None
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output]
|
||||
return [sampler_output_to_return]
|
||||
|
||||
def _run_non_driver_rank(self) -> bool:
|
||||
"""Run proposer and verifier model in non-driver workers. This is used
|
||||
@@ -480,7 +535,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
@@ -601,25 +655,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
the same number of outputs.
|
||||
"""
|
||||
batch_size, num_steps = accepted_token_ids.shape
|
||||
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
|
||||
|
||||
# Get the logprobs/rank of the accepted tokens.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
|
||||
logprob_tensor=target_logprobs_by_step,
|
||||
sampled_token_ids=accepted_token_ids_by_step,
|
||||
)
|
||||
|
||||
# Get the top-k logprobs (which may or may not include the logprob of
|
||||
# the accepted token).
|
||||
(topk_logprobs_by_step,
|
||||
topk_indices_by_step) = target_logprobs_by_step.topk(
|
||||
k=self.scorer_worker.model_config.max_logprobs,
|
||||
dim=-1,
|
||||
)
|
||||
if self._disable_logprobs:
|
||||
# We are skipping the logprobs. Hence don't serialize the
|
||||
# logprobs related tensors from the GPU. Instead create
|
||||
# empty/dummy lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_dummy_logprob_lists(
|
||||
batch_size, num_steps,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
else:
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
target_logprobs_by_step = target_logprobs.transpose(0, 1)
|
||||
# Serialize all tensors into Python lists.
|
||||
(accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step,
|
||||
topk_logprobs_by_step, topk_indices_by_step) =\
|
||||
self._create_logprob_lists_from_tensors(
|
||||
target_logprobs_by_step, accepted_token_ids_by_step,
|
||||
self.scorer_worker.model_config.max_logprobs)
|
||||
|
||||
# Get the sequence ids and num_logprobs (sampling parameter) in the
|
||||
# batch.
|
||||
@@ -628,14 +684,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
|
||||
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
# Serialize tensor to CPU Python list.
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
accepted_token_id_ranks_by_step = (
|
||||
accepted_token_id_ranks_by_step.tolist())
|
||||
accepted_token_id_logprobs_by_step = (
|
||||
accepted_token_id_logprobs_by_step.tolist())
|
||||
topk_logprobs_by_step = topk_logprobs_by_step.tolist()
|
||||
topk_indices_by_step = topk_indices_by_step.tolist()
|
||||
|
||||
# Construct the output on a per-step, per-sequence basis.
|
||||
sampler_output_list: List[SamplerOutput] = []
|
||||
@@ -677,6 +727,108 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
return sampler_output_list
|
||||
|
||||
def _create_dummy_logprob_lists(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_steps: int,
|
||||
num_top_k: int,
|
||||
) -> Tuple[List[List[int]], List[List[float]],
|
||||
List[List[List[Optional[float]]]],
|
||||
List[List[List[Optional[int]]]]]:
|
||||
"""
|
||||
Creates and returns four dummy lists representing token probabilities
|
||||
and their ranks.
|
||||
|
||||
This method initializes and returns:
|
||||
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
|
||||
- The log probabilities of the accepted tokens,
|
||||
shaped (num_steps, batch_size)
|
||||
- The log probabilities of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
- The token IDs of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
|
||||
Args:
|
||||
batch_size (int): The size of the batch.
|
||||
num_steps (int): The number of steps in the sequence.
|
||||
num_top_k (int): The number of top-k token log probabilities to
|
||||
return.
|
||||
|
||||
Returns:
|
||||
A tuple containing four dummy lists as described above.
|
||||
"""
|
||||
accepted_token_id_ranks_by_step = [[-1] * batch_size
|
||||
for _ in range(num_steps)]
|
||||
accepted_token_id_logprobs_by_step = [[0.0] * batch_size
|
||||
for _ in range(num_steps)]
|
||||
topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
|
||||
[None] * num_top_k for _ in range(batch_size)
|
||||
] for _ in range(num_steps)]
|
||||
topk_indices_by_step: List[List[List[Optional[int]]]] = [[
|
||||
[None] * num_top_k for _ in range(batch_size)
|
||||
] for _ in range(num_steps)]
|
||||
return (accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
|
||||
topk_indices_by_step)
|
||||
|
||||
def _create_logprob_lists_from_tensors(
|
||||
self,
|
||||
target_logprobs_by_step: torch.Tensor,
|
||||
accepted_token_ids_by_step: torch.Tensor,
|
||||
num_top_k: int,
|
||||
) -> Tuple[List[List[int]], List[List[float]],
|
||||
List[List[List[Optional[float]]]],
|
||||
List[List[List[Optional[int]]]]]:
|
||||
"""
|
||||
Creates and returns four lists representing token probabilities and
|
||||
their ranks.
|
||||
|
||||
This method initializes and returns four lists containing:
|
||||
- The ranks of the accepted tokens, shaped (num_steps, batch_size)
|
||||
- The log probabilities of the accepted tokens,
|
||||
shaped (num_steps, batch_size)
|
||||
- The log probabilities of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
- The token IDs of the top k tokens,
|
||||
shaped (num_steps, batch_size, num_top_k)
|
||||
|
||||
Args:
|
||||
target_logprobs_by_step (torch.Tensor): Tensor representing the
|
||||
log probabilities of the target model,
|
||||
shaped (num_steps, batch_size, vocab_size)
|
||||
accepted_token_ids_by_step (torch.Tensor): Tensor representing
|
||||
the accepted token_ids, shaped (num_steps, batch_size)
|
||||
num_top_k (int): The number of top-k token log probabilities to
|
||||
return.
|
||||
|
||||
Returns:
|
||||
A tuple containing the lists as described above.
|
||||
"""
|
||||
# Serialize all tensors to CPU Python lists.
|
||||
# Get the logprobs/rank of the accepted tokens.
|
||||
(accepted_token_id_ranks_by_step_tensor,
|
||||
accepted_token_id_logprobs_by_step_tensor
|
||||
) = get_sampled_token_logprobs(
|
||||
logprob_tensor=target_logprobs_by_step,
|
||||
sampled_token_ids=accepted_token_ids_by_step,
|
||||
)
|
||||
# Get the top-k logprobs (which may or may not include the
|
||||
# logprob of the accepted token).
|
||||
(topk_logprobs_by_step_tensor,
|
||||
topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
|
||||
k=num_top_k,
|
||||
dim=-1,
|
||||
)
|
||||
accepted_token_id_ranks_by_step = (
|
||||
accepted_token_id_ranks_by_step_tensor.tolist())
|
||||
accepted_token_id_logprobs_by_step = (
|
||||
accepted_token_id_logprobs_by_step_tensor.tolist())
|
||||
topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
|
||||
topk_indices_by_step = topk_indices_by_step_tensor.tolist()
|
||||
return (accepted_token_id_ranks_by_step,
|
||||
accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
|
||||
topk_indices_by_step)
|
||||
|
||||
def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
|
||||
"""
|
||||
Removes the finished requests and their associated sequence ids from
|
||||
|
||||
69
vllm/spec_decode/target_model_runner.py
Normal file
69
vllm/spec_decode/target_model_runner.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
|
||||
|
||||
class TargetModelRunner(ModelRunner):
|
||||
"""Specialized model runner for speculative decoding target model.
|
||||
In speculative decoding, the log probabilities selected finally may not
|
||||
be the same ones as selected by the target model sampling. This means
|
||||
that the time spent in the log probability calculation of the target model
|
||||
is time wasted, since we calculate log probabilities after deciding which
|
||||
tokens are accepted. For this reason disabling log probabilities in the
|
||||
target model will make decode faster. The model runner sets the
|
||||
SamplingMetadata parameters according to whether log probabilities are
|
||||
requested or not.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False):
|
||||
# An internal boolean member variable to indicate if token log
|
||||
# probabilities are needed or not.
|
||||
self.disable_logprobs = True
|
||||
super().__init__(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
lora_config=lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
multimodal_config=multimodal_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
model_input: ModelInputForGPUWithSamplingMetadata = super(
|
||||
).prepare_model_input(seq_group_metadata_list, virtual_engine,
|
||||
finished_requests_ids)
|
||||
# If token log probabilities is disabled then skip generating sampler
|
||||
# CPU output. We directly serialize the GPU sampled_token_id tensors
|
||||
# as needed. If log probabilities is enabled then synchronize all the
|
||||
# sampling related tensors which includes the logprobs tensors.
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
self.disable_logprobs)
|
||||
return model_input
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -53,8 +53,8 @@ def create_sequence_group_output(
|
||||
token_id_logprob_rank: int,
|
||||
token_id_logprob: float,
|
||||
seq_id: SeqId,
|
||||
topk_token_ids: List[int],
|
||||
topk_logprobs: List[float],
|
||||
topk_token_ids: List[Optional[int]],
|
||||
topk_logprobs: List[Optional[float]],
|
||||
) -> CompletionSequenceGroupOutput:
|
||||
"""Create a SequenceGroupOutput given the sampling results.
|
||||
|
||||
@@ -68,7 +68,7 @@ def create_sequence_group_output(
|
||||
"""
|
||||
# vLLM logprobs always include the sampled token. In addition, the user may
|
||||
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
||||
logprobs: Dict[int, Logprob] = {
|
||||
logprobs: Dict[Optional[int], Logprob] = {
|
||||
token_id: Logprob(
|
||||
logprob=token_id_logprob,
|
||||
rank=token_id_logprob_rank,
|
||||
|
||||
Reference in New Issue
Block a user