Improve Mistral format checks. (#33253)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: juliendenize <julien.denize@mistral.ai> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -83,7 +83,10 @@ def _assert_model_arch_config(
|
||||
assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"]
|
||||
|
||||
torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
|
||||
model_config.hf_config, model_config.model, revision=model_config.revision
|
||||
model_config.hf_config,
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
config_format="hf",
|
||||
)
|
||||
assert str(torch_dtype) == expected["dtype"]
|
||||
|
||||
|
||||
@@ -365,6 +365,7 @@ class HfRunner:
|
||||
self.config,
|
||||
dtype=dtype,
|
||||
is_pooling_model=is_sentence_transformer or is_cross_encoder,
|
||||
config_format="hf",
|
||||
)
|
||||
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
|
||||
@@ -8,7 +8,11 @@ from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
from vllm.transformers_utils.repo_utils import (
|
||||
any_pattern_in_repo_files,
|
||||
is_mistral_model_repo,
|
||||
list_filtered_repo_files,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -60,3 +64,95 @@ def test_list_filtered_repo_files(
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_patterns", "expected_bool"),
|
||||
[
|
||||
(["*.json", "correct*.txt"], True),
|
||||
(
|
||||
["*.jpeg"],
|
||||
True,
|
||||
),
|
||||
(
|
||||
["not_found.jpeg"],
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_one_filtered_repo_files(allow_patterns: list[str], expected_bool: bool):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
subfolder = path_tmp_dir / "subfolder"
|
||||
subfolder.mkdir()
|
||||
(path_tmp_dir / "uncorrect.jpeg").touch()
|
||||
(subfolder / "correct.txt").touch()
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
assert (
|
||||
any_pattern_in_repo_files(
|
||||
tmp_dir, allow_patterns, "revision", "model", "token"
|
||||
)
|
||||
) is expected_bool
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("files", "expected_bool"),
|
||||
[
|
||||
(["consolidated.safetensors", "incorrect.txt"], True),
|
||||
(["consolidated-1.safetensors", "incorrect.txt"], True),
|
||||
(
|
||||
["consolidated-1.json"],
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_mistral_model_repo(files: list[str], expected_bool: bool):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
for file in files:
|
||||
(path_tmp_dir / file).touch()
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
assert (
|
||||
is_mistral_model_repo(tmp_dir, "revision", "model", "token")
|
||||
is expected_bool
|
||||
)
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
|
||||
@@ -565,6 +565,7 @@ class ModelConfig:
|
||||
self.dtype,
|
||||
is_pooling_model=self.runner_type == "pooling",
|
||||
revision=self.revision,
|
||||
config_format=self.config_format,
|
||||
)
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
@@ -1844,9 +1845,10 @@ def _get_and_verify_dtype(
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
revision: str | None = None,
|
||||
config_format: ConfigFormat = "hf",
|
||||
) -> torch.dtype:
|
||||
config_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
|
||||
config, model_id, revision=revision
|
||||
config, model_id, revision=revision, config_format=config_format
|
||||
)
|
||||
model_type = config.model_type
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
@@ -18,7 +17,10 @@ from vllm.transformers_utils.gguf_utils import (
|
||||
is_remote_gguf,
|
||||
split_remote_gguf,
|
||||
)
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
from vllm.transformers_utils.repo_utils import (
|
||||
any_pattern_in_repo_files,
|
||||
is_mistral_model_repo,
|
||||
)
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
from .protocol import TokenizerLike
|
||||
@@ -142,26 +144,26 @@ def resolve_tokenizer_args(
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
# Try to use official Mistral tokenizer if possible
|
||||
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"):
|
||||
allow_patterns = ["tekken.json", "tokenizer.model.v*"]
|
||||
files_list = list_filtered_repo_files(
|
||||
if (
|
||||
tokenizer_mode == "auto"
|
||||
and is_mistral_model_repo(
|
||||
model_name_or_path=str(tokenizer_name), revision=revision
|
||||
)
|
||||
and any_pattern_in_repo_files(
|
||||
model_name_or_path=str(tokenizer_name),
|
||||
allow_patterns=allow_patterns,
|
||||
allow_patterns=["tekken.json", "tokenizer.model.v*"],
|
||||
revision=revision,
|
||||
)
|
||||
if len(files_list) > 0:
|
||||
tokenizer_mode = "mistral"
|
||||
):
|
||||
tokenizer_mode = "mistral"
|
||||
|
||||
# Try to use Grok2 tiktoken tokenizer if possible
|
||||
if tokenizer_mode == "auto":
|
||||
allow_patterns = ["tokenizer.tok.json"]
|
||||
files_list = list_filtered_repo_files(
|
||||
model_name_or_path=str(tokenizer_name),
|
||||
allow_patterns=allow_patterns,
|
||||
revision=revision,
|
||||
)
|
||||
if len(files_list) > 0:
|
||||
tokenizer_mode = "grok2"
|
||||
if tokenizer_mode == "auto" and any_pattern_in_repo_files(
|
||||
model_name_or_path=str(tokenizer_name),
|
||||
allow_patterns=["tokenizer.tok.json"],
|
||||
revision=revision,
|
||||
):
|
||||
tokenizer_mode = "grok2"
|
||||
|
||||
# Fallback to HF tokenizer
|
||||
if tokenizer_mode == "auto":
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.repo_utils import is_mistral_model_repo
|
||||
from vllm.transformers_utils.utils import parse_safetensors_file_metadata
|
||||
|
||||
from .config_parser_base import ConfigParserBase
|
||||
@@ -49,7 +50,6 @@ except ImportError:
|
||||
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
|
||||
)
|
||||
|
||||
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
else:
|
||||
@@ -581,7 +581,11 @@ def get_config(
|
||||
try:
|
||||
# First check for Mistral to avoid defaulting to
|
||||
# Transformers implementation.
|
||||
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
|
||||
if is_mistral_model_repo(
|
||||
model_name_or_path=str(model), revision=revision
|
||||
) and file_or_path_exists(
|
||||
model=model, config_name=MISTRAL_CONFIG_NAME, revision=revision
|
||||
):
|
||||
config_format = "mistral"
|
||||
elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
|
||||
model, HF_CONFIG_NAME, revision=revision
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
from huggingface_hub import constants
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
@@ -14,6 +17,7 @@ from vllm.config.model_arch import (
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat,
|
||||
try_get_safetensors_metadata,
|
||||
)
|
||||
from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||
@@ -21,6 +25,22 @@ from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]:
|
||||
if config_format == "mistral":
|
||||
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
|
||||
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
|
||||
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
|
||||
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
|
||||
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class ModelArchConfigConvertorBase:
|
||||
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
|
||||
self.hf_config = hf_config
|
||||
@@ -123,7 +143,11 @@ class ModelArchConfigConvertorBase:
|
||||
@final
|
||||
@classmethod
|
||||
def get_torch_dtype(
|
||||
cls, hf_config: PretrainedConfig, model_id: str, revision: str | None
|
||||
cls,
|
||||
hf_config: PretrainedConfig,
|
||||
model_id: str,
|
||||
revision: str | None,
|
||||
config_format: ConfigFormat,
|
||||
):
|
||||
# NOTE: getattr(config, "dtype", torch.float32) is not correct
|
||||
# because config.dtype can be None.
|
||||
@@ -140,7 +164,8 @@ class ModelArchConfigConvertorBase:
|
||||
|
||||
# Try to read the dtype of the weights if they are in safetensors format
|
||||
if config_dtype is None:
|
||||
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
|
||||
with _maybe_patch_hf_hub_constants(config_format):
|
||||
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
|
||||
|
||||
if repo_mt and (files_mt := repo_mt.files_metadata):
|
||||
param_dtypes: set[torch.dtype] = {
|
||||
|
||||
@@ -127,6 +127,42 @@ def list_filtered_repo_files(
|
||||
return file_list
|
||||
|
||||
|
||||
def any_pattern_in_repo_files(
|
||||
model_name_or_path: str,
|
||||
allow_patterns: list[str],
|
||||
revision: str | None = None,
|
||||
repo_type: str | None = None,
|
||||
token: str | bool | None = None,
|
||||
):
|
||||
return (
|
||||
len(
|
||||
list_filtered_repo_files(
|
||||
model_name_or_path=model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
token=token,
|
||||
)
|
||||
)
|
||||
> 0
|
||||
)
|
||||
|
||||
|
||||
def is_mistral_model_repo(
|
||||
model_name_or_path: str,
|
||||
revision: str | None = None,
|
||||
repo_type: str | None = None,
|
||||
token: str | bool | None = None,
|
||||
) -> bool:
|
||||
return any_pattern_in_repo_files(
|
||||
model_name_or_path=model_name_or_path,
|
||||
allow_patterns=["consolidated*.safetensors"],
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
token=token,
|
||||
)
|
||||
|
||||
|
||||
def file_exists(
|
||||
repo_id: str,
|
||||
file_name: str,
|
||||
|
||||
Reference in New Issue
Block a user