[Core] Allow specifying custom Executor (#6557)
This commit is contained in:
@@ -7,17 +7,28 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
RayTokenizerGroupPool)
|
||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
||||
TokenizerGroup)
|
||||
|
||||
from ..conftest import get_tokenizer_pool_config
|
||||
|
||||
|
||||
class CustomTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._i = 0
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
self._i += 1
|
||||
return super().encode(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
|
||||
@pytest.mark.parametrize("tokenizer_group_type",
|
||||
[None, "ray", CustomTokenizerGroup])
|
||||
async def test_tokenizer_group(tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group = get_tokenizer_group(
|
||||
@@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type):
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
||||
if tokenizer_group_type is CustomTokenizerGroup:
|
||||
assert tokenizer_group._i > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user