[BugFix] Fix returned logprobs with spec decode + prefill chunking (#29216)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-11-22 06:41:25 -08:00
committed by GitHub
parent 066209a045
commit d44a63c6d6
3 changed files with 22 additions and 15 deletions

View File

@@ -81,7 +81,10 @@ class Sampler(nn.Module):
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
if logits.dtype == torch.float32:
raw_logprobs = logits.clone()
else:
raw_logprobs = logits.to(torch.float32)
# Use float32 for the logits.
logits = logits.to(torch.float32)

View File

@@ -2466,7 +2466,9 @@ class GPUModelRunner(
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
cu_num_new_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2479,6 +2481,12 @@ class GPUModelRunner(
sampled_token_ids,
self.input_batch.vocab_size,
)
if logprobs_tensors:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens = [0]
for toks in valid_sampled_token_ids:
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
@@ -2506,10 +2514,6 @@ class GPUModelRunner(
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
logprobs_tensors = sampler_output.logprobs_tensors
cu_num_accepted_tokens = (
[0] if spec_decode_metadata and logprobs_tensors else None
)
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
@@ -2518,11 +2522,6 @@ class GPUModelRunner(
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + num_sampled_ids
)
if not sampled_ids:
continue
@@ -2544,7 +2543,7 @@ class GPUModelRunner(
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens)
logprobs_tensors.tolists(cu_num_new_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)