diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 06d4c6e7a..f28ed1733 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -83,7 +83,10 @@ def _assert_model_arch_config( assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype( - model_config.hf_config, model_config.model, revision=model_config.revision + model_config.hf_config, + model_config.model, + revision=model_config.revision, + config_format="hf", ) assert str(torch_dtype) == expected["dtype"] diff --git a/tests/conftest.py b/tests/conftest.py index 1352cdeea..0ecb623ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -365,6 +365,7 @@ class HfRunner: self.config, dtype=dtype, is_pooling_model=is_sentence_transformer or is_cross_encoder, + config_format="hf", ) model_kwargs = model_kwargs if model_kwargs is not None else {} diff --git a/tests/transformers_utils/test_repo_utils.py b/tests/transformers_utils/test_repo_utils.py index 7107ad0f7..e17e3de84 100644 --- a/tests/transformers_utils/test_repo_utils.py +++ b/tests/transformers_utils/test_repo_utils.py @@ -8,7 +8,11 @@ from unittest.mock import MagicMock, call, patch import pytest -from vllm.transformers_utils.repo_utils import list_filtered_repo_files +from vllm.transformers_utils.repo_utils import ( + any_pattern_in_repo_files, + is_mistral_model_repo, + list_filtered_repo_files, +) @pytest.mark.parametrize( @@ -60,3 +64,95 @@ def test_list_filtered_repo_files( repo_type="model", token="token", ) + + +@pytest.mark.parametrize( + ("allow_patterns", "expected_bool"), + [ + (["*.json", "correct*.txt"], True), + ( + ["*.jpeg"], + True, + ), + ( + ["not_found.jpeg"], + False, + ), + ], +) +def test_one_filtered_repo_files(allow_patterns: list[str], expected_bool: bool): + with tempfile.TemporaryDirectory() as tmp_dir: + # Prep folder and files + path_tmp_dir = Path(tmp_dir) + subfolder = path_tmp_dir / "subfolder" + subfolder.mkdir() + (path_tmp_dir / "uncorrect.jpeg").touch() + (subfolder / "correct.txt").touch() + + def _glob_path() -> list[str]: + return [ + str(file.relative_to(path_tmp_dir)) + for file in path_tmp_dir.glob("**/*") + if file.is_file() + ] + + # Patch list_repo_files called by fn + with patch( + "vllm.transformers_utils.repo_utils.list_repo_files", + MagicMock(return_value=_glob_path()), + ) as mock_list_repo_files: + assert ( + any_pattern_in_repo_files( + tmp_dir, allow_patterns, "revision", "model", "token" + ) + ) is expected_bool + assert mock_list_repo_files.call_count == 1 + assert mock_list_repo_files.call_args_list[0] == call( + repo_id=tmp_dir, + revision="revision", + repo_type="model", + token="token", + ) + + +@pytest.mark.parametrize( + ("files", "expected_bool"), + [ + (["consolidated.safetensors", "incorrect.txt"], True), + (["consolidated-1.safetensors", "incorrect.txt"], True), + ( + ["consolidated-1.json"], + False, + ), + ], +) +def test_is_mistral_model_repo(files: list[str], expected_bool: bool): + with tempfile.TemporaryDirectory() as tmp_dir: + # Prep folder and files + path_tmp_dir = Path(tmp_dir) + for file in files: + (path_tmp_dir / file).touch() + + def _glob_path() -> list[str]: + return [ + str(file.relative_to(path_tmp_dir)) + for file in path_tmp_dir.glob("**/*") + if file.is_file() + ] + + # Patch list_repo_files called by fn + with patch( + "vllm.transformers_utils.repo_utils.list_repo_files", + MagicMock(return_value=_glob_path()), + ) as mock_list_repo_files: + assert ( + is_mistral_model_repo(tmp_dir, "revision", "model", "token") + is expected_bool + ) + assert mock_list_repo_files.call_count == 1 + assert mock_list_repo_files.call_args_list[0] == call( + repo_id=tmp_dir, + revision="revision", + repo_type="model", + token="token", + ) diff --git a/vllm/config/model.py b/vllm/config/model.py index 527af4c54..c172069c8 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -565,6 +565,7 @@ class ModelConfig: self.dtype, is_pooling_model=self.runner_type == "pooling", revision=self.revision, + config_format=self.config_format, ) self.original_max_model_len = self.max_model_len @@ -1844,9 +1845,10 @@ def _get_and_verify_dtype( *, is_pooling_model: bool, revision: str | None = None, + config_format: ConfigFormat = "hf", ) -> torch.dtype: config_dtype = ModelArchConfigConvertorBase.get_torch_dtype( - config, model_id, revision=revision + config, model_id, revision=revision, config_format=config_format ) model_type = config.model_type diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index b5088a116..2da7842b0 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib.util from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path @@ -18,7 +17,10 @@ from vllm.transformers_utils.gguf_utils import ( is_remote_gguf, split_remote_gguf, ) -from vllm.transformers_utils.repo_utils import list_filtered_repo_files +from vllm.transformers_utils.repo_utils import ( + any_pattern_in_repo_files, + is_mistral_model_repo, +) from vllm.utils.import_utils import resolve_obj_by_qualname from .protocol import TokenizerLike @@ -142,26 +144,26 @@ def resolve_tokenizer_args( kwargs["use_fast"] = False # Try to use official Mistral tokenizer if possible - if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): - allow_patterns = ["tekken.json", "tokenizer.model.v*"] - files_list = list_filtered_repo_files( + if ( + tokenizer_mode == "auto" + and is_mistral_model_repo( + model_name_or_path=str(tokenizer_name), revision=revision + ) + and any_pattern_in_repo_files( model_name_or_path=str(tokenizer_name), - allow_patterns=allow_patterns, + allow_patterns=["tekken.json", "tokenizer.model.v*"], revision=revision, ) - if len(files_list) > 0: - tokenizer_mode = "mistral" + ): + tokenizer_mode = "mistral" # Try to use Grok2 tiktoken tokenizer if possible - if tokenizer_mode == "auto": - allow_patterns = ["tokenizer.tok.json"] - files_list = list_filtered_repo_files( - model_name_or_path=str(tokenizer_name), - allow_patterns=allow_patterns, - revision=revision, - ) - if len(files_list) > 0: - tokenizer_mode = "grok2" + if tokenizer_mode == "auto" and any_pattern_in_repo_files( + model_name_or_path=str(tokenizer_name), + allow_patterns=["tokenizer.tok.json"], + revision=revision, + ): + tokenizer_mode = "grok2" # Fallback to HF tokenizer if tokenizer_mode == "auto": diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3ec0cc5d0..a29fd17dc 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -23,6 +23,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger +from vllm.transformers_utils.repo_utils import is_mistral_model_repo from vllm.transformers_utils.utils import parse_safetensors_file_metadata from .config_parser_base import ConfigParserBase @@ -49,7 +50,6 @@ except ImportError: ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES, ) - if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig else: @@ -581,7 +581,11 @@ def get_config( try: # First check for Mistral to avoid defaulting to # Transformers implementation. - if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): + if is_mistral_model_repo( + model_name_or_path=str(model), revision=revision + ) and file_or_path_exists( + model=model, config_name=MISTRAL_CONFIG_NAME, revision=revision + ): config_format = "mistral" elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists( model, HF_CONFIG_NAME, revision=revision diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index dddeb2900..bd6b7376e 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from contextlib import contextmanager from typing import final import torch +from huggingface_hub import constants from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from transformers import PretrainedConfig @@ -14,6 +17,7 @@ from vllm.config.model_arch import ( from vllm.config.utils import getattr_iter from vllm.logger import init_logger from vllm.transformers_utils.config import ( + ConfigFormat, try_get_safetensors_metadata, ) from vllm.utils.torch_utils import common_broadcastable_dtype @@ -21,6 +25,22 @@ from vllm.utils.torch_utils import common_broadcastable_dtype logger = init_logger(__name__) +@contextmanager +def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]: + if config_format == "mistral": + hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE + hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE + constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors" + constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json" + try: + yield + finally: + constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file + constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file + else: + yield + + class ModelArchConfigConvertorBase: def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig): self.hf_config = hf_config @@ -123,7 +143,11 @@ class ModelArchConfigConvertorBase: @final @classmethod def get_torch_dtype( - cls, hf_config: PretrainedConfig, model_id: str, revision: str | None + cls, + hf_config: PretrainedConfig, + model_id: str, + revision: str | None, + config_format: ConfigFormat, ): # NOTE: getattr(config, "dtype", torch.float32) is not correct # because config.dtype can be None. @@ -140,7 +164,8 @@ class ModelArchConfigConvertorBase: # Try to read the dtype of the weights if they are in safetensors format if config_dtype is None: - repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + with _maybe_patch_hf_hub_constants(config_format): + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) if repo_mt and (files_mt := repo_mt.files_metadata): param_dtypes: set[torch.dtype] = { diff --git a/vllm/transformers_utils/repo_utils.py b/vllm/transformers_utils/repo_utils.py index a55bdf36a..552e053b2 100644 --- a/vllm/transformers_utils/repo_utils.py +++ b/vllm/transformers_utils/repo_utils.py @@ -127,6 +127,42 @@ def list_filtered_repo_files( return file_list +def any_pattern_in_repo_files( + model_name_or_path: str, + allow_patterns: list[str], + revision: str | None = None, + repo_type: str | None = None, + token: str | bool | None = None, +): + return ( + len( + list_filtered_repo_files( + model_name_or_path=model_name_or_path, + allow_patterns=allow_patterns, + revision=revision, + repo_type=repo_type, + token=token, + ) + ) + > 0 + ) + + +def is_mistral_model_repo( + model_name_or_path: str, + revision: str | None = None, + repo_type: str | None = None, + token: str | bool | None = None, +) -> bool: + return any_pattern_in_repo_files( + model_name_or_path=model_name_or_path, + allow_patterns=["consolidated*.safetensors"], + revision=revision, + repo_type=repo_type, + token=token, + ) + + def file_exists( repo_id: str, file_name: str,