[Bug] Refactor max_num_batched_tokens to account for drafting (#34898)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
b9c2a565cc
commit
682566b18e
@@ -822,6 +822,8 @@ class VllmConfig:
|
||||
self.speculative_config is None
|
||||
)
|
||||
|
||||
self._set_max_num_scheduled_tokens()
|
||||
|
||||
if current_platform.support_static_graph_mode():
|
||||
# if cudagraph_mode has full cudagraphs, we need to check support
|
||||
if model_config := self.model_config:
|
||||
@@ -1185,6 +1187,37 @@ class VllmConfig:
|
||||
if size % self.parallel_config.tensor_parallel_size == 0
|
||||
]
|
||||
|
||||
def _set_max_num_scheduled_tokens(self):
|
||||
"""
|
||||
In most cases, the scheduler may schedule a batch with as many tokens as the
|
||||
worker is configured to handle. However for some speculative decoding methods,
|
||||
the drafter model may insert additional slots into the batch when drafting.
|
||||
To account for this, we need to decrease the max_num_scheduled_tokens by an
|
||||
upper bound on the number of slots that can be added.
|
||||
"""
|
||||
if self.speculative_config is not None:
|
||||
scheduled_token_delta = (
|
||||
self.speculative_config.max_num_new_slots_for_drafting
|
||||
* self.scheduler_config.max_num_seqs
|
||||
)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
if self.scheduler_config.max_num_scheduled_tokens is None:
|
||||
self.scheduler_config.max_num_scheduled_tokens = (
|
||||
max_num_batched_tokens - scheduled_token_delta
|
||||
)
|
||||
|
||||
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
|
||||
if max_num_batched_tokens < max_num_scheduled_tokens + (
|
||||
self.speculative_config.max_num_new_slots_for_drafting
|
||||
* self.scheduler_config.max_num_seqs
|
||||
):
|
||||
raise ValueError(
|
||||
f"VllmConfig received max_num_scheduled_tokens but it does not have"
|
||||
" enough slots to support the speculative decoding settings."
|
||||
f" It should be greater by at least {scheduled_token_delta}, but"
|
||||
f" got {max_num_batched_tokens=} and {max_num_scheduled_tokens=}."
|
||||
)
|
||||
|
||||
def _set_cudagraph_sizes(self):
|
||||
"""
|
||||
vLLM defines the default candidate list of batch sizes for CUDA graph
|
||||
@@ -1347,22 +1380,8 @@ class VllmConfig:
|
||||
computed_compile_ranges_split_points = []
|
||||
|
||||
# The upper bound of the compile ranges is the max_num_batched_tokens.
|
||||
# For speculative decoding, the compile range must be extended
|
||||
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
|
||||
# - Parallel draft: + num_speculative_tokens * max_num_seqs
|
||||
compile_range_end = self.scheduler_config.max_num_batched_tokens
|
||||
if compile_range_end is not None:
|
||||
if self.speculative_config is not None and (
|
||||
self.speculative_config.uses_draft_model()
|
||||
or self.speculative_config.use_eagle()
|
||||
):
|
||||
multiplier = (
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config.parallel_drafting
|
||||
else 1
|
||||
)
|
||||
compile_range_end += multiplier * self.scheduler_config.max_num_seqs
|
||||
|
||||
computed_compile_ranges_split_points.append(compile_range_end)
|
||||
|
||||
# Add the compile ranges for flashinfer
|
||||
|
||||
Reference in New Issue
Block a user