[V1][Perf] Faster incremental detokenization (#15137)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-04-17 07:45:24 -07:00
committed by GitHub
parent 7c02d6a137
commit 05fcd1b430
7 changed files with 317 additions and 145 deletions

View File

@@ -4,14 +4,22 @@ from collections.abc import Generator
from typing import Any, Optional
import pytest
from transformers import AutoTokenizer
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer)
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
]
TRUTH = [
"Hello here, this is a simple test",
@@ -22,7 +30,8 @@ TRUTH = [
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်",
]
] + SPECIAL_TOKS_TRUTH
TOKENIZERS = [
"facebook/opt-125m",
"gpt2",
@@ -38,26 +47,37 @@ TOKENIZERS = [
]
def _run_incremental_decode(tokenizer, all_input_ids,
skip_special_tokens: bool, starting_index: int):
decoded_text = ""
offset = 0
token_offset = 0
prev_tokens = None
for i in range(starting_index, len(all_input_ids)):
new_tokens, text, offset, token_offset = detokenize_incrementally(
tokenizer,
all_input_ids[:i + 1],
prev_tokens,
offset,
token_offset,
skip_special_tokens=skip_special_tokens)
decoded_text += text
if prev_tokens is None:
prev_tokens = new_tokens
else:
prev_tokens += new_tokens
return decoded_text
def _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: Optional[bool] = None):
prompt_token_ids = all_input_ids[:starting_index]
params = SamplingParams(
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
params, None, 0.0, None)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer, request)
elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else:
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
output_text = ""
for i, token_id in enumerate(all_input_ids[starting_index:]):
detokenizer.update([token_id], False)
finished = i == len(all_input_ids) - 1
output_text += detokenizer.get_next_output_text(finished, delta=True)
return output_text, detokenizer.output_token_ids
@pytest.fixture
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
decoded_text = _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@pytest.fixture
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
spaces_between_special_tokens, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
if skip_special_tokens and not spaces_between_special_tokens:
pytest.skip()
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens({
"additional_special_tokens": [
at for at in
tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
})
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
else {"spaces_between_special_tokens": spaces_between_special_tokens}
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
if tokenizer.bos_token_id is not None:
truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(truth_tokens,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
if with_prompt:
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
prompt_input_ids = truth_tokens[:len(truth) // 2]
generated_input_ids = truth_tokens[len(truth) // 2:]
num_prompt_tokens = len(
tokenizer(truth[:len(truth) // 2],
add_special_tokens=False).input_ids)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
prompt_input_ids = truth_tokens[:num_prompt_tokens]
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens)
generated = truth[len(prompt):]
else:
generated = truth
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
if skip_special_tokens:
if tokenizer.bos_token_id is not None:
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
starting_index += 1
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
decoded_text = _run_incremental_decode(
generated = new_truth[len(prompt):]
else:
generated = new_truth
starting_index = 0
all_input_ids = truth_tokens
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)
starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast)
assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
decoded_text = _run_incremental_decode(
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
decoded_text, out_ids = _run_incremental_decode(
tokenizer, [len(tokenizer)],
skip_special_tokens=skip_special_tokens,
starting_index=starting_index)
skip_special_tokens=True,
starting_index=0,
spaces_between_special_tokens=True,
fast=fast)
assert decoded_text == ''
assert out_ids == [len(tokenizer)]
@pytest.fixture
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
@pytest.fixture(name="complete_sequence_token_ids")
def create_complete_sequence_token_ids(complete_sequence: str,
tokenizer) -> list[int]:
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
return complete_sequence_token_ids
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
prompt_token_ids = prompt_token_ids or []
return Sequence(
seq_id=0,
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
inputs=token_inputs(prompt_token_ids),
block_size=16,
)
@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
assert sequential_result != "".join(sequential_logprobs_text_other_token)
if skip_special_tokens:
if not skip_special_tokens:
# Text for logprobs for the chosen token should be the same as the
# generated text. Note that this will only be true if we skip
# special tokens.
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
def test_decode_prompt_logprobs(complete_sequence: str,
complete_sequence_token_ids: list[int],
detokenizer: Detokenizer):
# We want to use skip_special_tokens=False here but Mistral tokenizers
# don't support that.
if complete_sequence not in SPECIAL_TOKS_TRUTH:
skip_special_tokens = True
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
MistralTokenizer):
skip_special_tokens = False
else:
pytest.skip("MistralTokenizers don't support "
"skip_special_tokens=False")
return
"""Verify Detokenizer decodes prompt logprobs correctly."""
sampling_params = SamplingParams(skip_special_tokens=True,
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
prompt_logprobs=1)
# Run sequentially.
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
text_full = tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
text_first = tokenizer.decode(token_ids[0],
skip_special_tokens=skip_special_tokens)
text = text_full[len(text_first):]
# Text for logprobs for the chosen token should be the same as the