[Core] Improve detokenization performance for prefill (#3469)
Co-authored-by: MeloYang <meloyang05@gmail.com>
This commit is contained in:
@@ -158,6 +158,34 @@ def _convert_tokens_to_string_with_added_encoders(
|
||||
return "".join(sub_texts)
|
||||
|
||||
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
|
||||
def convert_prompt_ids_to_tokens(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prompt_ids: List[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
"""Converts the prompt ids to tokens and returns the tokens and offsets
|
||||
for incremental detokenization.
|
||||
|
||||
Note that not all tokens are converted to strings. Only the tokens that
|
||||
are necessary for incremental detokenization are converted to strings.
|
||||
"""
|
||||
# Offset a little more in case we have special tokens.
|
||||
prefix_offset = max(
|
||||
len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
|
||||
# We do not need to convert the whole prompt to tokens.
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
|
||||
prefix_offset = max(
|
||||
len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||
read_offset = len(new_tokens)
|
||||
return new_tokens, prefix_offset, read_offset
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
@@ -165,31 +193,53 @@ def detokenize_incrementally(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
all_input_ids: List[int],
|
||||
prev_tokens: Optional[List[str]],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
prefix_offset: int,
|
||||
read_offset: int,
|
||||
skip_special_tokens: bool = False,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
) -> Tuple[List[str], str, int, int]:
|
||||
"""Detokenizes the input ids incrementally and returns the new tokens
|
||||
and the new text.
|
||||
|
||||
If `prev_tokens` is None, this function will convert the input ids to
|
||||
tokens and return the tokens and the new text. Otherwise, it will return the
|
||||
new tokens and the new text.
|
||||
|
||||
This function will also return the new prefix offset and the new read
|
||||
offset to be used in the next iteration.
|
||||
|
||||
The offsets are necessary to defeat cleanup algorithms in the decode which
|
||||
decide to add a space or not depending on the surrounding ids.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use.
|
||||
all_input_ids: The input ids. The last id is the new token id.
|
||||
prev_tokens: The previous tokens. If None, this function will convert
|
||||
the input ids to tokens and return the tokens and the new text.
|
||||
prefix_offset: The prefix offset.
|
||||
read_offset: The read offset.
|
||||
skip_special_tokens: Whether to skip special tokens.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens.
|
||||
"""
|
||||
new_token_id = all_input_ids[-1]
|
||||
# This is the first iteration for this sequence
|
||||
if prev_tokens is None:
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = new_tokens
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
# Subtract 1 extra to account for the generated token.
|
||||
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||
# If the first new token is a special token, we can't skip 1 extra token
|
||||
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
|
||||
read_offset = max(len(output_tokens), 0)
|
||||
else:
|
||||
read_offset = max(len(output_tokens) - 1, 0)
|
||||
else:
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
is_first_iter = prev_tokens is None
|
||||
if is_first_iter:
|
||||
(prev_tokens, prefix_offset,
|
||||
read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer,
|
||||
all_input_ids[:-1],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# If this is the first iteration, return all tokens.
|
||||
if is_first_iter:
|
||||
new_tokens = output_tokens
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||
# the decode which decide to add a space or not depending on the
|
||||
|
||||
Reference in New Issue
Block a user