[Misc] Refactor tokenizer interface (#29693)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,7 +10,7 @@ import pytest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers import MistralTokenizer
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def default_tokenizer() -> AnyTokenizer:
|
||||
def default_tokenizer() -> TokenizerLike:
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@@ -270,14 +270,14 @@ async def test_streaming_product_tool_call():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> AnyTokenizer:
|
||||
def qwen_tokenizer() -> TokenizerLike:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_parser(qwen_tokenizer: AnyTokenizer) -> Hermes2ProToolParser:
|
||||
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
|
||||
return Hermes2ProToolParser(qwen_tokenizer)
|
||||
|
||||
|
||||
@@ -291,7 +291,7 @@ def any_chat_request() -> ChatCompletionRequest:
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
@@ -323,7 +323,7 @@ def test_hermes_parser_streaming_just_forward_text(
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
@@ -357,7 +357,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: AnyTokenizer,
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
|
||||
@@ -7,11 +7,11 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(default_tokenizer: AnyTokenizer):
|
||||
def parser(default_tokenizer: TokenizerLike):
|
||||
return Llama3JsonToolParser(default_tokenizer)
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
# Test cases similar to pythonic parser but with Llama4 specific format
|
||||
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
|
||||
@@ -64,7 +64,7 @@ PYTHON_TAG_FUNCTION_OUTPUT = (
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@@ -208,7 +208,7 @@ def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: AnyTokenizer,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
@@ -224,7 +224,7 @@ def test_tool_call(
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@@ -246,7 +246,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
|
||||
@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
@@ -69,7 +69,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
@@ -188,7 +188,7 @@ def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: AnyTokenizer,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
@@ -205,7 +205,7 @@ def test_tool_call(
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
@@ -228,7 +228,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
|
||||
@@ -11,7 +11,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
@@ -61,7 +61,7 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@@ -168,7 +168,7 @@ def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: AnyTokenizer,
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
@@ -185,7 +185,7 @@ def test_tool_call(
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
@@ -208,7 +208,7 @@ def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
class StreamingToolReconstructor:
|
||||
@@ -111,7 +111,7 @@ def run_tool_extraction_nonstreaming(
|
||||
return tool_parser.extract_tool_calls(model_output, request)
|
||||
|
||||
|
||||
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]:
|
||||
def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
|
||||
# Split a string into a series of deltas using the provided tokenizer. Each
|
||||
# delta will be the string equivalent of a single token.
|
||||
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
Reference in New Issue
Block a user