[Misc] Refactor tokenizer interface (#29693)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,62 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This test file includes some cases where it is inappropriate to
|
||||
only get the `eos_token_id` from the tokenizer as defined by
|
||||
`vllm.LLMEngine._get_eos_token_id`.
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
def test_get_llama3_eos_token():
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
import pytest
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 128009
|
||||
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == [128001, 128008, 128009]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_patterns,expected_relative_files",
|
||||
[
|
||||
(
|
||||
["*.json", "correct*.txt"],
|
||||
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_list_filtered_repo_files(
|
||||
allow_patterns: list[str], expected_relative_files: list[str]
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
subfolder = path_tmp_dir / "subfolder"
|
||||
subfolder.mkdir()
|
||||
(path_tmp_dir / "json_file.json").touch()
|
||||
(path_tmp_dir / "correct_2.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.jpeg").touch()
|
||||
(subfolder / "correct.txt").touch()
|
||||
(subfolder / "uncorrect_sub.txt").touch()
|
||||
def test_get_blip2_eos_token():
|
||||
model_name = "Salesforce/blip2-opt-2.7b"
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 2
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
out_files = sorted(
|
||||
list_filtered_repo_files(
|
||||
tmp_dir, allow_patterns, "revision", "model", "token"
|
||||
)
|
||||
)
|
||||
assert out_files == sorted(expected_relative_files)
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == 50118
|
||||
|
||||
Reference in New Issue
Block a user