[Core] Support logprobs with spec decode + async scheduling (#29223)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
@@ -204,7 +205,9 @@ class RejectionSampler(nn.Module):
|
||||
def parse_output(
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
) -> list[list[int]]:
|
||||
discard_req_indices: Sequence[int] = (),
|
||||
return_cu_num_tokens: bool = False,
|
||||
) -> tuple[list[list[int]], list[int] | None]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
@@ -212,6 +215,8 @@ class RejectionSampler(nn.Module):
|
||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||
and will be filtered out in this function.
|
||||
vocab_size: The size of the vocabulary.
|
||||
discard_req_indices: Optional row indices to discard tokens in.
|
||||
return_cu_num_tokens: Whether to also return cumulative token counts.
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
"""
|
||||
@@ -220,10 +225,15 @@ class RejectionSampler(nn.Module):
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
cu_num_tokens = None
|
||||
if return_cu_num_tokens:
|
||||
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
|
||||
if len(discard_req_indices) > 0:
|
||||
valid_mask[discard_req_indices] = False
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
return outputs, cu_num_tokens
|
||||
|
||||
def apply_logits_processors(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user