[Refactor] TokenizerRegistry only uses lazy imports (#30609)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar, overload
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import huggingface_hub
|
||||
from typing_extensions import assert_never
|
||||
from typing_extensions import TypeVar, assert_never, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@@ -24,46 +24,25 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from .protocol import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.model import ModelConfig, RunnerType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_T = TypeVar("_T", bound=type[TokenizerLike])
|
||||
|
||||
_VLLM_TOKENIZERS = {
|
||||
"deepseekv32": ("deepseekv32", "DeepseekV32Tokenizer"),
|
||||
"hf": ("hf", "CachedHfTokenizer"),
|
||||
"mistral": ("mistral", "MistralTokenizer"),
|
||||
}
|
||||
|
||||
|
||||
class TokenizerRegistry:
|
||||
# Tokenizer name -> tokenizer_cls or (tokenizer module, tokenizer class)
|
||||
REGISTRY: dict[str, type[TokenizerLike] | tuple[str, str]] = {}
|
||||
@dataclass
|
||||
class _TokenizerRegistry:
|
||||
# Tokenizer mode -> (tokenizer module, tokenizer class)
|
||||
tokenizers: dict[str, tuple[str, str]] = field(default_factory=dict)
|
||||
|
||||
# In-tree tokenizers
|
||||
@staticmethod
|
||||
@overload
|
||||
def register(tokenizer_mode: str) -> Callable[[_T], _T]: ...
|
||||
|
||||
# OOT tokenizers
|
||||
@staticmethod
|
||||
@overload
|
||||
def register(tokenizer_mode: str, module: str, class_name: str) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def register(
|
||||
tokenizer_mode: str,
|
||||
module: str | None = None,
|
||||
class_name: str | None = None,
|
||||
) -> Callable[[_T], _T] | None:
|
||||
# In-tree tokenizers
|
||||
if module is None or class_name is None:
|
||||
|
||||
def wrapper(tokenizer_cls: _T) -> _T:
|
||||
assert tokenizer_mode not in TokenizerRegistry.REGISTRY
|
||||
TokenizerRegistry.REGISTRY[tokenizer_mode] = tokenizer_cls
|
||||
|
||||
return tokenizer_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
# OOT tokenizers
|
||||
if tokenizer_mode in TokenizerRegistry.REGISTRY:
|
||||
def register(self, tokenizer_mode: str, module: str, class_name: str) -> None:
|
||||
if tokenizer_mode in self.tokenizers:
|
||||
logger.warning(
|
||||
"%s.%s is already registered for tokenizer_mode=%r. "
|
||||
"It is overwritten by the new one.",
|
||||
@@ -72,36 +51,42 @@ class TokenizerRegistry:
|
||||
tokenizer_mode,
|
||||
)
|
||||
|
||||
TokenizerRegistry.REGISTRY[tokenizer_mode] = (module, class_name)
|
||||
self.tokenizers[tokenizer_mode] = (module, class_name)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer(tokenizer_mode: str, *args, **kwargs) -> "TokenizerLike":
|
||||
if tokenizer_mode not in TokenizerRegistry.REGISTRY:
|
||||
def load_tokenizer_cls(self, tokenizer_mode: str) -> type[TokenizerLike]:
|
||||
if tokenizer_mode not in self.tokenizers:
|
||||
raise ValueError(f"No tokenizer registered for {tokenizer_mode=!r}.")
|
||||
|
||||
item = TokenizerRegistry.REGISTRY[tokenizer_mode]
|
||||
if isinstance(item, type):
|
||||
return item.from_pretrained(*args, **kwargs)
|
||||
|
||||
module, class_name = item
|
||||
module, class_name = self.tokenizers[tokenizer_mode]
|
||||
logger.debug_once(f"Loading {class_name} for {tokenizer_mode=!r}")
|
||||
|
||||
class_ = resolve_obj_by_qualname(f"{module}.{class_name}")
|
||||
return class_.from_pretrained(*args, **kwargs)
|
||||
return resolve_obj_by_qualname(f"{module}.{class_name}")
|
||||
|
||||
def load_tokenizer(self, tokenizer_mode: str, *args, **kwargs) -> TokenizerLike:
|
||||
tokenizer_cls = self.load_tokenizer_cls(tokenizer_mode)
|
||||
return tokenizer_cls.from_pretrained(*args, **kwargs)
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
TokenizerRegistry = _TokenizerRegistry(
|
||||
{
|
||||
mode: (f"vllm.tokenizers.{mod_relname}", cls_name)
|
||||
for mode, (mod_relname, cls_name) in _VLLM_TOKENIZERS.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def resolve_tokenizer_args(
|
||||
tokenizer_name: str | Path,
|
||||
*args,
|
||||
runner_type: "RunnerType" = "generate",
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
revision: str | None = None,
|
||||
download_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> TokenizerLike:
|
||||
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
|
||||
):
|
||||
revision: str | None = kwargs.get("revision")
|
||||
download_dir: str | None = kwargs.get("download_dir")
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
@@ -125,16 +110,6 @@ def get_tokenizer(
|
||||
)
|
||||
tokenizer_name = tokenizer_path
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
|
||||
tokenizer_mode = "hf"
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if "truncation_side" not in kwargs:
|
||||
kwargs["truncation_side"] = "left"
|
||||
|
||||
# Separate model folder from file path for GGUF models
|
||||
if is_gguf(tokenizer_name):
|
||||
if check_gguf_file(tokenizer_name):
|
||||
@@ -150,6 +125,21 @@ def get_tokenizer(
|
||||
)
|
||||
kwargs["gguf_file"] = gguf_file
|
||||
|
||||
if "truncation_side" not in kwargs:
|
||||
if runner_type == "generate" or runner_type == "draft":
|
||||
kwargs["truncation_side"] = "left"
|
||||
elif runner_type == "pooling":
|
||||
kwargs["truncation_side"] = "right"
|
||||
else:
|
||||
assert_never(runner_type)
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
|
||||
tokenizer_mode = "hf"
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
# Try to use official Mistral tokenizer if possible
|
||||
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"):
|
||||
allow_patterns = ["tekken.json", "tokenizer.model.v*"]
|
||||
@@ -165,49 +155,70 @@ def get_tokenizer(
|
||||
if tokenizer_mode == "auto":
|
||||
tokenizer_mode = "hf"
|
||||
|
||||
tokenizer_args = (tokenizer_name, *args)
|
||||
tokenizer_kwargs = dict(
|
||||
return tokenizer_mode, tokenizer_name, args, kwargs
|
||||
|
||||
|
||||
cached_resolve_tokenizer_args = lru_cache(resolve_tokenizer_args)
|
||||
|
||||
|
||||
def tokenizer_args_from_config(config: "ModelConfig", **kwargs):
|
||||
return cached_resolve_tokenizer_args(
|
||||
config.tokenizer,
|
||||
runner_type=config.runner_type,
|
||||
tokenizer_mode=config.tokenizer_mode,
|
||||
revision=config.tokenizer_revision,
|
||||
trust_remote_code=config.trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str | Path,
|
||||
*args,
|
||||
tokenizer_cls: type[_T] = TokenizerLike, # type: ignore[assignment]
|
||||
trust_remote_code: bool = False,
|
||||
revision: str | None = None,
|
||||
download_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> _T:
|
||||
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
|
||||
tokenizer_mode, tokenizer_name, args, kwargs = cached_resolve_tokenizer_args(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
download_dir=download_dir,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if tokenizer_mode == "custom":
|
||||
logger.warning_once(
|
||||
"TokenizerRegistry now uses `tokenizer_mode` as the registry key "
|
||||
"instead of `tokenizer_name`. "
|
||||
"Please update the definition of `.from_pretrained` in "
|
||||
"your custom tokenizer to accept `args=%s`, `kwargs=%s`. "
|
||||
"Then, you can pass `tokenizer_mode=%r` instead of "
|
||||
"`tokenizer_mode='custom'` when initializing vLLM.",
|
||||
tokenizer_args,
|
||||
str(tokenizer_kwargs),
|
||||
tokenizer_name,
|
||||
)
|
||||
if tokenizer_cls == TokenizerLike:
|
||||
tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
|
||||
else:
|
||||
tokenizer_cls_ = tokenizer_cls
|
||||
|
||||
tokenizer_mode = str(tokenizer_name)
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer(
|
||||
tokenizer_mode,
|
||||
*tokenizer_args,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
tokenizer = tokenizer_cls_.from_pretrained(tokenizer_name, *args, **kwargs)
|
||||
if not tokenizer.is_fast:
|
||||
logger.warning(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead."
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
return tokenizer # type: ignore
|
||||
|
||||
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
|
||||
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
|
||||
if model_config.skip_tokenizer_init:
|
||||
return None
|
||||
|
||||
return cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
runner_type=model_config.runner_type,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
revision=model_config.tokenizer_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
@@ -215,19 +226,8 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14."
|
||||
)
|
||||
def init_tokenizer_from_config(model_config: "ModelConfig"):
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type == "generate" or runner_type == "draft":
|
||||
truncation_side = "left"
|
||||
elif runner_type == "pooling":
|
||||
truncation_side = "right"
|
||||
else:
|
||||
assert_never(runner_type)
|
||||
|
||||
return get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
return cached_tokenizer_from_config(model_config)
|
||||
|
||||
Reference in New Issue
Block a user