[BUGFIX] MistralTokenizer._call__ adds an invalid EOS token (#29607)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Julien Denize
2025-11-28 09:44:47 +01:00
committed by GitHub
parent cc0f2a0e19
commit b2c1d294fa
2 changed files with 87 additions and 1 deletions

View File

@@ -331,6 +331,7 @@ class TestMistralTokenizer:
)
== token_ids
)
assert mistral_tokenizer.encode_one("") == []
def test_encode(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
@@ -370,6 +371,51 @@ class TestMistralTokenizer:
mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
== token_ids[1:]
)
assert mistral_tokenizer.encode("", add_special_tokens=False) == []
def test_call(self, mistral_tokenizer: MistralTokenizer):
token_ids = (
[1, 22177, 4304, 2662]
if mistral_tokenizer.is_tekken
else [1, 23325, 2294, 1686]
)
attn_mask = [1 for _ in range(len(token_ids))]
# Test 1: default
assert mistral_tokenizer("Hello world !") == {
"attention_mask": attn_mask[1:],
"input_ids": token_ids[1:],
}
# Test 2: special tokens
assert mistral_tokenizer("Hello world !", add_special_tokens=True) == {
"attention_mask": attn_mask,
"input_ids": token_ids,
}
# Test 3: special tokens + truncation
assert mistral_tokenizer(
"Hello world !", add_special_tokens=True, truncation=True, max_length=3
) == {
"attention_mask": attn_mask[:-1],
"input_ids": token_ids[:-1],
}
# Test 4: special tokens + no truncation + max length
assert mistral_tokenizer(
"Hello world !", add_special_tokens=True, max_length=3
) == {
"attention_mask": attn_mask,
"input_ids": token_ids,
}
# Test 5: empty string
assert mistral_tokenizer("") == {
"attention_mask": [],
"input_ids": [],
}
with pytest.raises(
ValueError,
match=(r"`text_pair` is not supported by `MistralTokenizer.__call__`."),
):
mistral_tokenizer("Hello world !", "invalid pair")
@pytest.mark.parametrize(
"openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output",
@@ -1087,6 +1133,24 @@ class TestMistralTokenizer:
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
assert (
mistral_tokenizer.decode(
ids[mistral_tokenizer.is_tekken],
skip_special_tokens=skip_special_tokens,
)
== expected_tokens[mistral_tokenizer.is_tekken]
)
def test_decode_empty(
self,
mistral_tokenizer: MistralTokenizer,
):
assert (
mistral_tokenizer.decode(
[],
)
== ""
)
def test_decode_int(
self,
@@ -1390,6 +1454,8 @@ class TestMistralTokenizer:
== expected_strings[mistral_tokenizer.is_tekken]
)
assert mistral_tokenizer.convert_tokens_to_string([]) == ""
@pytest.mark.parametrize(
"skip_special_tokens,tuple_expected_tokens",
(
@@ -2220,3 +2286,5 @@ class TestMistralTokenizer:
ids, skip_special_tokens=skip_special_tokens
)
assert actual_tokens == expected_tokens
assert mistral_tokenizer.convert_ids_to_tokens([]) == []