[Bug] Refactor max_num_batched_tokens to account for drafting (#34898)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-22 11:18:46 -05:00
committed by GitHub
parent b9c2a565cc
commit 682566b18e
5 changed files with 63 additions and 20 deletions

View File

@@ -46,12 +46,19 @@ class SchedulerConfig:
"""The runner type to launch for the model."""
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
"""Maximum number of tokens to be processed in a single iteration.
"""Maximum number of tokens that can be processed in a single iteration.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_scheduled_tokens: int | None = Field(default=None)
"""Maximum number of tokens that the scheduler may issue in a single iteration.
This is usually equal to max_num_batched_tokens, but can be smaller in cases
when the model might append tokens into the batch (such as speculative decoding).
Defaults to max_num_batched_tokens."""
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
"""Maximum number of sequences to be processed in a single iteration.

View File

@@ -750,6 +750,22 @@ class SpeculativeConfig:
f"errors during speculative decoding."
)
@property
def max_num_new_slots_for_drafting(self) -> int:
"""
Calculate the maximum number of new slots that might be added to the batch
when drafting.
"""
slots_per_req = 0 # for serial non-draft-model methods, no change needed
if self.parallel_drafting:
# For parallel drafting, we need one new slot per 'masked' token
slots_per_req = self.num_speculative_tokens - 1
if self.uses_draft_model():
# For draft model-based speculation, we need one new slot per request
# Since we do not slice the draft tokens
slots_per_req += 1
return slots_per_req
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")

View File

@@ -822,6 +822,8 @@ class VllmConfig:
self.speculative_config is None
)
self._set_max_num_scheduled_tokens()
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if model_config := self.model_config:
@@ -1185,6 +1187,37 @@ class VllmConfig:
if size % self.parallel_config.tensor_parallel_size == 0
]
def _set_max_num_scheduled_tokens(self):
"""
In most cases, the scheduler may schedule a batch with as many tokens as the
worker is configured to handle. However for some speculative decoding methods,
the drafter model may insert additional slots into the batch when drafting.
To account for this, we need to decrease the max_num_scheduled_tokens by an
upper bound on the number of slots that can be added.
"""
if self.speculative_config is not None:
scheduled_token_delta = (
self.speculative_config.max_num_new_slots_for_drafting
* self.scheduler_config.max_num_seqs
)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
if self.scheduler_config.max_num_scheduled_tokens is None:
self.scheduler_config.max_num_scheduled_tokens = (
max_num_batched_tokens - scheduled_token_delta
)
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
if max_num_batched_tokens < max_num_scheduled_tokens + (
self.speculative_config.max_num_new_slots_for_drafting
* self.scheduler_config.max_num_seqs
):
raise ValueError(
f"VllmConfig received max_num_scheduled_tokens but it does not have"
" enough slots to support the speculative decoding settings."
f" It should be greater by at least {scheduled_token_delta}, but"
f" got {max_num_batched_tokens=} and {max_num_scheduled_tokens=}."
)
def _set_cudagraph_sizes(self):
"""
vLLM defines the default candidate list of batch sizes for CUDA graph
@@ -1347,22 +1380,8 @@ class VllmConfig:
computed_compile_ranges_split_points = []
# The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding, the compile range must be extended
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
# - Parallel draft: + num_speculative_tokens * max_num_seqs
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
if self.speculative_config is not None and (
self.speculative_config.uses_draft_model()
or self.speculative_config.use_eagle()
):
multiplier = (
self.speculative_config.num_speculative_tokens
if self.speculative_config.parallel_drafting
else 1
)
compile_range_end += multiplier * self.scheduler_config.max_num_seqs
computed_compile_ranges_split_points.append(compile_range_end)
# Add the compile ranges for flashinfer

View File

@@ -99,7 +99,11 @@ class Scheduler(SchedulerInterface):
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_scheduled_tokens = (
self.scheduler_config.max_num_scheduled_tokens
if self.scheduler_config.max_num_scheduled_tokens
else self.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
self.enable_kv_cache_events = (
self.kv_events_config is not None

View File

@@ -100,11 +100,8 @@ class SpecDecodeBaseProposer:
if self.parallel_drafting:
self._init_parallel_drafting_params()
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + (
self.net_num_new_slots_per_request * max_batch_size
)
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
# Multi-modal data support