[Bugfix] Fix edge cases for MistralTokenizer (#9625)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@@ -7,11 +7,17 @@ from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
||||
detokenize_incrementally)
|
||||
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
|
||||
"我很感谢你的热情"
|
||||
"我很感谢你的热情",
|
||||
# Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
|
||||
# for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
|
||||
# incomplete UTF-8 characters
|
||||
# see https://github.com/vllm-project/vllm/pull/9625
|
||||
"ပုံပြင်လေးပြောပြပါ်",
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
@@ -24,6 +30,7 @@ TOKENIZERS = [
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
]
|
||||
|
||||
|
||||
@@ -49,15 +56,55 @@ def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
return decoded_text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(tokenizer_name):
|
||||
return (MistralTokenizer.from_pretrained(tokenizer_name)
|
||||
if "mistral" in tokenizer_name else
|
||||
AutoTokenizer.from_pretrained(tokenizer_name))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
|
||||
@pytest.mark.parametrize(
|
||||
"truth",
|
||||
[
|
||||
# Burmese text triggers an edge-case where tokens may map to bytes with
|
||||
# incomplete UTF-8 characters
|
||||
"ပုံပြင်လေးပြောပြပါ",
|
||||
# Using "URGENCY" since "CY" has token id 130282
|
||||
"URGENCY🌶️",
|
||||
])
|
||||
def test_mistral_edge_case(tokenizer, truth):
|
||||
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/9625
|
||||
"""
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
|
||||
decoded_text = _run_incremental_decode(tokenizer,
|
||||
all_input_ids,
|
||||
skip_special_tokens=True,
|
||||
starting_index=starting_index)
|
||||
assert decoded_text == truth
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
|
||||
if "mistral" in tokenizer_name:
|
||||
yield (
|
||||
bool(True) if request.param else
|
||||
pytest.skip("mistral doesn't support skip_special_tokens=False"))
|
||||
else:
|
||||
yield bool(True) if request.param else bool(False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("with_prompt", [True, False])
|
||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||
def test_decode_streaming(tokenizer_id, truth, with_prompt,
|
||||
skip_special_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
|
||||
def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens):
|
||||
if with_prompt:
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
prompt_input_ids = truth_tokens[:len(truth) // 2]
|
||||
generated_input_ids = truth_tokens[len(truth) // 2:]
|
||||
all_input_ids = prompt_input_ids + generated_input_ids
|
||||
@@ -68,7 +115,7 @@ def test_decode_streaming(tokenizer_id, truth, with_prompt,
|
||||
else:
|
||||
generated = truth
|
||||
starting_index = 0
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
|
||||
if skip_special_tokens:
|
||||
if tokenizer.bos_token_id is not None:
|
||||
all_input_ids = [tokenizer.bos_token_id] + all_input_ids
|
||||
@@ -98,7 +145,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
enable_lora=False,
|
||||
max_num_seqs=100,
|
||||
max_input_length=None,
|
||||
tokenizer_mode="auto",
|
||||
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
|
||||
trust_remote_code=False,
|
||||
revision=None,
|
||||
)
|
||||
@@ -113,9 +160,8 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||
tokenizer_name: str) -> List[int]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"]
|
||||
tokenizer) -> List[int]:
|
||||
complete_sequence_token_ids = tokenizer(complete_sequence).input_ids
|
||||
return complete_sequence_token_ids
|
||||
|
||||
|
||||
@@ -150,7 +196,7 @@ def create_dummy_prompt_logprobs(
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
|
||||
def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer,
|
||||
@@ -208,9 +254,9 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
||||
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
|
||||
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
|
||||
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
|
||||
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
||||
text_full = tokenizer.decode(token_ids, skip_special_tokens=True)
|
||||
text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True)
|
||||
text = text_full[len(text_first):]
|
||||
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
|
||||
Reference in New Issue
Block a user