Improve Mistral format checks. (#33253)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Signed-off-by: juliendenize <julien.denize@mistral.ai>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Julien Denize
2026-01-30 15:23:33 +01:00
committed by GitHub
parent a11bc12d53
commit ae5b7aff2b
8 changed files with 193 additions and 24 deletions

View File

@@ -8,7 +8,11 @@ from unittest.mock import MagicMock, call, patch
import pytest
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
from vllm.transformers_utils.repo_utils import (
any_pattern_in_repo_files,
is_mistral_model_repo,
list_filtered_repo_files,
)
@pytest.mark.parametrize(
@@ -60,3 +64,95 @@ def test_list_filtered_repo_files(
repo_type="model",
token="token",
)
@pytest.mark.parametrize(
("allow_patterns", "expected_bool"),
[
(["*.json", "correct*.txt"], True),
(
["*.jpeg"],
True,
),
(
["not_found.jpeg"],
False,
),
],
)
def test_one_filtered_repo_files(allow_patterns: list[str], expected_bool: bool):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
assert (
any_pattern_in_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
) is expected_bool
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)
@pytest.mark.parametrize(
("files", "expected_bool"),
[
(["consolidated.safetensors", "incorrect.txt"], True),
(["consolidated-1.safetensors", "incorrect.txt"], True),
(
["consolidated-1.json"],
False,
),
],
)
def test_is_mistral_model_repo(files: list[str], expected_bool: bool):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
for file in files:
(path_tmp_dir / file).touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
assert (
is_mistral_model_repo(tmp_dir, "revision", "model", "token")
is expected_bool
)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)