[Bug] Refactor max_num_batched_tokens to account for drafting (#34898)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-22 11:18:46 -05:00
committed by GitHub
parent b9c2a565cc
commit 682566b18e
5 changed files with 63 additions and 20 deletions

View File

@@ -822,6 +822,8 @@ class VllmConfig:
self.speculative_config is None
)
self._set_max_num_scheduled_tokens()
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if model_config := self.model_config:
@@ -1185,6 +1187,37 @@ class VllmConfig:
if size % self.parallel_config.tensor_parallel_size == 0
]
def _set_max_num_scheduled_tokens(self):
"""
In most cases, the scheduler may schedule a batch with as many tokens as the
worker is configured to handle. However for some speculative decoding methods,
the drafter model may insert additional slots into the batch when drafting.
To account for this, we need to decrease the max_num_scheduled_tokens by an
upper bound on the number of slots that can be added.
"""
if self.speculative_config is not None:
scheduled_token_delta = (
self.speculative_config.max_num_new_slots_for_drafting
* self.scheduler_config.max_num_seqs
)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if self.scheduler_config.max_num_scheduled_tokens is None:
self.scheduler_config.max_num_scheduled_tokens = (
max_num_batched_tokens - scheduled_token_delta
)
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
if max_num_batched_tokens < max_num_scheduled_tokens + (
self.speculative_config.max_num_new_slots_for_drafting
* self.scheduler_config.max_num_seqs
):
raise ValueError(
f"VllmConfig received max_num_scheduled_tokens but it does not have"
" enough slots to support the speculative decoding settings."
f" It should be greater by at least {scheduled_token_delta}, but"
f" got {max_num_batched_tokens=} and {max_num_scheduled_tokens=}."
)
def _set_cudagraph_sizes(self):
"""
vLLM defines the default candidate list of batch sizes for CUDA graph
@@ -1347,22 +1380,8 @@ class VllmConfig:
computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding, the compile range must be extended
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
# - Parallel draft: + num_speculative_tokens * max_num_seqs
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
if self.speculative_config is not None and (
self.speculative_config.uses_draft_model()
or self.speculative_config.use_eagle()
):
multiplier = (
self.speculative_config.num_speculative_tokens
if self.speculative_config.parallel_drafting
else 1
)
compile_range_end += multiplier * self.scheduler_config.max_num_seqs
computed_compile_ranges_split_points.append(compile_range_end)
# Add the compile ranges for flashinfer