From 682566b18e69d12a1ee603906417f508d61ac7ea Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Sun, 22 Feb 2026 11:18:46 -0500 Subject: [PATCH] [Bug] Refactor max_num_batched_tokens to account for drafting (#34898) Signed-off-by: Benjamin Chislett --- vllm/config/scheduler.py | 9 ++++++- vllm/config/speculative.py | 16 +++++++++++ vllm/config/vllm.py | 47 +++++++++++++++++++++++---------- vllm/v1/core/sched/scheduler.py | 6 ++++- vllm/v1/spec_decode/eagle.py | 5 +--- 5 files changed, 63 insertions(+), 20 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index fb162bd50..9f6284c4b 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -46,12 +46,19 @@ class SchedulerConfig: """The runner type to launch for the model.""" max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1) - """Maximum number of tokens to be processed in a single iteration. + """Maximum number of tokens that can be processed in a single iteration. The default value here is mainly for convenience when testing. In real usage, this should be set in `EngineArgs.create_engine_config`. """ + max_num_scheduled_tokens: int | None = Field(default=None) + """Maximum number of tokens that the scheduler may issue in a single iteration. + + This is usually equal to max_num_batched_tokens, but can be smaller in cases + when the model might append tokens into the batch (such as speculative decoding). + Defaults to max_num_batched_tokens.""" + max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1) """Maximum number of sequences to be processed in a single iteration. diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index dcc549c4c..847e846d4 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -750,6 +750,22 @@ class SpeculativeConfig: f"errors during speculative decoding." ) + @property + def max_num_new_slots_for_drafting(self) -> int: + """ + Calculate the maximum number of new slots that might be added to the batch + when drafting. + """ + slots_per_req = 0 # for serial non-draft-model methods, no change needed + if self.parallel_drafting: + # For parallel drafting, we need one new slot per 'masked' token + slots_per_req = self.num_speculative_tokens - 1 + if self.uses_draft_model(): + # For draft model-based speculation, we need one new slot per request + # Since we do not slice the draft tokens + slots_per_req += 1 + return slots_per_req + def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "mtp") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index e951e6f2c..5db217b22 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -822,6 +822,8 @@ class VllmConfig: self.speculative_config is None ) + self._set_max_num_scheduled_tokens() + if current_platform.support_static_graph_mode(): # if cudagraph_mode has full cudagraphs, we need to check support if model_config := self.model_config: @@ -1185,6 +1187,37 @@ class VllmConfig: if size % self.parallel_config.tensor_parallel_size == 0 ] + def _set_max_num_scheduled_tokens(self): + """ + In most cases, the scheduler may schedule a batch with as many tokens as the + worker is configured to handle. However for some speculative decoding methods, + the drafter model may insert additional slots into the batch when drafting. + To account for this, we need to decrease the max_num_scheduled_tokens by an + upper bound on the number of slots that can be added. + """ + if self.speculative_config is not None: + scheduled_token_delta = ( + self.speculative_config.max_num_new_slots_for_drafting + * self.scheduler_config.max_num_seqs + ) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + if self.scheduler_config.max_num_scheduled_tokens is None: + self.scheduler_config.max_num_scheduled_tokens = ( + max_num_batched_tokens - scheduled_token_delta + ) + + max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens + if max_num_batched_tokens < max_num_scheduled_tokens + ( + self.speculative_config.max_num_new_slots_for_drafting + * self.scheduler_config.max_num_seqs + ): + raise ValueError( + f"VllmConfig received max_num_scheduled_tokens but it does not have" + " enough slots to support the speculative decoding settings." + f" It should be greater by at least {scheduled_token_delta}, but" + f" got {max_num_batched_tokens=} and {max_num_scheduled_tokens=}." + ) + def _set_cudagraph_sizes(self): """ vLLM defines the default candidate list of batch sizes for CUDA graph @@ -1347,22 +1380,8 @@ class VllmConfig: computed_compile_ranges_split_points = [] # The upper bound of the compile ranges is the max_num_batched_tokens. - # For speculative decoding, the compile range must be extended - # - Sequential: + 1 * max_num_seqs (one draft token per iteration) - # - Parallel draft: + num_speculative_tokens * max_num_seqs compile_range_end = self.scheduler_config.max_num_batched_tokens if compile_range_end is not None: - if self.speculative_config is not None and ( - self.speculative_config.uses_draft_model() - or self.speculative_config.use_eagle() - ): - multiplier = ( - self.speculative_config.num_speculative_tokens - if self.speculative_config.parallel_drafting - else 1 - ) - compile_range_end += multiplier * self.scheduler_config.max_num_seqs - computed_compile_ranges_split_points.append(compile_range_end) # Add the compile ranges for flashinfer diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 25f848029..bf397ad68 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -99,7 +99,11 @@ class Scheduler(SchedulerInterface): # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = ( + self.scheduler_config.max_num_scheduled_tokens + if self.scheduler_config.max_num_scheduled_tokens + else self.scheduler_config.max_num_batched_tokens + ) self.max_model_len = vllm_config.model_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a6e7995bc..04450e989 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -100,11 +100,8 @@ class SpecDecodeBaseProposer: if self.parallel_drafting: self._init_parallel_drafting_params() - # The drafter can get longer sequences than the target model. max_batch_size = vllm_config.scheduler_config.max_num_seqs - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + ( - self.net_num_new_slots_per_request * max_batch_size - ) + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.token_arange_np = np.arange(self.max_num_tokens) # Multi-modal data support