Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,17 +6,16 @@ from copy import deepcopy
import pytest
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer
@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)
reference_tokenizer = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True
)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]})
reference_tokenizer.add_special_tokens({"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
_check_consistency(cached_tokenizer, reference_tokenizer)
@@ -32,13 +31,13 @@ def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert (target.all_special_tokens_extended ==
expected.all_special_tokens_extended)
assert target.all_special_tokens_extended == expected.all_special_tokens_extended
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)
# Other attributes
assert getattr(target, "padding_side",
None) == getattr(expected, "padding_side", None)
assert getattr(target, "padding_side", None) == getattr(
expected, "padding_side", None
)
assert target.encode("prompt") == expected.encode("prompt")

View File

@@ -5,15 +5,16 @@ from collections.abc import Generator
from typing import Any, Optional
import pytest
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer)
from vllm.v1.engine.detokenizer import (
FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer,
)
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
@@ -45,33 +46,35 @@ TOKENIZERS = [
]
def _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: Optional[bool] = None):
def _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: Optional[bool] = None,
):
prompt_token_ids = all_input_ids[:starting_index]
params = SamplingParams(
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest(request_id="",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None)
request = EngineCoreRequest(
request_id="",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer, request)
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else:
@@ -88,9 +91,11 @@ def _run_incremental_decode(tokenizer,
@pytest.fixture
def tokenizer(tokenizer_name):
return (MistralTokenizer.from_pretrained(tokenizer_name)
if "mistral" in tokenizer_name else
AutoTokenizer.from_pretrained(tokenizer_name))
return (
MistralTokenizer.from_pretrained(tokenizer_name)
if "mistral" in tokenizer_name
else AutoTokenizer.from_pretrained(tokenizer_name)
)
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@@ -102,7 +107,8 @@ def tokenizer(tokenizer_name):
"ပုံပြင်လေးပြောပြပါ",
# Using "URGENCY" since "CY" has token id 130282
"URGENCY🌶",
])
],
)
def test_mistral_edge_case(tokenizer, truth):
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
@@ -115,7 +121,8 @@ def test_mistral_edge_case(tokenizer, truth):
tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
starting_index=starting_index,
)
assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@@ -124,8 +131,10 @@ def test_mistral_edge_case(tokenizer, truth):
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
if "mistral" in tokenizer_name:
yield (
True if request.param else
pytest.skip("mistral doesn't support skip_special_tokens=False"))
True
if request.param
else pytest.skip("mistral doesn't support skip_special_tokens=False")
)
else:
yield bool(request.param)
@@ -136,8 +145,14 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
spaces_between_special_tokens, fast):
def test_decode_streaming(
tokenizer,
truth,
with_prompt,
skip_special_tokens,
spaces_between_special_tokens,
fast,
):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
@@ -146,30 +161,35 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens({
"additional_special_tokens": [
at for at in
tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
})
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
at
for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
}
)
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
extra_decode_args = (
{}
if not isinstance(tokenizer, PreTrainedTokenizer)
else {"spaces_between_special_tokens": spaces_between_special_tokens}
)
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
if tokenizer.bos_token_id is not None:
truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(truth_tokens,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
new_truth = tokenizer.decode(
truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
)
if with_prompt:
num_prompt_tokens = len(
tokenizer(truth[:len(truth) // 2],
add_special_tokens=False).input_ids)
tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
@@ -177,11 +197,13 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
prompt = tokenizer.decode(
prompt_input_ids,
skip_special_tokens=skip_special_tokens,
**extra_decode_args,
)
generated = new_truth[len(prompt):]
generated = new_truth[len(prompt) :]
else:
generated = new_truth
starting_index = 0
@@ -193,7 +215,8 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast)
fast=fast,
)
assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
@@ -206,11 +229,13 @@ def test_oov_decode(tokenizer, fast):
pytest.skip()
decoded_text, out_ids = _run_incremental_decode(
tokenizer, [len(tokenizer)],
tokenizer,
[len(tokenizer)],
skip_special_tokens=True,
starting_index=0,
spaces_between_special_tokens=True,
fast=fast)
fast=fast,
)
assert decoded_text == ''
assert decoded_text == ""
assert out_ids == [len(tokenizer)]

View File

@@ -13,6 +13,6 @@ TOKENIZER_NAMES = ["BAAI/bge-base-en"]
def test_special_tokens(tokenizer_name: str, n_tokens: int):
tokenizer = get_tokenizer(tokenizer_name, revision="main")
prompts = '[UNK]' * n_tokens
prompts = "[UNK]" * n_tokens
prompt_token_ids = tokenizer.encode(prompts)
assert len(prompt_token_ids) == n_tokens + 2

View File

