[Core] Remove tokenizer group in vLLM (#24078)

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li
2025-09-17 01:42:59 -07:00
committed by GitHub
parent c15309a730
commit 6c47f6bfa4
49 changed files with 276 additions and 934 deletions

View File

@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
detokenize_incrementally)
from .tokenizer import AnyTokenizer
from .tokenizer_group import TokenizerGroup
class Detokenizer:
"""Provides methods to decode the output of a model into text."""
def __init__(self, tokenizer_group: TokenizerGroup):
self.tokenizer_group = tokenizer_group
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
"""Returns the HF tokenizer to use for a given sequence."""
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
def __init__(self, tokenizer: AnyTokenizer):
self.tokenizer = tokenizer
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
prompt_logprobs: list[Optional[dict[
@@ -32,9 +27,9 @@ class Detokenizer:
Args:
seq_group: The sequence group to decode.
prompt_logprobs: The logprobs to decode.
position_offset: Offset of the first index of the logprobs
position_offset: Offset of the first index of the logprobs
relative to the start of the sequence (for chunked prefill).
Returns:
The prompt logprobs with the decoded tokens.
"""
@@ -46,7 +41,6 @@ class Detokenizer:
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
@@ -70,7 +64,7 @@ class Detokenizer:
prompt_token_ids[:token_position] + [token_id])
(new_tokens, new_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=prompt_token_ids_with_token,
prev_tokens=prev_tokens,
prefix_offset=prefix_offset,
@@ -111,7 +105,6 @@ class Detokenizer:
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
@@ -119,14 +112,14 @@ class Detokenizer:
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
@@ -150,7 +143,7 @@ class Detokenizer:
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
tokenizer=self.tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,

View File

@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from typing_extensions import assert_never
from vllm import envs
from vllm.logger import init_logger
@@ -19,7 +20,6 @@ from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -274,20 +274,19 @@ def cached_tokenizer_from_config(
)
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[AnyTokenizer]:
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
except Exception as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_path, e)
tokenizer = None
return tokenizer
def init_tokenizer_from_configs(model_config: ModelConfig):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)

View File

@@ -61,6 +61,11 @@ class TokenizerBase(ABC):
def max_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size

View File

@@ -1,132 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing_extensions import assert_never
from vllm.config import ModelConfig, SchedulerConfig
from vllm.config.lora import LoRAConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.utils import LRUCache
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.truncation_side = tokenizer_config.get("truncation_side", "left")
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: list[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
max_input_length = (lora_request.long_lora_max_len
or self.max_input_length)
else:
max_input_length = self.max_input_length
if max_input_length is not None and input_length > max_input_length:
raise ValueError("Input too long.", input_length, max_input_length)
def encode(self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
async def encode_async(
self,
prompt: str,
max_length: Optional[int] = None,
truncation: Optional[bool] = None,
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> list[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = encode_tokens(tokenizer,
prompt,
max_length=max_length,
truncation=truncation,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers[lora_request.lora_int_id]
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig]):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return TokenizerGroup(
tokenizer_id=model_config.tokenizer,
enable_lora=bool(lora_config),
max_num_seqs=scheduler_config.max_num_seqs,
max_loras=lora_config.max_loras if lora_config else 0,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side)

View File

@@ -327,6 +327,10 @@ class MistralTokenizer(TokenizerBase):
def max_token_id(self) -> int:
return self._max_token_id
@property
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size