[V1] Do not detokenize if sampling param detokenize is False (#14224)

Signed-off-by: Himanshu Jaju <hj@mistral.ai>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Himanshu Jaju
2025-03-06 19:40:24 +01:00
committed by GitHub
parent 9f1710f1ac
commit cd579352bf
4 changed files with 69 additions and 27 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
@@ -13,12 +14,15 @@ from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
NONES = itertools.repeat(None)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request
tokenizer: AnyTokenizer
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: Optional[AnyTokenizer]
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
@@ -30,7 +34,7 @@ class LogprobsProcessor:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
tokenizer: Optional[AnyTokenizer],
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
@@ -66,8 +70,8 @@ class LogprobsProcessor:
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer, token_ids)
decoded_tokens = NONES if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer, token_ids))
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
@@ -103,9 +107,9 @@ class LogprobsProcessor:
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer,
token_ids.flatten().tolist())
decoded_tokens = None if self.tokenizer is None else (
convert_ids_list_to_tokens(self.tokenizer,
token_ids.flatten().tolist()))
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
@@ -121,7 +125,8 @@ class LogprobsProcessor:
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
decoded_tokens_for_pos = NONES \
if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
@@ -153,7 +158,7 @@ class LogprobsProcessor:
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: list[str],
decoded_tokens: Iterable[Optional[str]],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]: