Simplify TokenizerGroup (#16790)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,17 +5,14 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
|
||||
from ..conftest import get_tokenizer_pool_config
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
|
||||
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=True,
|
||||
max_num_seqs=1,
|
||||
@@ -60,8 +57,7 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path):
|
||||
@pytest.mark.parametrize("max_num_seqs", [1, 2])
|
||||
@pytest.mark.parametrize("max_loras", [1, 2])
|
||||
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(None),
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=enable_lora,
|
||||
max_num_seqs=max_num_seqs,
|
||||
|
||||
Reference in New Issue
Block a user