[V1] Do not detokenize if sampling param detokenize is False (#14224)
Signed-off-by: Himanshu Jaju <hj@mistral.ai> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -13,12 +14,15 @@ from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobsProcessor:
|
||||
|
||||
# Tokenizer for this request
|
||||
tokenizer: AnyTokenizer
|
||||
# Tokenizer for this request,
|
||||
# None if detokenization is disabled.
|
||||
tokenizer: Optional[AnyTokenizer]
|
||||
|
||||
# Logprobs for this request
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
@@ -30,7 +34,7 @@ class LogprobsProcessor:
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
num_logprobs = request.sampling_params.logprobs
|
||||
@@ -66,8 +70,8 @@ class LogprobsProcessor:
|
||||
token_ids_lst):
|
||||
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens = convert_ids_list_to_tokens(
|
||||
self.tokenizer, token_ids)
|
||||
decoded_tokens = NONES if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer, token_ids))
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
@@ -103,9 +107,9 @@ class LogprobsProcessor:
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = convert_ids_list_to_tokens(
|
||||
self.tokenizer,
|
||||
token_ids.flatten().tolist())
|
||||
decoded_tokens = None if self.tokenizer is None else (
|
||||
convert_ids_list_to_tokens(self.tokenizer,
|
||||
token_ids.flatten().tolist()))
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
@@ -121,7 +125,8 @@ class LogprobsProcessor:
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
|
||||
decoded_tokens_for_pos = NONES \
|
||||
if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.prompt_logprobs.append(
|
||||
@@ -153,7 +158,7 @@ class LogprobsProcessor:
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: list[str],
|
||||
decoded_tokens: Iterable[Optional[str]],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
|
||||
Reference in New Issue
Block a user