[Core] Remove tokenizer group in vLLM (#24078)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
|
||||
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally)
|
||||
from .tokenizer import AnyTokenizer
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
"""Provides methods to decode the output of a model into text."""
|
||||
|
||||
def __init__(self, tokenizer_group: TokenizerGroup):
|
||||
self.tokenizer_group = tokenizer_group
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||
"""Returns the HF tokenizer to use for a given sequence."""
|
||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
|
||||
prompt_logprobs: list[Optional[dict[
|
||||
@@ -32,9 +27,9 @@ class Detokenizer:
|
||||
Args:
|
||||
seq_group: The sequence group to decode.
|
||||
prompt_logprobs: The logprobs to decode.
|
||||
position_offset: Offset of the first index of the logprobs
|
||||
position_offset: Offset of the first index of the logprobs
|
||||
relative to the start of the sequence (for chunked prefill).
|
||||
|
||||
|
||||
Returns:
|
||||
The prompt logprobs with the decoded tokens.
|
||||
"""
|
||||
@@ -46,7 +41,6 @@ class Detokenizer:
|
||||
# Only prompt, without the generated token.
|
||||
all_token_ids = seq.get_token_ids()
|
||||
prompt_token_ids = all_token_ids[:-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
next_iter_prefix_offset = 0
|
||||
@@ -70,7 +64,7 @@ class Detokenizer:
|
||||
prompt_token_ids[:token_position] + [token_id])
|
||||
(new_tokens, new_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=prompt_token_ids_with_token,
|
||||
prev_tokens=prev_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
@@ -111,7 +105,6 @@ class Detokenizer:
|
||||
"""
|
||||
all_input_ids = seq.get_token_ids()
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
|
||||
# Convert prompt token IDs to tokens if necessary.
|
||||
# Do it here so that we don't have to repeat this
|
||||
@@ -119,14 +112,14 @@ class Detokenizer:
|
||||
if seq.tokens is None:
|
||||
(seq.tokens, seq.prefix_offset,
|
||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_ids=all_input_ids[:-1],
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
)
|
||||
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
@@ -150,7 +143,7 @@ class Detokenizer:
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
import huggingface_hub
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
@@ -19,7 +20,6 @@ from vllm.transformers_utils.config import (
|
||||
get_sentence_transformer_tokenizer_config)
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import make_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@@ -274,20 +274,19 @@ def cached_tokenizer_from_config(
|
||||
)
|
||||
|
||||
|
||||
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||
**kwargs) -> Optional[AnyTokenizer]:
|
||||
if lora_request is None:
|
||||
return None
|
||||
try:
|
||||
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
|
||||
except Exception as e:
|
||||
# No tokenizer was found in the LoRA folder,
|
||||
# use base model tokenizer
|
||||
logger.warning(
|
||||
"No tokenizer found in %s, using base model tokenizer instead. "
|
||||
"(Exception: %s)", lora_request.lora_path, e)
|
||||
tokenizer = None
|
||||
return tokenizer
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig):
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type == "generate" or runner_type == "draft":
|
||||
truncation_side = "left"
|
||||
elif runner_type == "pooling":
|
||||
truncation_side = "right"
|
||||
else:
|
||||
assert_never(runner_type)
|
||||
|
||||
|
||||
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
|
||||
return get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
|
||||
@@ -61,6 +61,11 @@ class TokenizerBase(ABC):
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
|
||||
get_lora_tokenizer,
|
||||
get_lora_tokenizer_async,
|
||||
get_tokenizer)
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
|
||||
class TokenizerGroup:
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
||||
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
||||
max_input_length: Optional[int], **tokenizer_config):
|
||||
self.tokenizer_id = tokenizer_id
|
||||
self.tokenizer_config = tokenizer_config
|
||||
self.enable_lora = enable_lora
|
||||
self.max_input_length = max_input_length
|
||||
self.truncation_side = tokenizer_config.get("truncation_side", "left")
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
max_loras = tokenizer_config.get("max_loras", 0)
|
||||
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
||||
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
|
||||
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
return self.max_input_length
|
||||
|
||||
def _raise_if_input_too_long(self,
|
||||
encoded_tokens: list[int],
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
input_length = len(encoded_tokens)
|
||||
if lora_request:
|
||||
max_input_length = (lora_request.long_lora_max_len
|
||||
or self.max_input_length)
|
||||
else:
|
||||
max_input_length = self.max_input_length
|
||||
if max_input_length is not None and input_length > max_input_length:
|
||||
raise ValueError("Input too long.", input_length, max_input_length)
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: Optional[bool] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
|
||||
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||
ret = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens)
|
||||
self._raise_if_input_too_long(ret, lora_request)
|
||||
return ret
|
||||
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: Optional[bool] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||
ret = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens)
|
||||
self._raise_if_input_too_long(ret, lora_request)
|
||||
return ret
|
||||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
tokenizer = (get_lora_tokenizer(
|
||||
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers[lora_request.lora_int_id]
|
||||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
tokenizer = (await get_lora_tokenizer_async(
|
||||
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers[lora_request.lora_int_id]
|
||||
|
||||
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig]):
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type == "generate" or runner_type == "draft":
|
||||
truncation_side = "left"
|
||||
elif runner_type == "pooling":
|
||||
truncation_side = "right"
|
||||
else:
|
||||
assert_never(runner_type)
|
||||
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=model_config.tokenizer,
|
||||
enable_lora=bool(lora_config),
|
||||
max_num_seqs=scheduler_config.max_num_seqs,
|
||||
max_loras=lora_config.max_loras if lora_config else 0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision,
|
||||
truncation_side=truncation_side)
|
||||
@@ -327,6 +327,10 @@ class MistralTokenizer(TokenizerBase):
|
||||
def max_token_id(self) -> int:
|
||||
return self._max_token_id
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user