Improve Mistral format checks. (#33253)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: juliendenize <julien.denize@mistral.ai> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -8,7 +8,11 @@ from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
from vllm.transformers_utils.repo_utils import (
|
||||
any_pattern_in_repo_files,
|
||||
is_mistral_model_repo,
|
||||
list_filtered_repo_files,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -60,3 +64,95 @@ def test_list_filtered_repo_files(
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_patterns", "expected_bool"),
|
||||
[
|
||||
(["*.json", "correct*.txt"], True),
|
||||
(
|
||||
["*.jpeg"],
|
||||
True,
|
||||
),
|
||||
(
|
||||
["not_found.jpeg"],
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_one_filtered_repo_files(allow_patterns: list[str], expected_bool: bool):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
subfolder = path_tmp_dir / "subfolder"
|
||||
subfolder.mkdir()
|
||||
(path_tmp_dir / "uncorrect.jpeg").touch()
|
||||
(subfolder / "correct.txt").touch()
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
assert (
|
||||
any_pattern_in_repo_files(
|
||||
tmp_dir, allow_patterns, "revision", "model", "token"
|
||||
)
|
||||
) is expected_bool
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("files", "expected_bool"),
|
||||
[
|
||||
(["consolidated.safetensors", "incorrect.txt"], True),
|
||||
(["consolidated-1.safetensors", "incorrect.txt"], True),
|
||||
(
|
||||
["consolidated-1.json"],
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_mistral_model_repo(files: list[str], expected_bool: bool):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
for file in files:
|
||||
(path_tmp_dir / file).touch()
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
assert (
|
||||
is_mistral_model_repo(tmp_dir, "revision", "model", "token")
|
||||
is expected_bool
|
||||
)
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user