[Refactor] TokenizerRegistry only uses lazy imports (#30609)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,7 +7,7 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.inputs import zip_enc_dec_prompts
|
from vllm.inputs import zip_enc_dec_prompts
|
||||||
from vllm.inputs.parse import parse_raw_prompts
|
from vllm.inputs.parse import parse_raw_prompts
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.tokenizers import init_tokenizer_from_config
|
from vllm.tokenizers import cached_tokenizer_from_config
|
||||||
|
|
||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
|
|||||||
)
|
)
|
||||||
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
def test_preprocessor_always_mm_code_path(model_id, prompt):
|
||||||
model_config = ModelConfig(model=model_id)
|
model_config = ModelConfig(model=model_id)
|
||||||
tokenizer = init_tokenizer_from_config(model_config)
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
||||||
|
|
||||||
# HF processor adds sep token
|
# HF processor adds sep token
|
||||||
|
|||||||
@@ -3,38 +3,39 @@
|
|||||||
from typing import _get_protocol_attrs # type: ignore
|
from typing import _get_protocol_attrs # type: ignore
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import (
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
)
|
||||||
|
|
||||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||||
|
from vllm.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
|
|
||||||
def _get_missing_attrs(obj: object, target: type):
|
def _get_missing_attrs(obj: object, target: type):
|
||||||
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
|
return [k for k in _get_protocol_attrs(target) if not hasattr(obj, k)]
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_tokenizer_like(tokenizer: object):
|
||||||
|
missing_attrs = _get_missing_attrs(tokenizer, TokenizerLike)
|
||||||
|
assert not missing_attrs, f"Missing attrs: {missing_attrs}"
|
||||||
|
|
||||||
|
|
||||||
def test_tokenizer_like_protocol():
|
def test_tokenizer_like_protocol():
|
||||||
assert not (
|
tokenizer = get_tokenizer("gpt2", use_fast=False)
|
||||||
missing_attrs := _get_missing_attrs(
|
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||||
get_tokenizer("gpt2", use_fast=False),
|
_assert_tokenizer_like(tokenizer)
|
||||||
TokenizerLike,
|
|
||||||
)
|
|
||||||
), f"Missing attrs: {missing_attrs}"
|
|
||||||
|
|
||||||
assert not (
|
tokenizer = get_tokenizer("gpt2", use_fast=True)
|
||||||
missing_attrs := _get_missing_attrs(
|
assert isinstance(tokenizer, PreTrainedTokenizerFast)
|
||||||
get_tokenizer("gpt2", use_fast=True),
|
_assert_tokenizer_like(tokenizer)
|
||||||
TokenizerLike,
|
|
||||||
)
|
|
||||||
), f"Missing attrs: {missing_attrs}"
|
|
||||||
|
|
||||||
assert not (
|
tokenizer = get_tokenizer(
|
||||||
missing_attrs := _get_missing_attrs(
|
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
|
||||||
get_tokenizer(
|
)
|
||||||
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
|
assert isinstance(tokenizer, MistralTokenizer)
|
||||||
),
|
_assert_tokenizer_like(tokenizer)
|
||||||
TokenizerLike,
|
|
||||||
)
|
|
||||||
), f"Missing attrs: {missing_attrs}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
|
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
|
||||||
|
|||||||
@@ -2,7 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from vllm.tokenizers import TokenizerLike, TokenizerRegistry, get_tokenizer
|
import pytest
|
||||||
|
|
||||||
|
from vllm.tokenizers import TokenizerLike
|
||||||
|
from vllm.tokenizers.registry import (
|
||||||
|
TokenizerRegistry,
|
||||||
|
get_tokenizer,
|
||||||
|
resolve_tokenizer_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizer(TokenizerLike):
|
class TestTokenizer(TokenizerLike):
|
||||||
@@ -40,10 +47,22 @@ class TestTokenizer(TokenizerLike):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("runner_type", ["generate", "pooling"])
|
||||||
|
def test_resolve_tokenizer_args_idempotent(runner_type):
|
||||||
|
tokenizer_mode, tokenizer_name, args, kwargs = resolve_tokenizer_args(
|
||||||
|
"facebook/opt-125m",
|
||||||
|
runner_type=runner_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (tokenizer_mode, tokenizer_name, args, kwargs) == resolve_tokenizer_args(
|
||||||
|
tokenizer_name, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_customized_tokenizer():
|
def test_customized_tokenizer():
|
||||||
TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)
|
TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)
|
||||||
|
|
||||||
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer", "abc")
|
tokenizer = TokenizerRegistry.load_tokenizer("test_tokenizer", "abc")
|
||||||
assert isinstance(tokenizer, TestTokenizer)
|
assert isinstance(tokenizer, TestTokenizer)
|
||||||
assert tokenizer.path_or_repo_id == "abc"
|
assert tokenizer.path_or_repo_id == "abc"
|
||||||
assert tokenizer.bos_token_id == 0
|
assert tokenizer.bos_token_id == 0
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ from vllm.model_executor.models import SupportsMultiModal
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||||
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
from vllm.tokenizers.mistral import MistralTokenizer
|
|
||||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@@ -60,6 +59,8 @@ from vllm.utils.import_utils import LazyLoader
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.tokenizers.mistral import MistralTokenizer
|
||||||
else:
|
else:
|
||||||
torch = LazyLoader("torch", globals(), "torch")
|
torch = LazyLoader("torch", globals(), "torch")
|
||||||
|
|
||||||
@@ -1832,7 +1833,7 @@ def apply_hf_chat_template(
|
|||||||
|
|
||||||
|
|
||||||
def apply_mistral_chat_template(
|
def apply_mistral_chat_template(
|
||||||
tokenizer: MistralTokenizer,
|
tokenizer: "MistralTokenizer",
|
||||||
messages: list[ChatCompletionMessageParam],
|
messages: list[ChatCompletionMessageParam],
|
||||||
chat_template: str | None,
|
chat_template: str | None,
|
||||||
tools: list[dict[str, Any]] | None,
|
tools: list[dict[str, Any]] | None,
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from .deepseekv32 import DeepseekV32Tokenizer
|
|
||||||
from .hf import HfTokenizer
|
|
||||||
from .mistral import MistralTokenizer
|
|
||||||
from .protocol import TokenizerLike
|
from .protocol import TokenizerLike
|
||||||
from .registry import (
|
from .registry import (
|
||||||
TokenizerRegistry,
|
TokenizerRegistry,
|
||||||
@@ -15,12 +12,9 @@ from .registry import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TokenizerLike",
|
"TokenizerLike",
|
||||||
"HfTokenizer",
|
|
||||||
"MistralTokenizer",
|
|
||||||
"TokenizerRegistry",
|
"TokenizerRegistry",
|
||||||
"cached_get_tokenizer",
|
"cached_get_tokenizer",
|
||||||
"get_tokenizer",
|
"get_tokenizer",
|
||||||
"cached_tokenizer_from_config",
|
"cached_tokenizer_from_config",
|
||||||
"init_tokenizer_from_config",
|
"init_tokenizer_from_config",
|
||||||
"DeepseekV32Tokenizer",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,24 +2,18 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from transformers import BatchEncoding
|
from transformers import BatchEncoding
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
|
||||||
from .deepseek_v32_encoding import encode_messages
|
from .deepseek_v32_encoding import encode_messages
|
||||||
from .hf import HfTokenizer, TokenizerLike
|
from .hf import CachedHfTokenizer
|
||||||
from .registry import TokenizerRegistry
|
from .protocol import TokenizerLike
|
||||||
|
|
||||||
|
|
||||||
@TokenizerRegistry.register("deepseek_v32")
|
class DeepseekV32Tokenizer(CachedHfTokenizer):
|
||||||
class DeepseekV32Tokenizer(HfTokenizer):
|
|
||||||
def __init__(self, tokenizer: TokenizerLike):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.name_or_path = (
|
|
||||||
tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else ""
|
|
||||||
)
|
|
||||||
self._added_vocab = self.tokenizer.get_added_vocab()
|
|
||||||
self._added_vocab_size = len(self._added_vocab)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -40,7 +34,21 @@ class DeepseekV32Tokenizer(HfTokenizer):
|
|||||||
)
|
)
|
||||||
return DeepseekV32Tokenizer(tokenizer)
|
return DeepseekV32Tokenizer(tokenizer)
|
||||||
|
|
||||||
def apply_chat_template(self, messages, tools=None, **kwargs):
|
def __init__(self, tokenizer: TokenizerLike) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.name_or_path = getattr(tokenizer, "name_or_path", "")
|
||||||
|
|
||||||
|
self._added_vocab = self.tokenizer.get_added_vocab()
|
||||||
|
self._added_vocab_size = len(self._added_vocab)
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
messages: list["ChatCompletionMessageParam"],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str | list[int]:
|
||||||
thinking = kwargs.get("thinking", False)
|
thinking = kwargs.get("thinking", False)
|
||||||
thinking_mode = "thinking"
|
thinking_mode = "thinking"
|
||||||
if not thinking:
|
if not thinking:
|
||||||
@@ -49,13 +57,24 @@ class DeepseekV32Tokenizer(HfTokenizer):
|
|||||||
messages = conversation.copy()
|
messages = conversation.copy()
|
||||||
if tools is not None and len(tools) > 0:
|
if tools is not None and len(tools) > 0:
|
||||||
messages.insert(0, {"role": "system"})
|
messages.insert(0, {"role": "system"})
|
||||||
messages[0]["tools"] = tools
|
messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key]
|
||||||
|
|
||||||
# Historical reasoning content is dropped when a new user message is introduced
|
# Historical reasoning content is dropped when a new user message is introduced
|
||||||
drop_thinking = messages[-1]["role"] == "user"
|
drop_thinking = messages[-1]["role"] == "user"
|
||||||
|
|
||||||
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
|
encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking)
|
||||||
prompt_str = encode_messages(messages, **encode_config) # type: ignore
|
prompt_str = encode_messages(messages, **encode_config) # type: ignore
|
||||||
|
|
||||||
|
if kwargs.get("tokenize", True):
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
|
||||||
|
}
|
||||||
|
return self.encode(
|
||||||
|
prompt_str,
|
||||||
|
add_special_tokens=False,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return prompt_str
|
return prompt_str
|
||||||
|
|
||||||
def num_special_tokens_to_add(self) -> int:
|
def num_special_tokens_to_add(self) -> int:
|
||||||
|
|||||||
@@ -3,22 +3,18 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TypeAlias
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
|
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
|
||||||
|
|
||||||
from .protocol import TokenizerLike
|
from .protocol import TokenizerLike
|
||||||
from .registry import TokenizerRegistry
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
HfTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
|
|
||||||
def get_cached_tokenizer(
|
def get_cached_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
|
||||||
tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast",
|
|
||||||
) -> TokenizerLike:
|
|
||||||
"""
|
"""
|
||||||
By default, transformers will recompute multiple tokenizer properties
|
By default, transformers will recompute multiple tokenizer properties
|
||||||
each time they are called, leading to a significant slowdown.
|
each time they are called, leading to a significant slowdown.
|
||||||
@@ -65,11 +61,10 @@ def get_cached_tokenizer(
|
|||||||
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
|
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
|
||||||
|
|
||||||
cached_tokenizer.__class__ = CachedTokenizer
|
cached_tokenizer.__class__ = CachedTokenizer
|
||||||
return cached_tokenizer # type: ignore
|
return cached_tokenizer
|
||||||
|
|
||||||
|
|
||||||
@TokenizerRegistry.register("hf")
|
class CachedHfTokenizer(TokenizerLike):
|
||||||
class HfTokenizer(TokenizerLike):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -79,7 +74,7 @@ class HfTokenizer(TokenizerLike):
|
|||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
download_dir: str | None = None,
|
download_dir: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> "TokenizerLike":
|
) -> HfTokenizer:
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
path_or_repo_id,
|
path_or_repo_id,
|
||||||
|
|||||||
@@ -3,10 +3,11 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .protocol import TokenizerLike
|
from .protocol import TokenizerLike
|
||||||
from .registry import TokenizerRegistry
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mistral_common.protocol.instruct.request import (
|
from mistral_common.protocol.instruct.request import (
|
||||||
@@ -15,9 +16,6 @@ if TYPE_CHECKING:
|
|||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||||
from transformers import BatchEncoding
|
from transformers import BatchEncoding
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Transformers v5
|
# Transformers v5
|
||||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||||
@@ -201,7 +199,6 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
|
|||||||
return tokenizer.unk_id
|
return tokenizer.unk_id
|
||||||
|
|
||||||
|
|
||||||
@TokenizerRegistry.register("mistral")
|
|
||||||
class MistralTokenizer(TokenizerLike):
|
class MistralTokenizer(TokenizerLike):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class TokenizerLike(Protocol):
|
|||||||
messages: list["ChatCompletionMessageParam"],
|
messages: list["ChatCompletionMessageParam"],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[int]:
|
) -> str | list[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from collections.abc import Callable
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, TypeVar, overload
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import TypeVar, assert_never, deprecated
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@@ -24,46 +24,25 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
|||||||
from .protocol import TokenizerLike
|
from .protocol import TokenizerLike
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config.model import ModelConfig, RunnerType
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_T = TypeVar("_T", bound=type[TokenizerLike])
|
|
||||||
|
_VLLM_TOKENIZERS = {
|
||||||
|
"deepseekv32": ("deepseekv32", "DeepseekV32Tokenizer"),
|
||||||
|
"hf": ("hf", "CachedHfTokenizer"),
|
||||||
|
"mistral": ("mistral", "MistralTokenizer"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TokenizerRegistry:
|
@dataclass
|
||||||
# Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class)
|
class _TokenizerRegistry:
|
||||||
REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {}
|
# Tokenizer mode -> (tokenizer module, tokenizer class)
|
||||||
|
tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict)
|
||||||
|
|
||||||
# In-tree tokenizers
|
def register(self, tokenizer_mode: str, module: str, class_name: str) -> None:
|
||||||
@staticmethod
|
if tokenizer_mode in self.tokenizers:
|
||||||
@overload
|
|
||||||
def register(tokenizer_mode: str) -> Callable[[_T], _T]: ...
|
|
||||||
|
|
||||||
# OOT tokenizers
|
|
||||||
@staticmethod
|
|
||||||
@overload
|
|
||||||
def register(tokenizer_mode: str, module: str, class_name: str) -> None: ...
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def register(
|
|
||||||
tokenizer_mode: str,
|
|
||||||
module: str | None = None,
|
|
||||||
class_name: str | None = None,
|
|
||||||
) -> Callable[[_T], _T] | None:
|
|
||||||
# In-tree tokenizers
|
|
||||||
if module is None or class_name is None:
|
|
||||||
|
|
||||||
def wrapper(tokenizer_cls: _T) -> _T:
|
|
||||||
assert tokenizer_mode not in TokenizerRegistry.REGISTRY
|
|
||||||
TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls
|
|
||||||
|
|
||||||
return tokenizer_cls
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
# OOT tokenizers
|
|
||||||
if tokenizer_mode in TokenizerRegistry.REGISTRY:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"%s.%s is already registered for tokenizer_mode=%r. "
|
"%s.%s is already registered for tokenizer_mode=%r. "
|
||||||
"It is overwritten by the new one.",
|
"It is overwritten by the new one.",
|
||||||
@@ -72,36 +51,42 @@ class TokenizerRegistry:
|
|||||||
tokenizer_mode,
|
tokenizer_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name)
|
self.tokenizers[tokenizer_mode] = (module, class_name)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]:
|
||||||
def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike":
|
if tokenizer_mode not in self.tokenizers:
|
||||||
if tokenizer_mode not in TokenizerRegistry.REGISTRY:
|
|
||||||
raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")
|
raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")
|
||||||
|
|
||||||
item = TokenizerRegistry.REGISTRY[tokenizer_mode]
|
module, class_name = self.tokenizers[tokenizer_mode]
|
||||||
if isinstance(item, type):
|
|
||||||
return item.from_pretrained(*args, **kwargs)
|
|
||||||
|
|
||||||
module, class_name = item
|
|
||||||
logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
|
logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
|
||||||
|
|
||||||
class_ = resolve_obj_by_qualname(f"{module}.{class_name}")
|
return resolve_obj_by_qualname(f"{module}.{class_name}")
|
||||||
return class_.from_pretrained(*args, **kwargs)
|
|
||||||
|
def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike:
|
||||||
|
tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode)
|
||||||
|
return tokenizer_cls.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(
|
TokenizerRegistry = _TokenizerRegistry(
|
||||||
|
{
|
||||||
|
mode: (f"vllm.tokenizers.{mod_relname}", cls_name)
|
||||||
|
for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_tokenizer_args(
|
||||||
tokenizer_name: str | Path,
|
tokenizer_name: str | Path,
|
||||||
*args,
|
*args,
|
||||||
|
runner_type: "RunnerType" = "generate",
|
||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: str = "auto",
|
||||||
trust_remote_code: bool = False,
|
|
||||||
revision: str | None = None,
|
|
||||||
download_dir: str | None = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> TokenizerLike:
|
):
|
||||||
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
|
revision: str | None = kwargs.get("revision")
|
||||||
|
download_dir: str | None = kwargs.get("download_dir")
|
||||||
|
|
||||||
if envs.VLLM_USE_MODELSCOPE:
|
if envs.VLLM_USE_MODELSCOPE:
|
||||||
# download model from ModelScope hub,
|
# download model from ModelScope hub,
|
||||||
# lazy import so that modelscope is not required for normal use.
|
# lazy import so that modelscope is not required for normal use.
|
||||||
@@ -125,16 +110,6 @@ def get_tokenizer(
|
|||||||
)
|
)
|
||||||
tokenizer_name = tokenizer_path
|
tokenizer_name = tokenizer_path
|
||||||
|
|
||||||
if tokenizer_mode == "slow":
|
|
||||||
if kwargs.get("use_fast", False):
|
|
||||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
|
||||||
|
|
||||||
tokenizer_mode = "hf"
|
|
||||||
kwargs["use_fast"] = False
|
|
||||||
|
|
||||||
if "truncation_side" not in kwargs:
|
|
||||||
kwargs["truncation_side"] = "left"
|
|
||||||
|
|
||||||
# Separate model folder from file path for GGUF models
|
# Separate model folder from file path for GGUF models
|
||||||
if is_gguf(tokenizer_name):
|
if is_gguf(tokenizer_name):
|
||||||
if check_gguf_file(tokenizer_name):
|
if check_gguf_file(tokenizer_name):
|
||||||
@@ -150,6 +125,21 @@ def get_tokenizer(
|
|||||||
)
|
)
|
||||||
kwargs["gguf_file"] = gguf_file
|
kwargs["gguf_file"] = gguf_file
|
||||||
|
|
||||||
|
if "truncation_side" not in kwargs:
|
||||||
|
if runner_type == "generate" or runner_type == "draft":
|
||||||
|
kwargs["truncation_side"] = "left"
|
||||||
|
elif runner_type == "pooling":
|
||||||
|
kwargs["truncation_side"] = "right"
|
||||||
|
else:
|
||||||
|
assert_never(runner_type)
|
||||||
|
|
||||||
|
if tokenizer_mode == "slow":
|
||||||
|
if kwargs.get("use_fast", False):
|
||||||
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
|
|
||||||
|
tokenizer_mode = "hf"
|
||||||
|
kwargs["use_fast"] = False
|
||||||
|
|
||||||
# Try to use official Mistral tokenizer if possible
|
# Try to use official Mistral tokenizer if possible
|
||||||
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"):
|
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"):
|
||||||
allow_patterns = ["tekken.json", "tokenizer.model.v*"]
|
allow_patterns = ["tekken.json", "tokenizer.model.v*"]
|
||||||
@@ -165,49 +155,70 @@ def get_tokenizer(
|
|||||||
if tokenizer_mode == "auto":
|
if tokenizer_mode == "auto":
|
||||||
tokenizer_mode = "hf"
|
tokenizer_mode = "hf"
|
||||||
|
|
||||||
tokenizer_args = (tokenizer_name, *args)
|
return tokenizer_mode, tokenizer_name, args, kwargs
|
||||||
tokenizer_kwargs = dict(
|
|
||||||
|
|
||||||
|
cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer_args_from_config(config: "ModelConfig", **kwargs):
|
||||||
|
return cached_resolve_tokenizer_args(
|
||||||
|
config.tokenizer,
|
||||||
|
runner_type=config.runner_type,
|
||||||
|
tokenizer_mode=config.tokenizer_mode,
|
||||||
|
revision=config.tokenizer_revision,
|
||||||
|
trust_remote_code=config.trust_remote_code,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(
|
||||||
|
tokenizer_name: str | Path,
|
||||||
|
*args,
|
||||||
|
tokenizer_cls: type[_T] = TokenizerLike, # type: ignore[assignment]
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
download_dir: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> _T:
|
||||||
|
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
|
||||||
|
tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args(
|
||||||
|
tokenizer_name,
|
||||||
|
*args,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
download_dir=download_dir,
|
download_dir=download_dir,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if tokenizer_mode == "custom":
|
if tokenizer_cls == TokenizerLike:
|
||||||
logger.warning_once(
|
tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
|
||||||
"TokenizerRegistry now uses `tokenizer_mode` as the registry key "
|
else:
|
||||||
"instead of `tokenizer_name`. "
|
tokenizer_cls_ = tokenizer_cls
|
||||||
"Please update the definition of `.from_pretrained` in "
|
|
||||||
"your custom tokenizer to accept `args=%s`, `kwargs=%s`. "
|
|
||||||
"Then, you can pass `tokenizer_mode=%r` instead of "
|
|
||||||
"`tokenizer_mode='custom'` when initializing vLLM.",
|
|
||||||
tokenizer_args,
|
|
||||||
str(tokenizer_kwargs),
|
|
||||||
tokenizer_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer_mode = str(tokenizer_name)
|
tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs)
|
||||||
|
|
||||||
tokenizer = TokenizerRegistry.get_tokenizer(
|
|
||||||
tokenizer_mode,
|
|
||||||
*tokenizer_args,
|
|
||||||
**tokenizer_kwargs,
|
|
||||||
)
|
|
||||||
if not tokenizer.is_fast:
|
if not tokenizer.is_fast:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Using a slow tokenizer. This might cause a significant "
|
"Using a slow tokenizer. This might cause a significant "
|
||||||
"slowdown. Consider using a fast tokenizer instead."
|
"slowdown. Consider using a fast tokenizer instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
|
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
|
||||||
|
if model_config.skip_tokenizer_init:
|
||||||
|
return None
|
||||||
|
|
||||||
return cached_get_tokenizer(
|
return cached_get_tokenizer(
|
||||||
model_config.tokenizer,
|
model_config.tokenizer,
|
||||||
|
runner_type=model_config.runner_type,
|
||||||
tokenizer_mode=model_config.tokenizer_mode,
|
tokenizer_mode=model_config.tokenizer_mode,
|
||||||
revision=model_config.tokenizer_revision,
|
revision=model_config.tokenizer_revision,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
@@ -215,19 +226,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
"Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14."
|
||||||
|
)
|
||||||
def init_tokenizer_from_config(model_config: "ModelConfig"):
|
def init_tokenizer_from_config(model_config: "ModelConfig"):
|
||||||
runner_type = model_config.runner_type
|
return cached_tokenizer_from_config(model_config)
|
||||||
if runner_type == "generate" or runner_type == "draft":
|
|
||||||
truncation_side = "left"
|
|
||||||
elif runner_type == "pooling":
|
|
||||||
truncation_side = "right"
|
|
||||||
else:
|
|
||||||
assert_never(runner_type)
|
|
||||||
|
|
||||||
return get_tokenizer(
|
|
||||||
model_config.tokenizer,
|
|
||||||
tokenizer_mode=model_config.tokenizer_mode,
|
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
revision=model_config.tokenizer_revision,
|
|
||||||
truncation_side=truncation_side,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -60,17 +60,17 @@ def __getattr__(name: str):
|
|||||||
|
|
||||||
return cached_tokenizer_from_config
|
return cached_tokenizer_from_config
|
||||||
if name == "init_tokenizer_from_configs":
|
if name == "init_tokenizer_from_configs":
|
||||||
from vllm.tokenizers import init_tokenizer_from_config
|
from vllm.tokenizers import cached_tokenizer_from_config
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` "
|
"`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` "
|
||||||
"has been moved to `vllm.tokenizers.init_tokenizer_from_config`. "
|
"has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. "
|
||||||
"The old name will be removed in v0.14.",
|
"The old name will be removed in v0.14.",
|
||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
return init_tokenizer_from_config
|
return cached_tokenizer_from_config
|
||||||
|
|
||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from vllm.plugins.io_processors import get_io_processor
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
|
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||||
from vllm.tracing import init_tracer
|
from vllm.tracing import init_tracer
|
||||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
@@ -111,7 +111,7 @@ class AsyncLLM(EngineClient):
|
|||||||
if self.model_config.skip_tokenizer_init:
|
if self.model_config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
else:
|
else:
|
||||||
tokenizer = init_tokenizer_from_config(self.model_config)
|
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||||
|
|
||||||
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
||||||
self.io_processor = get_io_processor(
|
self.io_processor = get_io_processor(
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from vllm.plugins.io_processors import get_io_processor
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
|
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||||
from vllm.tracing import init_tracer
|
from vllm.tracing import init_tracer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
@@ -86,7 +86,7 @@ class LLMEngine:
|
|||||||
if self.model_config.skip_tokenizer_init:
|
if self.model_config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
else:
|
else:
|
||||||
tokenizer = init_tokenizer_from_config(self.model_config)
|
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||||
|
|
||||||
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
||||||
self.io_processor = get_io_processor(
|
self.io_processor = get_io_processor(
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
from vllm.tokenizers import init_tokenizer_from_config
|
from vllm.tokenizers import cached_tokenizer_from_config
|
||||||
from vllm.utils.import_utils import LazyLoader
|
from vllm.utils.import_utils import LazyLoader
|
||||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||||
from vllm.v1.structured_output.backend_types import (
|
from vllm.v1.structured_output.backend_types import (
|
||||||
@@ -71,7 +71,7 @@ class StructuredOutputManager:
|
|||||||
# of CPUs.
|
# of CPUs.
|
||||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self.tokenizer = init_tokenizer_from_config(
|
self.tokenizer = cached_tokenizer_from_config(
|
||||||
model_config=self.vllm_config.model_config
|
model_config=self.vllm_config.model_config
|
||||||
)
|
)
|
||||||
reasoning_parser = (
|
reasoning_parser = (
|
||||||
|
|||||||
Reference in New Issue
Block a user