[Spec Decode] Unified Parallel Drafting (#32887)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
5b2a9422f0
commit
af3162d3aa
@@ -116,9 +116,16 @@ class SpeculativeConfig:
|
||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||
provided. Defaults to 1."""
|
||||
|
||||
# Alternative drafting strategies
|
||||
speculative_token_tree: str | None = None
|
||||
"""Specifies the tree structure for speculative token generation.
|
||||
"""
|
||||
parallel_drafting: bool = False
|
||||
"""Enable parallel drafting, where all speculative tokens are generated
|
||||
in parallel rather than sequentially. This can improve performance but
|
||||
requires the speculative model be trained to support parallel drafting.
|
||||
Only compatible with EAGLE and draft model methods."""
|
||||
|
||||
# required configuration params passed from engine
|
||||
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the target model."""
|
||||
|
||||
@@ -604,10 +604,13 @@ class VllmConfig:
|
||||
# Currently, async scheduling only support eagle speculative
|
||||
# decoding.
|
||||
if self.speculative_config is not None:
|
||||
if self.speculative_config.method not in get_args(EagleModelTypes):
|
||||
if (
|
||||
self.speculative_config.method not in get_args(EagleModelTypes)
|
||||
and self.speculative_config.method != "draft_model"
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, async scheduling is only supported "
|
||||
"with EAGLE/MTP kind of speculative decoding."
|
||||
"with EAGLE/MTP/Draft Model kind of speculative decoding."
|
||||
)
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
raise ValueError(
|
||||
@@ -1298,16 +1301,21 @@ class VllmConfig:
|
||||
computed_compile_ranges_split_points = []
|
||||
|
||||
# The upper bound of the compile ranges is the max_num_batched_tokens.
|
||||
# For speculative decoding with draft model, the compile range must be extended
|
||||
# by 1 for each sequence.
|
||||
# For speculative decoding, the compile range must be extended
|
||||
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
|
||||
# - Parallel draft: + num_speculative_tokens * max_num_seqs
|
||||
compile_range_end = self.scheduler_config.max_num_batched_tokens
|
||||
if compile_range_end is not None:
|
||||
do_extend: bool = (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.uses_draft_model()
|
||||
)
|
||||
if do_extend:
|
||||
compile_range_end += self.scheduler_config.max_num_seqs
|
||||
if self.speculative_config is not None and (
|
||||
self.speculative_config.uses_draft_model()
|
||||
or self.speculative_config.use_eagle()
|
||||
):
|
||||
multiplier = (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config.parallel_drafting
|
||||
else 1
|
||||
)
|
||||
compile_range_end += multiplier * self.scheduler_config.max_num_seqs
|
||||
|
||||
computed_compile_ranges_split_points.append(compile_range_end)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user