From 50cd5674b39c69e60fefd0ec1d61652d693e6a06 Mon Sep 17 00:00:00 2001 From: danisereb Date: Fri, 3 Apr 2026 19:24:37 +0300 Subject: [PATCH] Fix invalid logprobs with MTP enabled and sync scheduling (#38711) Signed-off-by: Daniel Serebrenik --- vllm/v1/worker/gpu_model_runner.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 98b3e56ab..7a21117fb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4192,6 +4192,7 @@ class GPUModelRunner( spec_config = self.speculative_config propose_drafts_after_bookkeeping = False if spec_config is not None: + # Decide whether to run the drafter or zero out draft tokens. input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= self.effective_drafter_max_model_len @@ -4227,10 +4228,6 @@ class GPUModelRunner( self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) - self._draft_token_ids = torch.zeros( - 1, device=self.device, dtype=torch.int32 - ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) - self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) elif ( spec_config.use_ngram_gpu() and not spec_config.disable_padded_drafter_batch @@ -4253,15 +4250,20 @@ class GPUModelRunner( self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) - # Since we couldn't run the drafter, - # just use zeros for the draft tokens. - self._draft_token_ids = torch.zeros( - 1, device=self.device, dtype=torch.int32 - ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) - self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) else: propose_drafts_after_bookkeeping = input_fits_in_drafter + if not input_fits_in_drafter: + # Zero out draft tokens so the scheduler doesn't schedule + # stale drafts from the previous step. + # For Nemotron-H: it is necessary to zero out the draft tokens, + # otherwise the stale tokens will corrupt Mamba recurrent + # state and logprobs for sequences near max_model_len. + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits,