Fix invalid logprobs with MTP enabled and sync scheduling (#38711)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
danisereb
2026-04-03 19:24:37 +03:00
committed by GitHub
parent 7b1a7423be
commit 50cd5674b3

View File

@@ -4192,6 +4192,7 @@ class GPUModelRunner(
spec_config = self.speculative_config
propose_drafts_after_bookkeeping = False
if spec_config is not None:
# Decide whether to run the drafter or zero out draft tokens.
input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
<= self.effective_drafter_max_model_len
@@ -4227,10 +4228,6 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
elif (
spec_config.use_ngram_gpu()
and not spec_config.disable_padded_drafter_batch
@@ -4253,15 +4250,20 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
else:
propose_drafts_after_bookkeeping = input_fits_in_drafter
if not input_fits_in_drafter:
# Zero out draft tokens so the scheduler doesn't schedule
# stale drafts from the previous step.
# For Nemotron-H: it is necessary to zero out the draft tokens,
# otherwise the stale tokens will corrupt Mamba recurrent
# state and logprobs for sequences near max_model_len.
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
num_nans_in_logits,