Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,15 +5,16 @@ from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||
IncrementalDetokenizer,
|
||||
SlowIncrementalDetokenizer)
|
||||
from vllm.v1.engine.detokenizer import (
|
||||
FastIncrementalDetokenizer,
|
||||
IncrementalDetokenizer,
|
||||
SlowIncrementalDetokenizer,
|
||||
)
|
||||
|
||||
SPECIAL_TOKS_TRUTH = [
|
||||
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
|
||||
@@ -45,33 +46,35 @@ TOKENIZERS = [
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens: bool,
|
||||
starting_index: int,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
fast: Optional[bool] = None):
|
||||
|
||||
def _run_incremental_decode(
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens: bool,
|
||||
starting_index: int,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
fast: Optional[bool] = None,
|
||||
):
|
||||
prompt_token_ids = all_input_ids[:starting_index]
|
||||
|
||||
params = SamplingParams(
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
request = EngineCoreRequest(request_id="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None)
|
||||
request = EngineCoreRequest(
|
||||
request_id="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
if fast is None:
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
tokenizer, request)
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
|
||||
elif fast:
|
||||
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
|
||||
else:
|
||||
@@ -88,9 +91,11 @@ def _run_incremental_decode(tokenizer,
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(tokenizer_name):
|
||||
return (MistralTokenizer.from_pretrained(tokenizer_name)
|
||||
if "mistral" in tokenizer_name else
|
||||
AutoTokenizer.from_pretrained(tokenizer_name))
|
||||
return (
|
||||
MistralTokenizer.from_pretrained(tokenizer_name)
|
||||
if "mistral" in tokenizer_name
|
||||
else AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
|
||||
@@ -102,7 +107,8 @@ def tokenizer(tokenizer_name):
|
||||
"ပုံပြင်လေးပြောပြပါ",
|
||||
# Using "URGENCY" since "CY" has token id 130282
|
||||
"URGENCY🌶️",
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_mistral_edge_case(tokenizer, truth):
|
||||
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
|
||||
|
||||
@@ -115,7 +121,8 @@ def test_mistral_edge_case(tokenizer, truth):
|
||||
tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
starting_index=starting_index,
|
||||
)
|
||||
assert decoded_text == truth
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
|
||||
@@ -124,8 +131,10 @@ def test_mistral_edge_case(tokenizer, truth):
|
||||
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
if "mistral" in tokenizer_name:
|
||||
yield (
|
||||
True if request.param else
|
||||
pytest.skip("mistral doesn't support skip_special_tokens=False"))
|
||||
True
|
||||
if request.param
|
||||
else pytest.skip("mistral doesn't support skip_special_tokens=False")
|
||||
)
|
||||
else:
|
||||
yield bool(request.param)
|
||||
|
||||
@@ -136,8 +145,14 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
|
||||
@pytest.mark.parametrize("fast", (True, False))
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
spaces_between_special_tokens, fast):
|
||||
def test_decode_streaming(
|
||||
tokenizer,
|
||||
truth,
|
||||
with_prompt,
|
||||
skip_special_tokens,
|
||||
spaces_between_special_tokens,
|
||||
fast,
|
||||
):
|
||||
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
pytest.skip()
|
||||
|
||||
@@ -146,30 +161,35 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
|
||||
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
# Fix up inconsistency in fast/slow tokenizer behaviour.
|
||||
tokenizer.add_special_tokens({
|
||||
"additional_special_tokens": [
|
||||
at for at in
|
||||
tokenizer._tokenizer.get_added_tokens_decoder().values()
|
||||
if at.special
|
||||
]
|
||||
})
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
at
|
||||
for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
|
||||
if at.special
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \
|
||||
extra_decode_args = (
|
||||
{}
|
||||
if not isinstance(tokenizer, PreTrainedTokenizer)
|
||||
else {"spaces_between_special_tokens": spaces_between_special_tokens}
|
||||
)
|
||||
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if tokenizer.bos_token_id is not None:
|
||||
truth_tokens.insert(0, tokenizer.bos_token_id)
|
||||
truth_tokens.append(tokenizer.eos_token_id)
|
||||
|
||||
new_truth = tokenizer.decode(truth_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
new_truth = tokenizer.decode(
|
||||
truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
|
||||
)
|
||||
|
||||
if with_prompt:
|
||||
num_prompt_tokens = len(
|
||||
tokenizer(truth[:len(truth) // 2],
|
||||
add_special_tokens=False).input_ids)
|
||||
tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
|
||||
)
|
||||
if tokenizer.bos_token_id is not None:
|
||||
num_prompt_tokens += 1
|
||||
|
||||
@@ -177,11 +197,13 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
generated_input_ids = truth_tokens[num_prompt_tokens:]
|
||||
all_input_ids = prompt_input_ids + generated_input_ids
|
||||
starting_index = len(prompt_input_ids)
|
||||
prompt = tokenizer.decode(prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args)
|
||||
prompt = tokenizer.decode(
|
||||
prompt_input_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
**extra_decode_args,
|
||||
)
|
||||
|
||||
generated = new_truth[len(prompt):]
|
||||
generated = new_truth[len(prompt) :]
|
||||
else:
|
||||
generated = new_truth
|
||||
starting_index = 0
|
||||
@@ -193,7 +215,8 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
starting_index=starting_index,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
fast=fast)
|
||||
fast=fast,
|
||||
)
|
||||
|
||||
assert decoded_text == generated
|
||||
assert out_ids == all_input_ids[starting_index:]
|
||||
@@ -206,11 +229,13 @@ def test_oov_decode(tokenizer, fast):
|
||||
pytest.skip()
|
||||
|
||||
decoded_text, out_ids = _run_incremental_decode(
|
||||
tokenizer, [len(tokenizer)],
|
||||
tokenizer,
|
||||
[len(tokenizer)],
|
||||
skip_special_tokens=True,
|
||||
starting_index=0,
|
||||
spaces_between_special_tokens=True,
|
||||
fast=fast)
|
||||
fast=fast,
|
||||
)
|
||||
|
||||
assert decoded_text == ''
|
||||
assert decoded_text == ""
|
||||
assert out_ids == [len(tokenizer)]
|
||||
|
||||
Reference in New Issue
Block a user