[Frontend] Exploit tokenizers "new stream" in FastIncrementalDetokenizer (#34217)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -19,9 +19,9 @@ from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
||||
# FastIncrementalDetokenizer.
|
||||
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
|
||||
# Only tokenizers >= 0.22.0 supports DecodeStream with native prefill
|
||||
# (ids parameter) used for FastIncrementalDetokenizer.
|
||||
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.22.0")
|
||||
|
||||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||
@@ -154,11 +154,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
||||
# We return the full output text if the sequence is finished.
|
||||
buffer_length = 0 if finished else self.stop_buffer_length
|
||||
if not delta:
|
||||
return (
|
||||
self.output_text[:-buffer_length]
|
||||
if buffer_length
|
||||
else (self.output_text)
|
||||
)
|
||||
if not buffer_length:
|
||||
return self.output_text
|
||||
return self.output_text[:-buffer_length]
|
||||
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
@@ -176,24 +175,14 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
self.request_id = request.request_id
|
||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
|
||||
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
# Find a safe place to start.
|
||||
prompt_token_ids = request.prompt_token_ids or []
|
||||
prompt_suffix = prompt_token_ids
|
||||
prompt_len = len(prompt_suffix)
|
||||
if prompt_len > 4:
|
||||
for i in range(4, min(prompt_len + 1, 24)):
|
||||
suffix = prompt_token_ids[-i:]
|
||||
if "<EFBFBD>" not in self.tokenizer.decode(suffix):
|
||||
prompt_suffix = suffix
|
||||
break
|
||||
|
||||
# Prime the stream.
|
||||
for tid in prompt_suffix:
|
||||
self._protected_step(tid)
|
||||
# Use native prefill to prime the decode stream with prompt tokens.
|
||||
self.stream = DecodeStream(
|
||||
ids=request.prompt_token_ids,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
)
|
||||
|
||||
self.spaces_between_special_tokens = (
|
||||
sampling_params.skip_special_tokens
|
||||
@@ -203,9 +192,8 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
if not self.spaces_between_special_tokens:
|
||||
# Store dict of added token ids so that we can suppress
|
||||
# the spaces between them.
|
||||
if (
|
||||
added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
|
||||
) is None:
|
||||
added_token_ids = getattr(self.tokenizer, "added_token_ids", None)
|
||||
if added_token_ids is None:
|
||||
self.tokenizer.added_token_ids = added_token_ids = {
|
||||
tid: tok.content
|
||||
for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
|
||||
@@ -290,11 +278,9 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
@property
|
||||
def output_token_ids(self) -> list[int]:
|
||||
return (
|
||||
self.token_ids
|
||||
if not self.prompt_len
|
||||
else (self.token_ids[self.prompt_len :])
|
||||
)
|
||||
if self.prompt_len:
|
||||
return self.token_ids[self.prompt_len :]
|
||||
return self.token_ids
|
||||
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self.token_ids) - self.prompt_len
|
||||
|
||||
Reference in New Issue
Block a user