[Spec Decode] Unified Parallel Drafting (#32887)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-05 12:37:18 -05:00
committed by GitHub
parent 5b2a9422f0
commit af3162d3aa
14 changed files with 1085 additions and 392 deletions

View File

@@ -116,9 +116,16 @@ class SpeculativeConfig:
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
# Alternative drafting strategies
speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation.
"""
parallel_drafting: bool = False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""

View File

@@ -604,10 +604,13 @@ class VllmConfig:
# Currently, async scheduling only support eagle speculative
# decoding.
if self.speculative_config is not None:
if self.speculative_config.method not in get_args(EagleModelTypes):
if (
self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method != "draft_model"
):
raise ValueError(
"Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding."
"with EAGLE/MTP/Draft Model kind of speculative decoding."
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
@@ -1298,16 +1301,21 @@ class VllmConfig:
computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding with draft model, the compile range must be extended
# by 1 for each sequence.
# For speculative decoding, the compile range must be extended
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
# - Parallel draft: + num_speculative_tokens * max_num_seqs
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
do_extend: bool = (
self.speculative_config is not None
and self.speculative_config.uses_draft_model()
)
if do_extend:
compile_range_end += self.scheduler_config.max_num_seqs
if self.speculative_config is not None and (
self.speculative_config.uses_draft_model()
or self.speculative_config.use_eagle()
):
multiplier = (
self.speculative_config.num_speculative_tokens
if self.speculative_config.parallel_drafting
else 1
)
compile_range_end += multiplier * self.scheduler_config.max_num_seqs
computed_compile_ranges_split_points.append(compile_range_end)