Fix invalid logprobs with MTP enabled and sync scheduling (#38711)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
This commit is contained in:
@@ -4192,6 +4192,7 @@ class GPUModelRunner(
|
||||
spec_config = self.speculative_config
|
||||
propose_drafts_after_bookkeeping = False
|
||||
if spec_config is not None:
|
||||
# Decide whether to run the drafter or zero out draft tokens.
|
||||
input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
|
||||
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
|
||||
<= self.effective_drafter_max_model_len
|
||||
@@ -4227,10 +4228,6 @@ class GPUModelRunner(
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
self._draft_token_ids = torch.zeros(
|
||||
1, device=self.device, dtype=torch.int32
|
||||
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
|
||||
elif (
|
||||
spec_config.use_ngram_gpu()
|
||||
and not spec_config.disable_padded_drafter_batch
|
||||
@@ -4253,15 +4250,20 @@ class GPUModelRunner(
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
# Since we couldn't run the drafter,
|
||||
# just use zeros for the draft tokens.
|
||||
self._draft_token_ids = torch.zeros(
|
||||
1, device=self.device, dtype=torch.int32
|
||||
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
|
||||
else:
|
||||
propose_drafts_after_bookkeeping = input_fits_in_drafter
|
||||
|
||||
if not input_fits_in_drafter:
|
||||
# Zero out draft tokens so the scheduler doesn't schedule
|
||||
# stale drafts from the previous step.
|
||||
# For Nemotron-H: it is necessary to zero out the draft tokens,
|
||||
# otherwise the stale tokens will corrupt Mamba recurrent
|
||||
# state and logprobs for sequences near max_model_len.
|
||||
self._draft_token_ids = torch.zeros(
|
||||
1, device=self.device, dtype=torch.int32
|
||||
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
|
||||
(
|
||||
num_nans_in_logits,
|
||||
|
||||
Reference in New Issue
Block a user