Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -5,15 +5,16 @@ from collections.abc import Generator
from typing import Any, Optional
import pytest
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer)
from vllm.v1.engine.detokenizer import (
FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer,
)
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
@@ -45,33 +46,35 @@ TOKENIZERS = [
]
def _run_incremental_decode(tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: Optional[bool] = None):
def _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: Optional[bool] = None,
):
prompt_token_ids = all_input_ids[:starting_index]
params = SamplingParams(
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest(request_id="",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None)
request = EngineCoreRequest(
request_id="",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
tokenizer, request)
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else:
@@ -88,9 +91,11 @@ def _run_incremental_decode(tokenizer,
@pytest.fixture
def tokenizer(tokenizer_name):
return (MistralTokenizer.from_pretrained(tokenizer_name)
if "mistral" in tokenizer_name else
AutoTokenizer.from_pretrained(tokenizer_name))
return (
MistralTokenizer.from_pretrained(tokenizer_name)
if "mistral" in tokenizer_name
else AutoTokenizer.from_pretrained(tokenizer_name)
)
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@@ -102,7 +107,8 @@ def tokenizer(tokenizer_name):
"ပုံပြင်လေးပြောပြပါ",
# Using "URGENCY" since "CY" has token id 130282
"URGENCY🌶",
])
],
)
def test_mistral_edge_case(tokenizer, truth):
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
@@ -115,7 +121,8 @@ def test_mistral_edge_case(tokenizer, truth):
tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index)
starting_index=starting_index,
)
assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@@ -124,8 +131,10 @@ def test_mistral_edge_case(tokenizer, truth):
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
if "mistral" in tokenizer_name:
yield (
True if request.param else
pytest.skip("mistral doesn't support skip_special_tokens=False"))
True
if request.param
else pytest.skip("mistral doesn't support skip_special_tokens=False")
)
else:
yield bool(request.param)
@@ -136,8 +145,14 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
spaces_between_special_tokens, fast):
def test_decode_streaming(
tokenizer,
truth,
with_prompt,
skip_special_tokens,
spaces_between_special_tokens,
fast,
):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
@@ -146,30 +161,35 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens({
"additional_special_tokens": [
at for at in
tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
})
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
at
for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
}
)
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
extra_decode_args = (
{}
if not isinstance(tokenizer, PreTrainedTokenizer)
else {"spaces_between_special_tokens": spaces_between_special_tokens}
)
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
if tokenizer.bos_token_id is not None:
truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(truth_tokens,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
new_truth = tokenizer.decode(
truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
)
if with_prompt:
num_prompt_tokens = len(
tokenizer(truth[:len(truth) // 2],
add_special_tokens=False).input_ids)
tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
@@ -177,11 +197,13 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(prompt_input_ids,
skip_special_tokens=skip_special_tokens,
**extra_decode_args)
prompt = tokenizer.decode(
prompt_input_ids,
skip_special_tokens=skip_special_tokens,
**extra_decode_args,
)
generated = new_truth[len(prompt):]
generated = new_truth[len(prompt) :]
else:
generated = new_truth
starting_index = 0
@@ -193,7 +215,8 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast)
fast=fast,
)
assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
@@ -206,11 +229,13 @@ def test_oov_decode(tokenizer, fast):
pytest.skip()
decoded_text, out_ids = _run_incremental_decode(
tokenizer, [len(tokenizer)],
tokenizer,
[len(tokenizer)],
skip_special_tokens=True,
starting_index=0,
spaces_between_special_tokens=True,
fast=fast)
fast=fast,
)
assert decoded_text == ''
assert decoded_text == ""
assert out_ids == [len(tokenizer)]