@@ -5,6 +5,7 @@ This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
{meth}`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -15,8 +16,7 @@ def test_get_llama3_eos_token():
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128008, 128009]
@@ -27,7 +27,6 @@ def test_get_blip2_eos_token():
tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2
generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118

View File

@@ -2,187 +2,206 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from mistral_common.protocol.instruct.messages import (AssistantMessage,
ToolMessage,
UserMessage)
from mistral_common.protocol.instruct.messages import (
AssistantMessage,
ToolMessage,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import (Function,
FunctionCall, Tool,
ToolCall)
from mistral_common.protocol.instruct.tool_calls import (
Function,
FunctionCall,
Tool,
ToolCall,
)
from vllm.transformers_utils.tokenizers.mistral import (
make_mistral_chat_completion_request)
make_mistral_chat_completion_request,
)
@pytest.mark.parametrize(
"openai_request,expected_mistral_request",
[(
{
"messages": [{
"role": "user",
"content": "What is the current local date and time?",
}],
"tools": [{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
),
(
{
"messages":
[{
"role": "user",
"content": "What is the current local date and time?",
}],
"tools": [{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": None,
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(
content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
)],
)
def test_make_mistral_chat_completion_request(openai_request,
expected_mistral_request):
actual_request = make_mistral_chat_completion_request(
openai_request["messages"], openai_request["tools"])
assert actual_request == expected_mistral_request
# Tool use with list content and reasoning_content
@pytest.mark.parametrize("openai_request,expected_mistral_request", [(
{
"messages": [
[
(
{
"role": "user",
"content": "What's the weather in Paris?",
"messages": [
{
"role": "user",
"content": "What is the current local date and time?",
}
],
"tools": [
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
},
}
],
},
{
"role":
"assistant",
"reasoning_content":
None,
"content":
None,
"tool_calls": [{
"id": "call123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}],
},
{
"role": "tool",
"content": [{
"type": "text",
"text": "Rainy"
}],
"name": "get_weather",
"tool_call_id": "call123",
},
],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
},
"required": ["city"],
},
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What's the weather in Paris?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="call123",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
ChatCompletionRequest(
messages=[
UserMessage(content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
ToolMessage(
content="Rainy",
tool_call_id="call123",
name="get_weather",
),
],
tools=[
Tool(
type="function",
function=Function(
name="get_weather",
description="Gets the current weather in a city.",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
}
),
(
{
"messages": [
{
"role": "user",
"content": "What is the current local date and time?",
}
],
"tools": [
{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
"parameters": None,
},
"required": ["city"],
},
),
)
],
),
)])
def test_make_mistral_chat_completion_request_list_content(
openai_request, expected_mistral_request):
}
],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
),
],
)
def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request):
actual_request = make_mistral_chat_completion_request(
openai_request["messages"], openai_request["tools"])
openai_request["messages"], openai_request["tools"]
)
assert actual_request == expected_mistral_request
# Tool use with list content and reasoning_content
@pytest.mark.parametrize(
"openai_request,expected_mistral_request",
[
(
{
"messages": [
{
"role": "user",
"content": "What's the weather in Paris?",
},
{
"role": "assistant",
"reasoning_content": None,
"content": None,
"tool_calls": [
{
"id": "call123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
}
],
},
{
"role": "tool",
"content": [{"type": "text", "text": "Rainy"}],
"name": "get_weather",
"tool_call_id": "call123",
},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Gets the current weather in a city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
},
}
],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What's the weather in Paris?"),
AssistantMessage(
content=None,
tool_calls=[
ToolCall(
id="call123",
function=FunctionCall(
name="get_weather",
arguments='{"city": "Paris"}',
),
)
],
),
ToolMessage(
content="Rainy",
tool_call_id="call123",
name="get_weather",
),
],
tools=[
Tool(
type="function",
function=Function(
name="get_weather",
description="Gets the current weather in a city.",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
}
},
"required": ["city"],
},
),
)
],
),
)
],
)
def test_make_mistral_chat_completion_request_list_content(
openai_request, expected_mistral_request
):
actual_request = make_mistral_chat_completion_request(
openai_request["messages"], openai_request["tools"]
)
assert actual_request == expected_mistral_request

View File

@@ -19,5 +19,5 @@ def test_tokenizer_revision(tokenizer_name: str):
assert isinstance(tokenizer, PreTrainedTokenizerBase)
# Assume that "never" branch always does not exist
with pytest.raises(OSError, match='not a valid git identifier'):
with pytest.raises(OSError, match="not a valid git identifier"):
get_tokenizer(tokenizer_name, revision="never")

View File

@@ -4,15 +4,13 @@
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)
from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TestTokenizer(TokenizerBase):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()
@@ -85,23 +83,23 @@ class TestTokenizer(TokenizerBase):
) -> list[int]:
raise NotImplementedError()
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> list[int]:
def encode(self, text: str, add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError()
def apply_chat_template(self,
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> list[int]:
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
**kwargs,
) -> list[int]:
raise NotImplementedError()
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
def decode(self,
ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str:
def decode(
self, ids: Union[list[int], int], skip_special_tokens: bool = True
) -> str:
raise NotImplementedError()
def convert_ids_to_tokens(
@@ -113,9 +111,9 @@ class TestTokenizer(TokenizerBase):
def test_customized_tokenizer():
TokenizerRegistry.register("test_tokenizer",
"tests.tokenization.test_tokenizer_registry",
"TestTokenizer")
TokenizerRegistry.register(
"test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer"
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)