[Frontend] Exploit tokenizers "new stream" in FastIncrementalDetokenizer (#34217)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -19,9 +19,9 @@ from vllm.v1.engine import EngineCoreRequest
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
# Only tokenizers >= 0.22.0 supports DecodeStream with native prefill
|
||||||
# FastIncrementalDetokenizer.
|
# (ids parameter) used for FastIncrementalDetokenizer.
|
||||||
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
|
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.22.0")
|
||||||
|
|
||||||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||||
@@ -154,11 +154,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||||||
# We return the full output text if the sequence is finished.
|
# We return the full output text if the sequence is finished.
|
||||||
buffer_length = 0 if finished else self.stop_buffer_length
|
buffer_length = 0 if finished else self.stop_buffer_length
|
||||||
if not delta:
|
if not delta:
|
||||||
return (
|
if not buffer_length:
|
||||||
self.output_text[:-buffer_length]
|
return self.output_text
|
||||||
if buffer_length
|
return self.output_text[:-buffer_length]
|
||||||
else (self.output_text)
|
|
||||||
)
|
|
||||||
length = len(self.output_text) - buffer_length
|
length = len(self.output_text) - buffer_length
|
||||||
last_offset = self._last_output_text_offset
|
last_offset = self._last_output_text_offset
|
||||||
if last_offset < length:
|
if last_offset < length:
|
||||||
@@ -176,24 +175,14 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
|
|
||||||
self.request_id = request.request_id
|
self.request_id = request.request_id
|
||||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||||
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
|
|
||||||
|
|
||||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||||
|
|
||||||
# Find a safe place to start.
|
# Use native prefill to prime the decode stream with prompt tokens.
|
||||||
prompt_token_ids = request.prompt_token_ids or []
|
self.stream = DecodeStream(
|
||||||
prompt_suffix = prompt_token_ids
|
ids=request.prompt_token_ids,
|
||||||
prompt_len = len(prompt_suffix)
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
if prompt_len > 4:
|
)
|
||||||
for i in range(4, min(prompt_len + 1, 24)):
|
|
||||||
suffix = prompt_token_ids[-i:]
|
|
||||||
if "<EFBFBD>" not in self.tokenizer.decode(suffix):
|
|
||||||
prompt_suffix = suffix
|
|
||||||
break
|
|
||||||
|
|
||||||
# Prime the stream.
|
|
||||||
for tid in prompt_suffix:
|
|
||||||
self._protected_step(tid)
|
|
||||||
|
|
||||||
self.spaces_between_special_tokens = (
|
self.spaces_between_special_tokens = (
|
||||||
sampling_params.skip_special_tokens
|
sampling_params.skip_special_tokens
|
||||||
@@ -203,9 +192,8 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
if not self.spaces_between_special_tokens:
|
if not self.spaces_between_special_tokens:
|
||||||
# Store dict of added token ids so that we can suppress
|
# Store dict of added token ids so that we can suppress
|
||||||
# the spaces between them.
|
# the spaces between them.
|
||||||
if (
|
added_token_ids = getattr(self.tokenizer, "added_token_ids", None)
|
||||||
added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
|
if added_token_ids is None:
|
||||||
) is None:
|
|
||||||
self.tokenizer.added_token_ids = added_token_ids = {
|
self.tokenizer.added_token_ids = added_token_ids = {
|
||||||
tid: tok.content
|
tid: tok.content
|
||||||
for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
|
for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
|
||||||
@@ -290,11 +278,9 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def output_token_ids(self) -> list[int]:
|
def output_token_ids(self) -> list[int]:
|
||||||
return (
|
if self.prompt_len:
|
||||||
self.token_ids
|
return self.token_ids[self.prompt_len :]
|
||||||
if not self.prompt_len
|
return self.token_ids
|
||||||
else (self.token_ids[self.prompt_len :])
|
|
||||||
)
|
|
||||||
|
|
||||||
def num_output_tokens(self) -> int:
|
def num_output_tokens(self) -> int:
|
||||||
return len(self.token_ids) - self.prompt_len
|
return len(self.token_ids) - self.prompt_len
|
||||||
|
|||||||
Reference in New Issue
Block a user