[V1][Perf] Faster incremental detokenization (#15137)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -4,14 +4,22 @@ from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
||||
detokenize_incrementally)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||
IncrementalDetokenizer,
|
||||
SlowIncrementalDetokenizer)
|
||||
|
||||
SPECIAL_TOKS_TRUTH = [
|
||||
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
|
||||
]
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test",
|
||||
@@ -22,7 +30,8 @@ TRUTH = [
|
||||
# incomplete UTF-8 characters
|
||||
# see https://github.com/vllm-project/vllm/pull/9625
|
||||
"ပုံပြင်လေးပြောပြပါ်",
|
||||
]
|
||||
] + SPECIAL_TOKS_TRUTH
|
||||
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
@@ -38,26 +47,37 @@ TOKENIZERS = [
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool, starting_index: int):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(starting_index, len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
else:
|
||||
prev_tokens += new_tokens
|
||||
return decoded_text
|
||||
def _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens: bool,
|
||||
starting_index: int,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
fast: Optional[bool] = None):
|
||||
|
||||
prompt_token_ids = all_input_ids[:starting_index]
|
||||
|
||||
params = SamplingParams(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
request = EngineCoreRequest("", "", prompt_token_ids, None, None, None,
|
||||
params, None, 0.0, None)
|
||||
|
||||
if fast is None:
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
tokenizer, request)
|
||||
elif fast:
|
||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||
else:
|
||||
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
output_text = ""
|
||||
for i, token_id in enumerate(all_input_ids[starting_index:]):
|
||||
detokenizer.update([token_id], False)
|
||||
finished = i == len(all_input_ids) - 1
|
||||
output_text += detokenizer.get_next_output_text(finished, delta=True)
|
||||
|
||||
return output_text, detokenizer.output_token_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
|
||||
decoded_text = _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
assert decoded_text == truth
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
|
||||
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
|
||||
@pytest.mark.parametrize("fast", (True, False))
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
spaces_between_special_tokens, fast):
|
||||
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
pytest.skip()
|
||||
|
||||
if skip_special_tokens and not spaces_between_special_tokens:
|
||||
pytest.skip()
|
||||
|
||||
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
# Fix up inconsistency in fast/slow tokenizer behaviour.
|
||||
tokenizer.add_special_tokens({
|
||||
"additional_special_tokens": [
|
||||
at for at in
|
||||
tokenizer._tokenizer.get_added_tokens_decoder().values()
|
||||
if at.special
|
||||
]
|
||||
})
|
||||
|
||||
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
|
||||
else {"spaces_between_special_tokens": spaces_between_special_tokens}
|
||||
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if tokenizer.bos_token_id is not None:
|
||||
truth_tokens.insert(0, tokenizer.bos_token_id)
|
||||
truth_tokens.append(tokenizer.eos_token_id)
|
||||
|
||||
new_truth = tokenizer.decode(truth_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
|
||||
if with_prompt:
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
||||
generated_input_ids = truth_tokens[len(truth) // 2:]
|
||||
num_prompt_tokens = len(
|
||||
tokenizer(truth[:len(truth) // 2],
|
||||
add_special_tokens=False).input_ids)
|
||||
if tokenizer.bos_token_id is not None:
|
||||
num_prompt_tokens += 1
|
||||
|
||||
prompt_input_ids = truth_tokens[:num_prompt_tokens]
|
||||
generated_input_ids = truth_tokens[num_prompt_tokens:]
|
||||
all_input_ids = prompt_input_ids + generated_input_ids
|
||||
starting_index = len(prompt_input_ids)
|
||||
prompt = tokenizer.decode(prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
generated = truth[len(prompt):]
|
||||
else:
|
||||
generated = truth
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if skip_special_tokens:
|
||||
if tokenizer.bos_token_id is not None:
|
||||
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
||||
starting_index += 1
|
||||
all_input_ids = all_input_ids + [tokenizer.eos_token_id]
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
generated = new_truth[len(prompt):]
|
||||
else:
|
||||
generated = new_truth
|
||||
starting_index = 0
|
||||
all_input_ids = truth_tokens
|
||||
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index)
|
||||
starting_index=starting_index,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
fast=fast)
|
||||
|
||||
assert decoded_text == generated
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("fast", (True, False))
|
||||
def test_oov_decode(tokenizer, fast):
|
||||
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
pytest.skip()
|
||||
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer, [len(tokenizer)],
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index)
|
||||
skip_special_tokens=True,
|
||||
starting_index=0,
|
||||
spaces_between_special_tokens=True,
|
||||
fast=fast)
|
||||
|
||||
assert decoded_text == ''
|
||||
assert out_ids == [len(tokenizer)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||
tokenizer) -> list[int]:
|
||||
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
|
||||
return complete_sequence_token_ids
|
||||
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
|
||||
|
||||
|
||||
def create_sequence(prompt_token_ids=None):
|
||||
prompt_token_ids = prompt_token_ids or [1]
|
||||
prompt_token_ids = prompt_token_ids or []
|
||||
return Sequence(
|
||||
seq_id=0,
|
||||
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
|
||||
inputs=token_inputs(prompt_token_ids),
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
||||
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
||||
|
||||
if skip_special_tokens:
|
||||
if not skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# generated text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
|
||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: list[int],
|
||||
detokenizer: Detokenizer):
|
||||
|
||||
# We want to use skip_special_tokens=False here but Mistral tokenizers
|
||||
# don't support that.
|
||||
if complete_sequence not in SPECIAL_TOKS_TRUTH:
|
||||
skip_special_tokens = True
|
||||
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
|
||||
MistralTokenizer):
|
||||
skip_special_tokens = False
|
||||
else:
|
||||
pytest.skip("MistralTokenizers don't support "
|
||||
"skip_special_tokens=False")
|
||||
return
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=True,
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
prompt_logprobs=1)
|
||||
|
||||
# Run sequentially.
|
||||
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
||||
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
|
||||
text_full = tokenizer.decode(token_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
text_first = tokenizer.decode(token_ids[0],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
text = text_full[len(text_first):]
|
||||
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
|
||||
Reference in New Issue
Block a user