[Bug] Refactor max_num_batched_tokens to account for drafting (#34898)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
b9c2a565cc
commit
682566b18e
@@ -46,12 +46,19 @@ class SchedulerConfig:
|
|||||||
"""The runner type to launch for the model."""
|
"""The runner type to launch for the model."""
|
||||||
|
|
||||||
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
|
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
|
||||||
"""Maximum number of tokens to be processed in a single iteration.
|
"""Maximum number of tokens that can be processed in a single iteration.
|
||||||
|
|
||||||
The default value here is mainly for convenience when testing.
|
The default value here is mainly for convenience when testing.
|
||||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
max_num_scheduled_tokens: int | None = Field(default=None)
|
||||||
|
"""Maximum number of tokens that the scheduler may issue in a single iteration.
|
||||||
|
|
||||||
|
This is usually equal to max_num_batched_tokens, but can be smaller in cases
|
||||||
|
when the model might append tokens into the batch (such as speculative decoding).
|
||||||
|
Defaults to max_num_batched_tokens."""
|
||||||
|
|
||||||
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
|
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
|
||||||
"""Maximum number of sequences to be processed in a single iteration.
|
"""Maximum number of sequences to be processed in a single iteration.
|
||||||
|
|
||||||
|
|||||||
@@ -750,6 +750,22 @@ class SpeculativeConfig:
|
|||||||
f"errors during speculative decoding."
|
f"errors during speculative decoding."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_num_new_slots_for_drafting(self) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the maximum number of new slots that might be added to the batch
|
||||||
|
when drafting.
|
||||||
|
"""
|
||||||
|
slots_per_req = 0 # for serial non-draft-model methods, no change needed
|
||||||
|
if self.parallel_drafting:
|
||||||
|
# For parallel drafting, we need one new slot per 'masked' token
|
||||||
|
slots_per_req = self.num_speculative_tokens - 1
|
||||||
|
if self.uses_draft_model():
|
||||||
|
# For draft model-based speculation, we need one new slot per request
|
||||||
|
# Since we do not slice the draft tokens
|
||||||
|
slots_per_req += 1
|
||||||
|
return slots_per_req
|
||||||
|
|
||||||
def use_eagle(self) -> bool:
|
def use_eagle(self) -> bool:
|
||||||
return self.method in ("eagle", "eagle3", "mtp")
|
return self.method in ("eagle", "eagle3", "mtp")
|
||||||
|
|
||||||
|
|||||||
@@ -822,6 +822,8 @@ class VllmConfig:
|
|||||||
self.speculative_config is None
|
self.speculative_config is None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._set_max_num_scheduled_tokens()
|
||||||
|
|
||||||
if current_platform.support_static_graph_mode():
|
if current_platform.support_static_graph_mode():
|
||||||
# if cudagraph_mode has full cudagraphs, we need to check support
|
# if cudagraph_mode has full cudagraphs, we need to check support
|
||||||
if model_config := self.model_config:
|
if model_config := self.model_config:
|
||||||
@@ -1185,6 +1187,37 @@ class VllmConfig:
|
|||||||
if size % self.parallel_config.tensor_parallel_size == 0
|
if size % self.parallel_config.tensor_parallel_size == 0
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _set_max_num_scheduled_tokens(self):
|
||||||
|
"""
|
||||||
|
In most cases, the scheduler may schedule a batch with as many tokens as the
|
||||||
|
worker is configured to handle. However for some speculative decoding methods,
|
||||||
|
the drafter model may insert additional slots into the batch when drafting.
|
||||||
|
To account for this, we need to decrease the max_num_scheduled_tokens by an
|
||||||
|
upper bound on the number of slots that can be added.
|
||||||
|
"""
|
||||||
|
if self.speculative_config is not None:
|
||||||
|
scheduled_token_delta = (
|
||||||
|
self.speculative_config.max_num_new_slots_for_drafting
|
||||||
|
* self.scheduler_config.max_num_seqs
|
||||||
|
)
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
if self.scheduler_config.max_num_scheduled_tokens is None:
|
||||||
|
self.scheduler_config.max_num_scheduled_tokens = (
|
||||||
|
max_num_batched_tokens - scheduled_token_delta
|
||||||
|
)
|
||||||
|
|
||||||
|
max_num_scheduled_tokens = self.scheduler_config.max_num_scheduled_tokens
|
||||||
|
if max_num_batched_tokens < max_num_scheduled_tokens + (
|
||||||
|
self.speculative_config.max_num_new_slots_for_drafting
|
||||||
|
* self.scheduler_config.max_num_seqs
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"VllmConfig received max_num_scheduled_tokens but it does not have"
|
||||||
|
" enough slots to support the speculative decoding settings."
|
||||||
|
f" It should be greater by at least {scheduled_token_delta}, but"
|
||||||
|
f" got {max_num_batched_tokens=} and {max_num_scheduled_tokens=}."
|
||||||
|
)
|
||||||
|
|
||||||
def _set_cudagraph_sizes(self):
|
def _set_cudagraph_sizes(self):
|
||||||
"""
|
"""
|
||||||
vLLM defines the default candidate list of batch sizes for CUDA graph
|
vLLM defines the default candidate list of batch sizes for CUDA graph
|
||||||
@@ -1347,22 +1380,8 @@ class VllmConfig:
|
|||||||
computed_compile_ranges_split_points = []
|
computed_compile_ranges_split_points = []
|
||||||
|
|
||||||
# The upper bound of the compile ranges is the max_num_batched_tokens.
|
# The upper bound of the compile ranges is the max_num_batched_tokens.
|
||||||
# For speculative decoding, the compile range must be extended
|
|
||||||
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
|
|
||||||
# - Parallel draft: + num_speculative_tokens * max_num_seqs
|
|
||||||
compile_range_end = self.scheduler_config.max_num_batched_tokens
|
compile_range_end = self.scheduler_config.max_num_batched_tokens
|
||||||
if compile_range_end is not None:
|
if compile_range_end is not None:
|
||||||
if self.speculative_config is not None and (
|
|
||||||
self.speculative_config.uses_draft_model()
|
|
||||||
or self.speculative_config.use_eagle()
|
|
||||||
):
|
|
||||||
multiplier = (
|
|
||||||
self.speculative_config.num_speculative_tokens
|
|
||||||
if self.speculative_config.parallel_drafting
|
|
||||||
else 1
|
|
||||||
)
|
|
||||||
compile_range_end += multiplier * self.scheduler_config.max_num_seqs
|
|
||||||
|
|
||||||
computed_compile_ranges_split_points.append(compile_range_end)
|
computed_compile_ranges_split_points.append(compile_range_end)
|
||||||
|
|
||||||
# Add the compile ranges for flashinfer
|
# Add the compile ranges for flashinfer
|
||||||
|
|||||||
@@ -99,7 +99,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Scheduling constraints.
|
# Scheduling constraints.
|
||||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||||
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
|
self.max_num_scheduled_tokens = (
|
||||||
|
self.scheduler_config.max_num_scheduled_tokens
|
||||||
|
if self.scheduler_config.max_num_scheduled_tokens
|
||||||
|
else self.scheduler_config.max_num_batched_tokens
|
||||||
|
)
|
||||||
self.max_model_len = vllm_config.model_config.max_model_len
|
self.max_model_len = vllm_config.model_config.max_model_len
|
||||||
self.enable_kv_cache_events = (
|
self.enable_kv_cache_events = (
|
||||||
self.kv_events_config is not None
|
self.kv_events_config is not None
|
||||||
|
|||||||
@@ -100,11 +100,8 @@ class SpecDecodeBaseProposer:
|
|||||||
if self.parallel_drafting:
|
if self.parallel_drafting:
|
||||||
self._init_parallel_drafting_params()
|
self._init_parallel_drafting_params()
|
||||||
|
|
||||||
# The drafter can get longer sequences than the target model.
|
|
||||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + (
|
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
self.net_num_new_slots_per_request * max_batch_size
|
|
||||||
)
|
|
||||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
|
|||||||
Reference in New Issue
Block a user