[Experimental] Add multi-LoRA support (#1804)
Co-authored-by: Chen Shen <scv119@gmail.com> Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com> Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
69
tests/lora/test_tokenizer.py
Normal file
69
tests/lora/test_tokenizer.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transformers_tokenizer():
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer.get_lora_tokenizer(
|
||||
None) == await tokenizer.get_lora_tokenizer_async(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transformers_tokenizer_lora(sql_lora_files):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
|
||||
tokenizer = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=True,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
||||
request_id="request_id", prompt="prompt", lora_request=lora_request)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
||||
prompt="prompt",
|
||||
lora_request=lora_request)
|
||||
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer.get_lora_tokenizer(
|
||||
None) == await tokenizer.get_lora_tokenizer_async(None)
|
||||
|
||||
assert isinstance(tokenizer.get_lora_tokenizer(lora_request),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer.get_lora_tokenizer(
|
||||
lora_request) != tokenizer.get_lora_tokenizer(None)
|
||||
assert tokenizer.get_lora_tokenizer(
|
||||
lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
|
||||
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
|
||||
lora_request = None
|
||||
tokenizer = get_lora_tokenizer(lora_request)
|
||||
assert not tokenizer
|
||||
|
||||
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||
tokenizer = get_lora_tokenizer(lora_request)
|
||||
assert tokenizer.get_added_vocab()
|
||||
|
||||
lora_request = LoRARequest("1", 1, str(tmpdir))
|
||||
tokenizer = get_lora_tokenizer(lora_request)
|
||||
assert not tokenizer
|
||||
Reference in New Issue
Block a user