[Misc] Tidy up some spec decode logic in GPUModelRunner (#31591)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-01-08 09:10:07 -08:00
committed by GitHub
parent 49568d5cf9
commit a3d909ad2b
3 changed files with 51 additions and 50 deletions

View File

@@ -139,6 +139,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
def get_supported_tasks(self) -> tuple[str]:
return ("generate",)

View File

@@ -452,6 +452,11 @@ class GPUModelRunner(
self.num_spec_tokens = 0
if self.speculative_config:
self.num_spec_tokens = self.speculative_config.num_speculative_tokens
draft_config = self.speculative_config.draft_model_config
if draft_config is not None and draft_config.max_model_len is not None:
self.effective_drafter_max_model_len = draft_config.max_model_len
else:
self.effective_drafter_max_model_len = self.max_model_len
# Request states.
self.requests: dict[str, CachedRequestState] = {}
@@ -674,6 +679,13 @@ class GPUModelRunner(
self.kv_connector_output: KVConnectorOutput | None = None
self.layerwise_nvtx_hooks_registered = False
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
if self.speculative_config:
draft_config = self.speculative_config.draft_model_config
if draft_config is None or draft_config.max_model_len is None:
self.effective_drafter_max_model_len = self.max_model_len
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@@ -3399,54 +3411,41 @@ class GPUModelRunner(
self._copy_draft_token_ids_to_cpu(scheduler_output)
spec_config = self.speculative_config
use_padded_batch_for_eagle = (
spec_config is not None
and spec_config.use_eagle()
and not spec_config.disable_padded_drafter_batch
)
effective_drafter_max_model_len = self.max_model_len
if effective_drafter_max_model_len is None:
effective_drafter_max_model_len = self.model_config.max_model_len
if (
spec_config is not None
and spec_config.draft_model_config is not None
and spec_config.draft_model_config.max_model_len is not None
):
effective_drafter_max_model_len = (
spec_config.draft_model_config.max_model_len
propose_drafts_after_bookkeeping = False
if spec_config is not None:
input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= self.effective_drafter_max_model_len
)
input_fits_in_drafter = spec_decode_common_attn_metadata and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= effective_drafter_max_model_len
)
if use_padded_batch_for_eagle:
assert self.speculative_config is not None
assert isinstance(self.drafter, EagleProposer)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
spec_decode_common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_mask.gpu,
assert isinstance(self.drafter, EagleProposer)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
spec_decode_common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_mask.gpu,
)
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
else:
propose_drafts_after_bookkeeping = input_fits_in_drafter
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
@@ -3466,17 +3465,14 @@ class GPUModelRunner(
spec_decode_metadata,
)
if (
self.speculative_config
and not use_padded_batch_for_eagle
and input_fits_in_drafter
):
if propose_drafts_after_bookkeeping:
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
@@ -3494,6 +3490,7 @@ class GPUModelRunner(
if not self.use_async_scheduling:
return output
with record_function_or_nullcontext(
"gpu_model_runner: AsyncGPUModelRunnerOutput"
):

View File

@@ -390,7 +390,7 @@ class Worker(WorkerBase):
"""
self.model_config.max_model_len = max_model_len
if self.model_runner is not None:
self.model_runner.max_model_len = max_model_len
self.model_runner.update_max_model_len(max_model_len)
logger.debug("Updated max_model_len to %d", max_model_len)
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: