Asynchronous tokenization (#2879)
This commit is contained in:
0
tests/tokenization/__init__.py
Normal file
0
tests/tokenization/__init__.py
Normal file
20
tests/tokenization/test_cached_tokenizer.py
Normal file
20
tests/tokenization/test_cached_tokenizer.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from copy import deepcopy
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def test_cached_tokenizer():
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
|
||||
reference_tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": ["<SEP>"]})
|
||||
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
|
||||
|
||||
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
|
||||
"prompt")
|
||||
assert set(reference_tokenizer.all_special_ids) == set(
|
||||
cached_tokenizer.all_special_ids)
|
||||
assert set(reference_tokenizer.all_special_tokens) == set(
|
||||
cached_tokenizer.all_special_tokens)
|
||||
assert set(reference_tokenizer.all_special_tokens_extended) == set(
|
||||
cached_tokenizer.all_special_tokens_extended)
|
||||
62
tests/tokenization/test_detokenize.py
Normal file
62
tests/tokenization/test_detokenize.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test", # noqa: E501
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501
|
||||
"我很感谢你的热情" # noqa: E501
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
else:
|
||||
prev_tokens += new_tokens
|
||||
return decoded_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
if skip_special_tokens:
|
||||
all_input_ids = ([tokenizer.bos_token_id]
|
||||
if tokenizer.bos_token_id is not None else
|
||||
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
assert decoded_text == truth
|
||||
100
tests/tokenization/test_tokenizer_group.py
Normal file
100
tests/tokenization/test_tokenizer_group.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
RayTokenizerGroupPool)
|
||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
||||
TokenizerGroup)
|
||||
from ..conftest import get_tokenizer_pool_config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
|
||||
async def test_tokenizer_group(tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer_group.encode_async(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_pool(tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group_pool = get_tokenizer_group(
|
||||
get_tokenizer_pool_config(tokenizer_group_type),
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
# Send multiple requests to the tokenizer group pool
|
||||
# (more than the pool size)
|
||||
# and check that all requests are processed correctly.
|
||||
num_requests = tokenizer_group_pool.pool_size * 5
|
||||
requests = [
|
||||
tokenizer_group_pool.encode_async(request_id=str(i),
|
||||
prompt=f"prompt {i}",
|
||||
lora_request=None)
|
||||
for i in range(num_requests)
|
||||
]
|
||||
results = await asyncio.gather(*requests)
|
||||
expected_results = [
|
||||
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
|
||||
]
|
||||
assert results == expected_results
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_ray_pool_env_var_propagation(
|
||||
tokenizer_group_type):
|
||||
"""Test that env vars from caller process are propagated to
|
||||
tokenizer Ray actors."""
|
||||
env_var = "MY_ENV_VAR"
|
||||
|
||||
class EnvVarCheckerTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def ping(self):
|
||||
assert os.environ.get(env_var) == "1"
|
||||
return super().ping()
|
||||
|
||||
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
|
||||
_worker_cls = EnvVarCheckerTokenizerGroup
|
||||
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None)
|
||||
with pytest.raises(AssertionError):
|
||||
tokenizer_pool.ping()
|
||||
|
||||
with patch.dict(os.environ, {env_var: "1"}):
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None)
|
||||
tokenizer_pool.ping()
|
||||
Reference in New Issue
Block a user