[Bugfix] vLLM produces invalid UTF-8 tokens and “�” (#28874)
Signed-off-by: John Calderon <jcalderon@nvidia.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
@@ -274,12 +274,28 @@ def _validate_logprobs(
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[lp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Sampled logprob token id {lp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})"
|
||||
)
|
||||
|
||||
# With UTF-8 correction logic, tokens ending with "<22>"
|
||||
# (incomplete byte sequences) are corrected to either
|
||||
# empty string or proper UTF-8 characters
|
||||
if ref_decoded_token.endswith("<EFBFBD>"):
|
||||
# Token needs UTF-8 correction
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Sampled logprob token id {lp_tok} decodes to"
|
||||
f" '{ref_decoded_token}' (ends with replacement char)"
|
||||
f" but corrected decoded token '{decoded_token}'"
|
||||
f" still ends with replacement char"
|
||||
f" (at position {idx}). UTF-8 correction should"
|
||||
f" have removed it."
|
||||
)
|
||||
else:
|
||||
# No correction needed, should match exactly
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Sampled logprob token id {lp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})"
|
||||
)
|
||||
|
||||
ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob
|
||||
# Assert that cumulative logprobs are correct
|
||||
@@ -420,12 +436,28 @@ def _validate_logprobs(
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[plp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Prompt logprob token id {plp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})"
|
||||
)
|
||||
|
||||
# With UTF-8 correction logic, tokens ending with "<22>"
|
||||
# (incomplete byte sequences) are corrected to either
|
||||
# empty string or proper UTF-8 characters
|
||||
if ref_decoded_token.endswith("<EFBFBD>"):
|
||||
# Token needs UTF-8 correction
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Prompt logprob token id {plp_tok} decodes to"
|
||||
f" '{ref_decoded_token}' (ends with replacement char)"
|
||||
f" but corrected decoded token '{decoded_token}'"
|
||||
f" still ends with replacement char"
|
||||
f" (at position {idx}). UTF-8 correction should"
|
||||
f" have removed it."
|
||||
)
|
||||
else:
|
||||
# No correction needed, should match exactly
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Prompt logprob token id {plp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})"
|
||||
)
|
||||
else:
|
||||
# Prompt logprobs disabled for this request
|
||||
assert prompt_logprobs is None
|
||||
|
||||
@@ -514,6 +514,424 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
||||
del llm
|
||||
|
||||
|
||||
class TestCorrectDecodedToken:
|
||||
"""Unit tests for _correct_decoded_token method in LogprobsProcessor.
|
||||
|
||||
This method handles UTF-8 decoding issues where incomplete byte sequences
|
||||
result in the Unicode replacement character "<EFBFBD>" (U+FFFD). This commonly
|
||||
happens with byte-fallback tokenization when multi-byte UTF-8 characters
|
||||
are split across tokens.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a mock tokenizer for testing."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
tokenizer = Mock()
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_empty_logprobs(self, mock_tokenizer):
|
||||
"""Create a LogprobsProcessor with empty logprobs."""
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
|
||||
processor = LogprobsProcessor(
|
||||
tokenizer=mock_tokenizer,
|
||||
logprobs=[],
|
||||
prompt_logprobs=None,
|
||||
cumulative_logprob=0.0,
|
||||
num_logprobs=1,
|
||||
num_prompt_logprobs=None,
|
||||
)
|
||||
return processor
|
||||
|
||||
@pytest.fixture
|
||||
def processor_with_previous_logprobs(self, mock_tokenizer):
|
||||
"""Create a LogprobsProcessor with previous logprobs."""
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
|
||||
processor = LogprobsProcessor(
|
||||
tokenizer=mock_tokenizer,
|
||||
logprobs=[{123: None}], # Previous token ID is 123
|
||||
prompt_logprobs=None,
|
||||
cumulative_logprob=0.0,
|
||||
num_logprobs=1,
|
||||
num_prompt_logprobs=None,
|
||||
)
|
||||
return processor
|
||||
|
||||
def test_correction_with_previous_token_in_list(
|
||||
self, processor_with_empty_logprobs
|
||||
):
|
||||
"""Test correction using previous token in the same list.
|
||||
|
||||
Scenario: Token at idx=1 ends with "<EFBFBD>", but when decoded with
|
||||
the previous token (idx=0), it forms a valid UTF-8 sequence.
|
||||
Example: token[0]="<EFBFBD>", token[1]="<EFBFBD>" -> together form "polarized"
|
||||
"""
|
||||
processor = processor_with_empty_logprobs
|
||||
tokens = [100, 101, 102] # token IDs
|
||||
|
||||
# Mock tokenizer behavior:
|
||||
# - decode([102]) returns "<22>" (ends with replacement char)
|
||||
# - decode([101, 102]) returns "valid" (no replacement char)
|
||||
processor.tokenizer.decode.side_effect = lambda ids: (
|
||||
"valid" if ids == [101, 102] else "<EFBFBD>"
|
||||
)
|
||||
|
||||
result = processor._correct_decoded_token(2, tokens)
|
||||
assert result == "valid"
|
||||
processor.tokenizer.decode.assert_called_with([101, 102])
|
||||
|
||||
def test_correction_with_previous_logprob_token(
|
||||
self, processor_with_previous_logprobs
|
||||
):
|
||||
"""Test correction using previous logprob token.
|
||||
|
||||
Scenario: Cannot correct with previous token in list (idx=0),
|
||||
but can correct with previous logprob token.
|
||||
"""
|
||||
processor = processor_with_previous_logprobs
|
||||
tokens = [100] # single token
|
||||
|
||||
# Mock tokenizer behavior:
|
||||
# - decode([100]) returns "<22>" (ends with replacement char)
|
||||
# - decode([123, 100]) returns " "polarized" (no replacement char)
|
||||
# Token 123 is from previous logprobs
|
||||
def mock_decode(ids):
|
||||
if ids == [123, 100]:
|
||||
return ' "polarized"'
|
||||
return "<EFBFBD>"
|
||||
|
||||
processor.tokenizer.decode.side_effect = mock_decode
|
||||
|
||||
result = processor._correct_decoded_token(0, tokens)
|
||||
assert result == ' "polarized"'
|
||||
|
||||
def test_correction_at_idx_zero_no_previous_logprobs(
|
||||
self, processor_with_empty_logprobs
|
||||
):
|
||||
"""Test correction at idx=0 with no previous logprobs.
|
||||
|
||||
Scenario: First token in list, no previous logprobs available.
|
||||
Should return empty string as fallback.
|
||||
"""
|
||||
processor = processor_with_empty_logprobs
|
||||
tokens = [100]
|
||||
|
||||
# Mock tokenizer always returns "<22>"
|
||||
processor.tokenizer.decode.return_value = "<EFBFBD>"
|
||||
|
||||
result = processor._correct_decoded_token(0, tokens)
|
||||
assert result == ""
|
||||
|
||||
def test_correction_at_idx_zero_with_previous_logprobs(
|
||||
self, processor_with_previous_logprobs
|
||||
):
|
||||
"""Test correction at idx=0 with previous logprobs available.
|
||||
|
||||
Scenario: First token in list, but previous logprobs exist.
|
||||
Should try correction with previous logprob token.
|
||||
"""
|
||||
processor = processor_with_previous_logprobs
|
||||
tokens = [200]
|
||||
|
||||
# Mock tokenizer behavior
|
||||
def mock_decode(ids):
|
||||
if ids == [123, 200]:
|
||||
return "corrected"
|
||||
return "<EFBFBD>"
|
||||
|
||||
processor.tokenizer.decode.side_effect = mock_decode
|
||||
|
||||
result = processor._correct_decoded_token(0, tokens)
|
||||
assert result == "corrected"
|
||||
|
||||
def test_no_correction_needed_returns_fallback(
|
||||
self, processor_with_previous_logprobs
|
||||
):
|
||||
"""Test fallback to empty string when no correction works.
|
||||
|
||||
Scenario: All correction attempts still end with "<EFBFBD>".
|
||||
Should return empty string as final fallback.
|
||||
"""
|
||||
processor = processor_with_previous_logprobs
|
||||
tokens = [100, 101, 102]
|
||||
|
||||
# Mock tokenizer always returns text ending with "<22>"
|
||||
processor.tokenizer.decode.return_value = "still<EFBFBD>"
|
||||
|
||||
result = processor._correct_decoded_token(2, tokens)
|
||||
assert result == ""
|
||||
|
||||
def test_middle_token_correction(self, processor_with_previous_logprobs):
|
||||
"""Test correction for a token in the middle of the list.
|
||||
|
||||
Scenario: Token at idx=5 in a longer list needs correction.
|
||||
"""
|
||||
processor = processor_with_previous_logprobs
|
||||
tokens = [10, 20, 30, 40, 50, 60, 70, 80]
|
||||
|
||||
# Mock tokenizer behavior for middle token
|
||||
def mock_decode(ids):
|
||||
if ids == [50, 60]:
|
||||
return "olar"
|
||||
return "<EFBFBD>"
|
||||
|
||||
processor.tokenizer.decode.side_effect = mock_decode
|
||||
|
||||
result = processor._correct_decoded_token(5, tokens)
|
||||
assert result == "olar"
|
||||
|
||||
def test_multiple_consecutive_replacement_chars(
|
||||
self, processor_with_previous_logprobs
|
||||
):
|
||||
"""Test handling of multiple consecutive replacement characters.
|
||||
|
||||
Scenario: Sequence like ["<EFBFBD>", "<EFBFBD>", "p"] where first two should
|
||||
become empty strings.
|
||||
"""
|
||||
processor = processor_with_previous_logprobs
|
||||
|
||||
# Test first replacement char
|
||||
tokens = [100, 101, 102]
|
||||
processor.tokenizer.decode.return_value = "still<EFBFBD>"
|
||||
result1 = processor._correct_decoded_token(0, tokens)
|
||||
assert result1 == ""
|
||||
|
||||
# Test second replacement char
|
||||
result2 = processor._correct_decoded_token(1, tokens)
|
||||
assert result2 == ""
|
||||
|
||||
def test_correction_with_multibyte_utf8(self, processor_with_previous_logprobs):
|
||||
"""Test correction involving multi-byte UTF-8 characters.
|
||||
|
||||
Scenario: Byte-fallback tokenization splits multi-byte UTF-8
|
||||
characters (e.g., curly quotes, Chinese characters, emojis).
|
||||
Example from user: "<EFBFBD>", "<EFBFBD>" -> "", "\""
|
||||
"""
|
||||
processor = processor_with_previous_logprobs
|
||||
tokens = [200, 201]
|
||||
|
||||
# Mock tokenizer behavior for multi-byte UTF-8 correction
|
||||
def mock_decode(ids):
|
||||
# When decoding first token (idx=0) with previous logprob token
|
||||
if ids == [123, 200]:
|
||||
return ' "' # Space + left curly quote
|
||||
# When decoding second token (idx=1) with previous token in list
|
||||
elif ids == [200, 201]:
|
||||
return '"' # Right curly quote
|
||||
# When decoding second token (idx=1) with previous logprob + prev token
|
||||
elif ids == [123, 200, 201]:
|
||||
return ' ""' # Full sequence
|
||||
return "<EFBFBD>"
|
||||
|
||||
processor.tokenizer.decode.side_effect = mock_decode
|
||||
|
||||
# First token correction (idx=0)
|
||||
# Will call decode([123, 200]) since idx=0 uses previous logprob token
|
||||
result1 = processor._correct_decoded_token(0, tokens)
|
||||
assert result1 == ' "'
|
||||
|
||||
# Second token correction (idx=1)
|
||||
# Will call decode([200, 201]) since idx>0 uses previous token in list
|
||||
result2 = processor._correct_decoded_token(1, tokens)
|
||||
assert result2 == '"'
|
||||
|
||||
def test_real_world_opt125m_scenario(self, mock_tokenizer):
|
||||
"""Test the real-world scenario from user's example.
|
||||
|
||||
User's example with facebook/opt-125m:
|
||||
Before: [" the", " term", " <20>", "<EFBFBD>", "p", "olar", "ized", "<EFBFBD>", "<EFBFBD>", ...]
|
||||
After: [" the", " term", "", " "", "p", "olar", "ized", "", "\"", ...]
|
||||
"""
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
|
||||
# Simulate the sequence of tokens
|
||||
processor = LogprobsProcessor(
|
||||
tokenizer=mock_tokenizer,
|
||||
logprobs=[],
|
||||
prompt_logprobs=None,
|
||||
cumulative_logprob=0.0,
|
||||
num_logprobs=1,
|
||||
num_prompt_logprobs=None,
|
||||
)
|
||||
|
||||
# Token IDs representing the problematic sequence
|
||||
tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9] # placeholder IDs
|
||||
|
||||
# Mock decode behavior simulating the real scenario
|
||||
def mock_decode(ids):
|
||||
# Simulate cases where individual tokens decode to "<22>"
|
||||
# but combinations decode correctly
|
||||
if len(ids) == 1:
|
||||
if ids[0] == 3 or ids[0] == 4 or ids[0] == 8 or ids[0] == 9:
|
||||
return "<EFBFBD>"
|
||||
elif len(ids) == 2:
|
||||
if ids == [2, 3]:
|
||||
return " term<72>" # Still ends with <20>, need more context
|
||||
elif ids == [3, 4]:
|
||||
return ' "' # Corrected to space + left curly quote
|
||||
elif ids == [7, 8]:
|
||||
return "ized<EFBFBD>" # Still ends with <20>
|
||||
elif ids == [8, 9]:
|
||||
return '"' # Corrected to right curly quote
|
||||
elif len(ids) == 3:
|
||||
if ids == [1, 2, 3]:
|
||||
return " the term<72>" # Still ends with issue
|
||||
elif ids == [2, 3, 4]:
|
||||
return ' term "' # With all context
|
||||
return "normal_text"
|
||||
|
||||
mock_tokenizer.decode.side_effect = mock_decode
|
||||
|
||||
# Test token at index 2 (should fail to correct, return "")
|
||||
# Token 3 individually is "<22>"
|
||||
# decode([2, 3]) = " term<72>" (still ends with <20>)
|
||||
# No previous logprobs, so fallback to ""
|
||||
result = processor._correct_decoded_token(2, tokens)
|
||||
assert result == ""
|
||||
|
||||
# Test token at index 3 (should correct to " "")
|
||||
# Token 4 individually is "<22>"
|
||||
# decode([3, 4]) = " "" (corrected!)
|
||||
processor.logprobs = [{2: None}] # Add previous logprob
|
||||
result = processor._correct_decoded_token(3, tokens)
|
||||
assert result == ' "'
|
||||
|
||||
|
||||
def test_verify_tokens_integration():
|
||||
"""Integration test for _verify_tokens with real model.
|
||||
|
||||
This test validates that _verify_tokens correctly identifies and
|
||||
corrects tokens ending with the replacement character "<EFBFBD>".
|
||||
Uses facebook/opt-125m which is known to produce these issues.
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=0,
|
||||
enable_prefix_caching=False,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
# Use a prompt that triggers multi-byte UTF-8 issues
|
||||
# Based on user's example: "In this example,"
|
||||
test_prompts = ["In this example,"]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=16,
|
||||
temperature=0,
|
||||
logprobs=0,
|
||||
)
|
||||
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
|
||||
# Verify that decoded tokens don't contain replacement characters
|
||||
for result in results:
|
||||
assert result.outputs[0].logprobs is not None
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
# Decoded tokens should not end with replacement character
|
||||
# They should either be corrected or empty string
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Token {token_id} decoded to '{decoded_token}' which "
|
||||
f"ends with replacement character"
|
||||
)
|
||||
# Decoded tokens should not contain lone replacement characters
|
||||
assert decoded_token != "<EFBFBD>", (
|
||||
f"Token {token_id} is a lone replacement character"
|
||||
)
|
||||
|
||||
|
||||
def test_utf8_edge_cases_with_real_model():
|
||||
"""Test various UTF-8 edge cases with a real model.
|
||||
|
||||
Tests prompts that are likely to trigger byte-fallback tokenization
|
||||
and multi-byte UTF-8 splitting.
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enable_prefix_caching=False,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
# Prompts with various multi-byte UTF-8 characters
|
||||
test_prompts = [
|
||||
'Smart quotes: "Hello"', # Curly quotes
|
||||
"Em dash — test", # Em dash
|
||||
"Ellipsis… continues", # Ellipsis
|
||||
"Chinese: 你好", # Chinese characters
|
||||
"Emoji: 😀 🎉", # Emojis
|
||||
'Mixed: "quoted" — with symbols', # Mixed
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=1,
|
||||
)
|
||||
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
prompt = test_prompts[i]
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
# Check that no decoded tokens end with replacement character
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Prompt: '{prompt}'\n"
|
||||
f"Token {token_id} decoded to '{decoded_token}' which "
|
||||
f"ends with replacement character"
|
||||
)
|
||||
|
||||
|
||||
def test_correct_decoded_token_preserves_valid_tokens():
|
||||
"""Test that valid tokens (not ending with <20>) are not modified.
|
||||
|
||||
The _correct_decoded_token method should only be called for tokens
|
||||
ending with "<EFBFBD>", but this test verifies the broader _verify_tokens
|
||||
logic doesn't affect valid tokens.
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=2,
|
||||
enable_prefix_caching=False,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
# Simple prompt with standard ASCII characters
|
||||
test_prompts = ["Hello world, this is a test."]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=2,
|
||||
)
|
||||
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
|
||||
for result in results:
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
# All decoded tokens should be valid strings
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
# Valid tokens should be non-empty strings (or empty if corrected)
|
||||
assert isinstance(decoded_token, str)
|
||||
# Should not contain replacement character
|
||||
assert "<EFBFBD>" not in decoded_token
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
@pytest.mark.parametrize(
|
||||
"model_setup",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@@ -88,11 +89,16 @@ class LogprobsProcessor:
|
||||
logprobs = logprobs_np.tolist()
|
||||
token_ids = token_ids_np.tolist()
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens = (
|
||||
NONES
|
||||
if self.tokenizer is None
|
||||
else (convert_ids_list_to_tokens(self.tokenizer, token_ids))
|
||||
)
|
||||
decoded_tokens: list[str] | Iterable[None]
|
||||
if self.tokenizer is None:
|
||||
decoded_tokens = NONES
|
||||
else:
|
||||
decoded_tokens_list = convert_ids_list_to_tokens(
|
||||
self.tokenizer, token_ids
|
||||
)
|
||||
decoded_tokens = self._verify_tokens(
|
||||
decoded_tokens_list=decoded_tokens_list, tokens=token_ids
|
||||
)
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
@@ -126,37 +132,45 @@ class LogprobsProcessor:
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = (
|
||||
None
|
||||
if self.tokenizer is None
|
||||
else (
|
||||
convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist())
|
||||
)
|
||||
)
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
all_decoded_tokens: list[str] | None = (
|
||||
None
|
||||
if self.tokenizer is None
|
||||
else convert_ids_list_to_tokens(
|
||||
self.tokenizer, token_ids.flatten().tolist()
|
||||
)
|
||||
)
|
||||
|
||||
# Pythonize the torch tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
token_ids_list = token_ids.tolist()
|
||||
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
# Handle flattening and UTF-8 correction per position
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = (
|
||||
NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
)
|
||||
|
||||
decoded_tokens_for_pos: list[str] | Iterable[None]
|
||||
if all_decoded_tokens is None:
|
||||
decoded_tokens_for_pos = NONES
|
||||
else:
|
||||
# Extract decoded tokens for this position
|
||||
decoded_tokens_slice = all_decoded_tokens[offset:offset_end]
|
||||
# Apply UTF-8 correction within this position's token boundaries
|
||||
decoded_tokens_for_pos = self._verify_tokens(
|
||||
decoded_tokens_list=decoded_tokens_slice, tokens=token_ids_list[pos]
|
||||
)
|
||||
|
||||
# Update with the Logprob container for this pos.
|
||||
append_logprobs_for_next_position(
|
||||
self.prompt_logprobs,
|
||||
token_ids[pos],
|
||||
token_ids_list[pos],
|
||||
prompt_logprobs[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
@@ -182,6 +196,48 @@ class LogprobsProcessor:
|
||||
self.prompt_logprobs = []
|
||||
return plp
|
||||
|
||||
def _correct_decoded_token(self, idx: int, tokens: list[int]) -> str:
|
||||
assert self.tokenizer is not None, "self.tokenizer should not be None"
|
||||
|
||||
# try with prev token id in same list
|
||||
if idx > 0:
|
||||
possible_decoded_token = self.tokenizer.decode(tokens[idx - 1 : idx + 1])
|
||||
if not possible_decoded_token.endswith("<EFBFBD>"):
|
||||
return possible_decoded_token
|
||||
# try with previous logprob token id
|
||||
if self.logprobs:
|
||||
latest_token_id = next(iter(self.logprobs[-1]))
|
||||
|
||||
decode_ids = [latest_token_id]
|
||||
if idx > 0:
|
||||
decode_ids.extend(tokens[idx - 1 : idx + 1])
|
||||
else:
|
||||
decode_ids.extend(tokens[idx : idx + 1])
|
||||
|
||||
possible_decoded_token = self.tokenizer.decode(decode_ids)
|
||||
if not possible_decoded_token.endswith("<EFBFBD>"):
|
||||
return possible_decoded_token
|
||||
|
||||
# by default return empty string
|
||||
return ""
|
||||
|
||||
def _verify_tokens(
|
||||
self, decoded_tokens_list: list[str], tokens: list[int]
|
||||
) -> list[str]:
|
||||
corrected_decoded_token_map = dict()
|
||||
for idx, text in enumerate(decoded_tokens_list):
|
||||
if text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
corrected_decoded_token_map[idx] = self._correct_decoded_token(
|
||||
idx, tokens
|
||||
)
|
||||
|
||||
for idx, text in corrected_decoded_token_map.items():
|
||||
decoded_tokens_list[idx] = text
|
||||
|
||||
return decoded_tokens_list
|
||||
|
||||
def update_from_output(self, output: EngineCoreOutput) -> None:
|
||||
if output.new_logprobs is not None:
|
||||
self._update_sample_logprobs(output.new_logprobs)
|
||||
|
||||
Reference in New Issue
Block a user