[Core] Improve detokenization performance for prefill (#3469)

Co-authored-by: MeloYang <meloyang05@gmail.com>
This commit is contained in:
Antoni Baum
2024-03-22 13:44:12 -07:00
committed by GitHub
parent cf2f084d56
commit bfdb1ba5c3
4 changed files with 385 additions and 89 deletions

View File

@@ -158,6 +158,34 @@ def _convert_tokens_to_string_with_added_encoders(
return "".join(sub_texts)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt_ids: List[int],
skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# Offset a little more in case we have special tokens.
prefix_offset = max(
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
# We do not need to convert the whole prompt to tokens.
new_tokens = tokenizer.convert_ids_to_tokens(
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
prefix_offset = max(
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
read_offset = len(new_tokens)
return new_tokens, prefix_offset, read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
@@ -165,31 +193,53 @@ def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int = 0,
read_offset: int = 0,
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
if prev_tokens is None:
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids, skip_special_tokens=skip_special_tokens)
output_tokens = new_tokens
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token.
prefix_offset = max(len(output_tokens) - 6, 0)
# If the first new token is a special token, we can't skip 1 extra token
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
read_offset = max(len(output_tokens), 0)
else:
read_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
is_first_iter = prev_tokens is None
if is_first_iter:
(prev_tokens, prefix_offset,
read_offset) = convert_prompt_ids_to_tokens(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
if is_first_iter:
new_tokens = output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the