[BugFix] Fix returned logprobs with spec decode + prefill chunking (#29216)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -81,7 +81,10 @@ class Sampler(nn.Module):
|
||||
if logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
elif logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
if logits.dtype == torch.float32:
|
||||
raw_logprobs = logits.clone()
|
||||
else:
|
||||
raw_logprobs = logits.to(torch.float32)
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
|
||||
@@ -2466,7 +2466,9 @@ class GPUModelRunner(
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
invalid_req_indices = []
|
||||
cu_num_new_tokens: list[int] | None = None
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
@@ -2479,6 +2481,12 @@ class GPUModelRunner(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
if logprobs_tensors:
|
||||
# Needed for extracting logprobs when spec decoding.
|
||||
# This must be done prior to discarding sampled tokens.
|
||||
cu_num_new_tokens = [0]
|
||||
for toks in valid_sampled_token_ids:
|
||||
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
@@ -2506,10 +2514,6 @@ class GPUModelRunner(
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
cu_num_accepted_tokens = (
|
||||
[0] if spec_decode_metadata and logprobs_tensors else None
|
||||
)
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
|
||||
@@ -2518,11 +2522,6 @@ class GPUModelRunner(
|
||||
|
||||
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
|
||||
|
||||
if cu_num_accepted_tokens is not None:
|
||||
cu_num_accepted_tokens.append(
|
||||
cu_num_accepted_tokens[-1] + num_sampled_ids
|
||||
)
|
||||
|
||||
if not sampled_ids:
|
||||
continue
|
||||
|
||||
@@ -2544,7 +2543,7 @@ class GPUModelRunner(
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
logprobs_lists = (
|
||||
logprobs_tensors.tolists(cu_num_accepted_tokens)
|
||||
logprobs_tensors.tolists(cu_num_new_tokens)
|
||||
if not self.use_async_scheduling and logprobs_tensors is not None
|
||||
else None
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user