[Frontend] Exploit tokenizers "new stream" in FastIncrementalDetokenizer (#34217)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-11 02:03:24 -08:00
committed by GitHub
parent 786806dd44
commit e09546cf05

View File

@@ -19,9 +19,9 @@ from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__) logger = init_logger(__name__)
# Only tokenizers >= 0.21.1 supports DecodeStream used for # Only tokenizers >= 0.22.0 supports DecodeStream with native prefill
# FastIncrementalDetokenizer. # (ids parameter) used for FastIncrementalDetokenizer.
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1") USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.22.0")
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042 # Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered" INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
@@ -154,11 +154,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# We return the full output text if the sequence is finished. # We return the full output text if the sequence is finished.
buffer_length = 0 if finished else self.stop_buffer_length buffer_length = 0 if finished else self.stop_buffer_length
if not delta: if not delta:
return ( if not buffer_length:
self.output_text[:-buffer_length] return self.output_text
if buffer_length return self.output_text[:-buffer_length]
else (self.output_text)
)
length = len(self.output_text) - buffer_length length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset last_offset = self._last_output_text_offset
if last_offset < length: if last_offset < length:
@@ -176,24 +175,14 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
self.request_id = request.request_id self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens self.skip_special_tokens = sampling_params.skip_special_tokens
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start. # Use native prefill to prime the decode stream with prompt tokens.
prompt_token_ids = request.prompt_token_ids or [] self.stream = DecodeStream(
prompt_suffix = prompt_token_ids ids=request.prompt_token_ids,
prompt_len = len(prompt_suffix) skip_special_tokens=self.skip_special_tokens,
if prompt_len > 4: )
for i in range(4, min(prompt_len + 1, 24)):
suffix = prompt_token_ids[-i:]
if "<EFBFBD>" not in self.tokenizer.decode(suffix):
prompt_suffix = suffix
break
# Prime the stream.
for tid in prompt_suffix:
self._protected_step(tid)
self.spaces_between_special_tokens = ( self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens sampling_params.skip_special_tokens
@@ -203,9 +192,8 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
if not self.spaces_between_special_tokens: if not self.spaces_between_special_tokens:
# Store dict of added token ids so that we can suppress # Store dict of added token ids so that we can suppress
# the spaces between them. # the spaces between them.
if ( added_token_ids = getattr(self.tokenizer, "added_token_ids", None)
added_token_ids := getattr(self.tokenizer, "added_token_ids", None) if added_token_ids is None:
) is None:
self.tokenizer.added_token_ids = added_token_ids = { self.tokenizer.added_token_ids = added_token_ids = {
tid: tok.content tid: tok.content
for tid, tok in self.tokenizer.get_added_tokens_decoder().items() for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
@@ -290,11 +278,9 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
@property @property
def output_token_ids(self) -> list[int]: def output_token_ids(self) -> list[int]:
return ( if self.prompt_len:
self.token_ids return self.token_ids[self.prompt_len :]
if not self.prompt_len return self.token_ids
else (self.token_ids[self.prompt_len :])
)
def num_output_tokens(self) -> int: def num_output_tokens(self) -> int:
return len(self.token_ids) - self.prompt_len return len(self.token_ids) - self.prompt_len