[Bugfix] Fix V1 logprobs empty strings for multi-byte UTF-8 tokens when logprobs > 0 (#34875)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -539,6 +539,10 @@ class TestCorrectDecodedToken:
|
|||||||
result in the Unicode replacement character "<EFBFBD>" (U+FFFD). This commonly
|
result in the Unicode replacement character "<EFBFBD>" (U+FFFD). This commonly
|
||||||
happens with byte-fallback tokenization when multi-byte UTF-8 characters
|
happens with byte-fallback tokenization when multi-byte UTF-8 characters
|
||||||
are split across tokens.
|
are split across tokens.
|
||||||
|
|
||||||
|
The method signature is _correct_decoded_token(token_id, context_token_ids)
|
||||||
|
where token_id is the single token to correct and context_token_ids are
|
||||||
|
the preceding sampled tokens in sequential order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -550,8 +554,8 @@ class TestCorrectDecodedToken:
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def processor_with_empty_logprobs(self, mock_tokenizer):
|
def processor(self, mock_tokenizer):
|
||||||
"""Create a LogprobsProcessor with empty logprobs."""
|
"""Create a LogprobsProcessor."""
|
||||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||||
|
|
||||||
processor = LogprobsProcessor(
|
processor = LogprobsProcessor(
|
||||||
@@ -564,209 +568,191 @@ class TestCorrectDecodedToken:
|
|||||||
)
|
)
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
@pytest.fixture
|
def test_correction_with_context(self, processor):
|
||||||
def processor_with_previous_logprobs(self, mock_tokenizer):
|
"""Test correction using context from preceding sampled tokens.
|
||||||
"""Create a LogprobsProcessor with previous logprobs."""
|
|
||||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
|
||||||
|
|
||||||
processor = LogprobsProcessor(
|
Scenario: A byte-fallback token that completes a multi-byte
|
||||||
tokenizer=mock_tokenizer,
|
UTF-8 sequence when decoded with context.
|
||||||
logprobs=[{123: None}], # Previous token ID is 123
|
|
||||||
prompt_logprobs=None,
|
|
||||||
cumulative_logprob=0.0,
|
|
||||||
num_logprobs=1,
|
|
||||||
num_prompt_logprobs=None,
|
|
||||||
)
|
|
||||||
return processor
|
|
||||||
|
|
||||||
def test_correction_with_previous_token_in_list(
|
|
||||||
self, processor_with_empty_logprobs
|
|
||||||
):
|
|
||||||
"""Test correction using previous token in the same list.
|
|
||||||
|
|
||||||
Scenario: Token at idx=1 ends with "<EFBFBD>", but when decoded with
|
|
||||||
the previous token (idx=0), it forms a valid UTF-8 sequence.
|
|
||||||
Example: token[0]="<EFBFBD>", token[1]="<EFBFBD>" -> together form "polarized"
|
|
||||||
"""
|
"""
|
||||||
processor = processor_with_empty_logprobs
|
|
||||||
tokens = [100, 101, 102] # token IDs
|
|
||||||
|
|
||||||
# Mock tokenizer behavior:
|
# Context is [101] (a preceding sampled token)
|
||||||
# - decode([102]) returns "<22>" (ends with replacement char)
|
# Token 102 individually decodes to "<22>"
|
||||||
# - decode([101, 102]) returns "valid" (no replacement char)
|
# decode([101, 102]) returns "valid" (complete sequence)
|
||||||
processor.tokenizer.decode.side_effect = lambda ids: (
|
def mock_decode(ids):
|
||||||
"valid" if ids == [101, 102] else "<EFBFBD>"
|
if ids == [101, 102]:
|
||||||
)
|
return "hello valid"
|
||||||
|
if ids == [101]:
|
||||||
|
return "hello "
|
||||||
|
return "<EFBFBD>"
|
||||||
|
|
||||||
result = processor._correct_decoded_token(2, tokens)
|
processor.tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
|
result = processor._correct_decoded_token(102, [101])
|
||||||
assert result == "valid"
|
assert result == "valid"
|
||||||
processor.tokenizer.decode.assert_called_with([101, 102])
|
|
||||||
|
|
||||||
def test_correction_with_previous_logprob_token(
|
def test_correction_with_context_from_logprobs(self, processor):
|
||||||
self, processor_with_previous_logprobs
|
"""Test correction using context from previous logprob entries.
|
||||||
):
|
|
||||||
"""Test correction using previous logprob token.
|
|
||||||
|
|
||||||
Scenario: Cannot correct with previous token in list (idx=0),
|
Scenario: Token decoded with context from previously sampled
|
||||||
but can correct with previous logprob token.
|
tokens completes a UTF-8 sequence.
|
||||||
"""
|
"""
|
||||||
processor = processor_with_previous_logprobs
|
|
||||||
tokens = [100] # single token
|
|
||||||
|
|
||||||
# Mock tokenizer behavior:
|
# Token 123 was previously sampled (in context)
|
||||||
# - decode([100]) returns "<22>" (ends with replacement char)
|
|
||||||
# - decode([123, 100]) returns " "polarized" (no replacement char)
|
|
||||||
# Token 123 is from previous logprobs
|
|
||||||
def mock_decode(ids):
|
def mock_decode(ids):
|
||||||
if ids == [123, 100]:
|
if ids == [123, 100]:
|
||||||
return ' "polarized"'
|
return 'hello "polarized"'
|
||||||
|
if ids == [123]:
|
||||||
|
return "hello "
|
||||||
return "<EFBFBD>"
|
return "<EFBFBD>"
|
||||||
|
|
||||||
processor.tokenizer.decode.side_effect = mock_decode
|
processor.tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
result = processor._correct_decoded_token(0, tokens)
|
result = processor._correct_decoded_token(100, [123])
|
||||||
assert result == ' "polarized"'
|
assert result == '"polarized"'
|
||||||
|
|
||||||
def test_correction_at_idx_zero_no_previous_logprobs(
|
def test_correction_no_context(self, processor):
|
||||||
self, processor_with_empty_logprobs
|
"""Test correction with no context available.
|
||||||
):
|
|
||||||
"""Test correction at idx=0 with no previous logprobs.
|
|
||||||
|
|
||||||
Scenario: First token in list, no previous logprobs available.
|
|
||||||
Should return empty string as fallback.
|
Should return empty string as fallback.
|
||||||
"""
|
"""
|
||||||
processor = processor_with_empty_logprobs
|
|
||||||
tokens = [100]
|
|
||||||
|
|
||||||
# Mock tokenizer always returns "<22>"
|
|
||||||
processor.tokenizer.decode.return_value = "<EFBFBD>"
|
processor.tokenizer.decode.return_value = "<EFBFBD>"
|
||||||
|
|
||||||
result = processor._correct_decoded_token(0, tokens)
|
result = processor._correct_decoded_token(100, [])
|
||||||
assert result == ""
|
assert result == ""
|
||||||
|
|
||||||
def test_correction_at_idx_zero_with_previous_logprobs(
|
def test_correction_with_context_succeeds(self, processor):
|
||||||
self, processor_with_previous_logprobs
|
"""Test correction with context from previously sampled tokens."""
|
||||||
):
|
|
||||||
"""Test correction at idx=0 with previous logprobs available.
|
|
||||||
|
|
||||||
Scenario: First token in list, but previous logprobs exist.
|
|
||||||
Should try correction with previous logprob token.
|
|
||||||
"""
|
|
||||||
processor = processor_with_previous_logprobs
|
|
||||||
tokens = [200]
|
|
||||||
|
|
||||||
# Mock tokenizer behavior
|
|
||||||
def mock_decode(ids):
|
def mock_decode(ids):
|
||||||
if ids == [123, 200]:
|
if ids == [123, 200]:
|
||||||
return "corrected"
|
return "hello corrected"
|
||||||
|
if ids == [123]:
|
||||||
|
return "hello "
|
||||||
return "<EFBFBD>"
|
return "<EFBFBD>"
|
||||||
|
|
||||||
processor.tokenizer.decode.side_effect = mock_decode
|
processor.tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
result = processor._correct_decoded_token(0, tokens)
|
result = processor._correct_decoded_token(200, [123])
|
||||||
assert result == "corrected"
|
assert result == "corrected"
|
||||||
|
|
||||||
def test_no_correction_needed_returns_fallback(
|
def test_fallback_when_all_attempts_fail(self, processor):
|
||||||
self, processor_with_previous_logprobs
|
"""Test fallback to empty string when no correction works."""
|
||||||
):
|
|
||||||
"""Test fallback to empty string when no correction works.
|
|
||||||
|
|
||||||
Scenario: All correction attempts still end with "<EFBFBD>".
|
|
||||||
Should return empty string as final fallback.
|
|
||||||
"""
|
|
||||||
processor = processor_with_previous_logprobs
|
|
||||||
tokens = [100, 101, 102]
|
|
||||||
|
|
||||||
# Mock tokenizer always returns text ending with "<22>"
|
|
||||||
processor.tokenizer.decode.return_value = "still<EFBFBD>"
|
processor.tokenizer.decode.return_value = "still<EFBFBD>"
|
||||||
|
|
||||||
result = processor._correct_decoded_token(2, tokens)
|
result = processor._correct_decoded_token(102, [100, 101])
|
||||||
assert result == ""
|
assert result == ""
|
||||||
|
|
||||||
def test_middle_token_correction(self, processor_with_previous_logprobs):
|
def test_increasing_context_window(self, processor):
|
||||||
"""Test correction for a token in the middle of the list.
|
"""Test that increasing context window finds the correction.
|
||||||
|
|
||||||
Scenario: Token at idx=5 in a longer list needs correction.
|
Scenario: 3-byte UTF-8 char. With 1 context token, still
|
||||||
|
incomplete. With 2 context tokens, completes the sequence.
|
||||||
"""
|
"""
|
||||||
processor = processor_with_previous_logprobs
|
|
||||||
tokens = [10, 20, 30, 40, 50, 60, 70, 80]
|
|
||||||
|
|
||||||
# Mock tokenizer behavior for middle token
|
|
||||||
def mock_decode(ids):
|
def mock_decode(ids):
|
||||||
if ids == [50, 60]:
|
# 1 context token: still incomplete
|
||||||
return "olar"
|
if ids == [81, 82]:
|
||||||
|
return "<EFBFBD>"
|
||||||
|
# 2 context tokens: complete
|
||||||
|
if ids == [80, 81, 82]:
|
||||||
|
return "\u201c"
|
||||||
|
# Context-only decodes
|
||||||
|
if ids == [81]:
|
||||||
|
return "<EFBFBD>"
|
||||||
|
if ids == [80, 81]:
|
||||||
|
return "<EFBFBD>"
|
||||||
return "<EFBFBD>"
|
return "<EFBFBD>"
|
||||||
|
|
||||||
processor.tokenizer.decode.side_effect = mock_decode
|
processor.tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
result = processor._correct_decoded_token(5, tokens)
|
# Context has 2 preceding tokens [80, 81]
|
||||||
assert result == "olar"
|
result = processor._correct_decoded_token(82, [80, 81])
|
||||||
|
assert result == "\u201c"
|
||||||
|
|
||||||
def test_multiple_consecutive_replacement_chars(
|
def test_multiple_consecutive_replacement_chars(self, processor):
|
||||||
self, processor_with_previous_logprobs
|
|
||||||
):
|
|
||||||
"""Test handling of multiple consecutive replacement characters.
|
"""Test handling of multiple consecutive replacement characters.
|
||||||
|
|
||||||
Scenario: Sequence like ["<EFBFBD>", "<EFBFBD>", "p"] where first two should
|
Scenario: Multi-byte sequence where intermediate bytes return
|
||||||
become empty strings.
|
empty string and the final byte returns the complete character.
|
||||||
"""
|
"""
|
||||||
processor = processor_with_previous_logprobs
|
|
||||||
|
|
||||||
# Test first replacement char
|
|
||||||
tokens = [100, 101, 102]
|
|
||||||
processor.tokenizer.decode.return_value = "still<EFBFBD>"
|
processor.tokenizer.decode.return_value = "still<EFBFBD>"
|
||||||
result1 = processor._correct_decoded_token(0, tokens)
|
|
||||||
|
# First byte with no useful context: returns ""
|
||||||
|
result1 = processor._correct_decoded_token(100, [50])
|
||||||
assert result1 == ""
|
assert result1 == ""
|
||||||
|
|
||||||
# Test second replacement char
|
# Second byte with same context: still returns ""
|
||||||
result2 = processor._correct_decoded_token(1, tokens)
|
result2 = processor._correct_decoded_token(101, [50])
|
||||||
assert result2 == ""
|
assert result2 == ""
|
||||||
|
|
||||||
def test_correction_with_multibyte_utf8(self, processor_with_previous_logprobs):
|
def test_correction_with_multibyte_utf8(self, processor):
|
||||||
"""Test correction involving multi-byte UTF-8 characters.
|
"""Test correction involving multi-byte UTF-8 characters.
|
||||||
|
|
||||||
Scenario: Byte-fallback tokenization splits multi-byte UTF-8
|
Scenario: Byte-fallback tokenization splits curly quotes.
|
||||||
characters (e.g., curly quotes, Chinese characters, emojis).
|
The last byte token should produce the complete character.
|
||||||
Example from user: "<EFBFBD>", "<EFBFBD>" -> "", "\""
|
|
||||||
"""
|
"""
|
||||||
processor = processor_with_previous_logprobs
|
|
||||||
tokens = [200, 201]
|
|
||||||
|
|
||||||
# Mock tokenizer behavior for multi-byte UTF-8 correction
|
|
||||||
def mock_decode(ids):
|
def mock_decode(ids):
|
||||||
# When decoding first token (idx=0) with previous logprob token
|
# Context [123] + first byte: completes to left curly quote
|
||||||
if ids == [123, 200]:
|
if ids == [123, 200]:
|
||||||
return ' "' # Space + left curly quote
|
return "hello \u201c"
|
||||||
# When decoding second token (idx=1) with previous token in list
|
if ids == [123]:
|
||||||
elif ids == [200, 201]:
|
return "hello "
|
||||||
return '"' # Right curly quote
|
# Context [123] + second byte: completes to right curly quote
|
||||||
# When decoding second token (idx=1) with previous logprob + prev token
|
if ids == [123, 201]:
|
||||||
elif ids == [123, 200, 201]:
|
return "hello \u201d"
|
||||||
return ' ""' # Full sequence
|
return "\ufffd"
|
||||||
return "<EFBFBD>"
|
|
||||||
|
|
||||||
processor.tokenizer.decode.side_effect = mock_decode
|
processor.tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
# First token correction (idx=0)
|
# Each top-k token is corrected independently with same context
|
||||||
# Will call decode([123, 200]) since idx=0 uses previous logprob token
|
result1 = processor._correct_decoded_token(200, [123])
|
||||||
result1 = processor._correct_decoded_token(0, tokens)
|
assert result1 == "\u201c"
|
||||||
assert result1 == ' "'
|
|
||||||
|
|
||||||
# Second token correction (idx=1)
|
result2 = processor._correct_decoded_token(201, [123])
|
||||||
# Will call decode([200, 201]) since idx>0 uses previous token in list
|
assert result2 == "\u201d"
|
||||||
result2 = processor._correct_decoded_token(1, tokens)
|
|
||||||
assert result2 == '"'
|
def test_topk_tokens_corrected_independently(self, processor):
|
||||||
|
"""Test that top-k alternatives at the same position are each
|
||||||
|
corrected independently using only sequential context, not
|
||||||
|
each other.
|
||||||
|
|
||||||
|
This is the core fix for issue #27300: when logprobs > 0,
|
||||||
|
alternative tokens must not be combined with each other.
|
||||||
|
"""
|
||||||
|
# Context: previously sampled token 50
|
||||||
|
context = [50]
|
||||||
|
|
||||||
|
def mock_decode(ids):
|
||||||
|
# Token 100 (sampled) with context
|
||||||
|
if ids == [50, 100]:
|
||||||
|
return "prefix \u201c"
|
||||||
|
# Token 200 (top-k alternative) with context
|
||||||
|
if ids == [50, 200]:
|
||||||
|
return "prefix \u2014"
|
||||||
|
# Context alone
|
||||||
|
if ids == [50]:
|
||||||
|
return "prefix "
|
||||||
|
return "\ufffd"
|
||||||
|
|
||||||
|
processor.tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
|
# Both tokens at the same position use the SAME context [50]
|
||||||
|
result_sampled = processor._correct_decoded_token(100, context)
|
||||||
|
assert result_sampled == "\u201c"
|
||||||
|
|
||||||
|
result_alt = processor._correct_decoded_token(200, context)
|
||||||
|
assert result_alt == "\u2014"
|
||||||
|
|
||||||
def test_real_world_opt125m_scenario(self, mock_tokenizer):
|
def test_real_world_opt125m_scenario(self, mock_tokenizer):
|
||||||
"""Test the real-world scenario from user's example.
|
"""Test the real-world scenario from the bug report.
|
||||||
|
|
||||||
User's example with facebook/opt-125m:
|
Simulates the OPT-125m sequence where curly quotes are split
|
||||||
Before: [" the", " term", " <20>", "<EFBFBD>", "p", "olar", "ized", "<EFBFBD>", "<EFBFBD>", ...]
|
into byte-fallback tokens. Each token is corrected using only
|
||||||
After: [" the", " term", "", " "", "p", "olar", "ized", "", "\"", ...]
|
the preceding sampled tokens as context.
|
||||||
"""
|
"""
|
||||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||||
|
|
||||||
# Simulate the sequence of tokens
|
|
||||||
processor = LogprobsProcessor(
|
processor = LogprobsProcessor(
|
||||||
tokenizer=mock_tokenizer,
|
tokenizer=mock_tokenizer,
|
||||||
logprobs=[],
|
logprobs=[],
|
||||||
@@ -776,47 +762,106 @@ class TestCorrectDecodedToken:
|
|||||||
num_prompt_logprobs=None,
|
num_prompt_logprobs=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Token IDs representing the problematic sequence
|
# Simulating: byte tokens 3, 4 form left curly quote "\u201c"
|
||||||
tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9] # placeholder IDs
|
# byte tokens 8, 9 form right curly quote "\u201d"
|
||||||
|
|
||||||
# Mock decode behavior simulating the real scenario
|
|
||||||
def mock_decode(ids):
|
def mock_decode(ids):
|
||||||
# Simulate cases where individual tokens decode to "<22>"
|
# Context decodes
|
||||||
# but combinations decode correctly
|
if ids == [2]:
|
||||||
if len(ids) == 1:
|
return " term"
|
||||||
if ids[0] in (3, 4, 8, 9):
|
if ids == [1, 2]:
|
||||||
return "<EFBFBD>"
|
return " the term"
|
||||||
elif len(ids) == 2:
|
if ids == [3]:
|
||||||
if ids == [2, 3]:
|
return "\ufffd"
|
||||||
return " term<72>" # Still ends with <20>, need more context
|
if ids == [2, 3]:
|
||||||
elif ids == [3, 4]:
|
return " term\ufffd"
|
||||||
return ' "' # Corrected to space + left curly quote
|
if ids == [1, 2, 3]:
|
||||||
elif ids == [7, 8]:
|
return " the term\ufffd"
|
||||||
return "ized<EFBFBD>" # Still ends with <20>
|
# Token 4 with context [2, 3] -> completes left curly quote
|
||||||
elif ids == [8, 9]:
|
if ids == [3, 4]:
|
||||||
return '"' # Corrected to right curly quote
|
return "\u201c"
|
||||||
elif len(ids) == 3:
|
if ids == [2, 3, 4]:
|
||||||
if ids == [1, 2, 3]:
|
return " term\u201c"
|
||||||
return " the term<72>" # Still ends with issue
|
# Context for right curly quote
|
||||||
elif ids == [2, 3, 4]:
|
if ids == [7]:
|
||||||
return ' term "' # With all context
|
return "ized"
|
||||||
|
if ids == [7, 8]:
|
||||||
|
return "ized\ufffd"
|
||||||
|
if ids == [8, 9]:
|
||||||
|
return "\u201d"
|
||||||
|
if ids == [7, 8, 9]:
|
||||||
|
return "ized\u201d"
|
||||||
return "normal_text"
|
return "normal_text"
|
||||||
|
|
||||||
mock_tokenizer.decode.side_effect = mock_decode
|
mock_tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
# Test token at index 2 (should fail to correct, return "")
|
# First byte (token 3) of left curly quote with no context
|
||||||
# Token 3 individually is "<22>"
|
result = processor._correct_decoded_token(3, [])
|
||||||
# decode([2, 3]) = " term<72>" (still ends with <20>)
|
|
||||||
# No previous logprobs, so fallback to ""
|
|
||||||
result = processor._correct_decoded_token(2, tokens)
|
|
||||||
assert result == ""
|
assert result == ""
|
||||||
|
|
||||||
# Test token at index 3 (should correct to " "")
|
# First byte (token 3) with context [2] -> still incomplete
|
||||||
# Token 4 individually is "<22>"
|
result = processor._correct_decoded_token(3, [2])
|
||||||
# decode([3, 4]) = " "" (corrected!)
|
assert result == ""
|
||||||
processor.logprobs = [{2: None}] # Add previous logprob
|
|
||||||
result = processor._correct_decoded_token(3, tokens)
|
# Second byte (token 4) of left curly quote with context [2, 3]
|
||||||
assert result == ' "'
|
# Token 3 is byte-fallback, so clean context is [2] only.
|
||||||
|
# decode([2, 3, 4]) = " term\u201c", decode([2]) = " term"
|
||||||
|
# result = "\u201c"
|
||||||
|
result = processor._correct_decoded_token(4, [2, 3])
|
||||||
|
assert result == "\u201c"
|
||||||
|
|
||||||
|
# Second byte (token 9) of right curly quote with context [7, 8]
|
||||||
|
result = processor._correct_decoded_token(9, [7, 8])
|
||||||
|
assert result == "\u201d"
|
||||||
|
|
||||||
|
def test_byte_fallback_context_preserves_space(self, mock_tokenizer):
|
||||||
|
"""Test that text from byte-fallback context tokens is preserved.
|
||||||
|
|
||||||
|
In OPT-125m, token 44 = space + 2 bytes of curly quote.
|
||||||
|
When token 44 returns "" (incomplete), the space it carried
|
||||||
|
must be attributed to the completing token (48).
|
||||||
|
"""
|
||||||
|
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||||
|
|
||||||
|
processor = LogprobsProcessor(
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
logprobs=[],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
cumulative_logprob=0.0,
|
||||||
|
num_logprobs=1,
|
||||||
|
num_prompt_logprobs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mock_decode(ids):
|
||||||
|
# Token 44 = space + 2 bytes (like OPT-125m's \u0120\u00e2\u0080)
|
||||||
|
if ids == [44]:
|
||||||
|
return " \ufffd"
|
||||||
|
if ids == [48]:
|
||||||
|
return "\ufffd"
|
||||||
|
# Together they form: space + left curly quote
|
||||||
|
if ids == [44, 48]:
|
||||||
|
return " \u201c"
|
||||||
|
# With preceding clean context
|
||||||
|
if ids == [1385]:
|
||||||
|
return " term"
|
||||||
|
if ids == [1385, 44]:
|
||||||
|
return " term \ufffd"
|
||||||
|
if ids == [1385, 44, 48]:
|
||||||
|
return " term \u201c"
|
||||||
|
return "\ufffd"
|
||||||
|
|
||||||
|
mock_tokenizer.decode.side_effect = mock_decode
|
||||||
|
|
||||||
|
# Token 44 with context [1385] -> still ends with replacement
|
||||||
|
result = processor._correct_decoded_token(44, [1385])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
# Token 48 with context [1385, 44]:
|
||||||
|
# Token 44 is byte-fallback, so clean context is [1385].
|
||||||
|
# decode([1385, 44, 48]) = " term \u201c"
|
||||||
|
# decode([1385]) = " term"
|
||||||
|
# result = " \u201c" (space preserved from token 44!)
|
||||||
|
result = processor._correct_decoded_token(48, [1385, 44])
|
||||||
|
assert result == " \u201c"
|
||||||
|
|
||||||
|
|
||||||
def test_verify_tokens_integration():
|
def test_verify_tokens_integration():
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import (
|
from vllm.logprobs import (
|
||||||
|
FlatLogprobs,
|
||||||
PromptLogprobs,
|
PromptLogprobs,
|
||||||
SampleLogprobs,
|
SampleLogprobs,
|
||||||
append_logprobs_for_next_position,
|
append_logprobs_for_next_position,
|
||||||
@@ -96,8 +97,11 @@ class LogprobsProcessor:
|
|||||||
decoded_tokens_list = convert_ids_list_to_tokens(
|
decoded_tokens_list = convert_ids_list_to_tokens(
|
||||||
self.tokenizer, token_ids
|
self.tokenizer, token_ids
|
||||||
)
|
)
|
||||||
|
context_token_ids = self._get_sampled_context_ids(self.logprobs)
|
||||||
decoded_tokens = self._verify_tokens(
|
decoded_tokens = self._verify_tokens(
|
||||||
decoded_tokens_list=decoded_tokens_list, tokens=token_ids
|
decoded_tokens_list=decoded_tokens_list,
|
||||||
|
tokens=token_ids,
|
||||||
|
context_token_ids=context_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampler puts the sampled logprob in first.
|
# Sampler puts the sampled logprob in first.
|
||||||
@@ -162,9 +166,14 @@ class LogprobsProcessor:
|
|||||||
else:
|
else:
|
||||||
# Extract decoded tokens for this position
|
# Extract decoded tokens for this position
|
||||||
decoded_tokens_slice = all_decoded_tokens[offset:offset_end]
|
decoded_tokens_slice = all_decoded_tokens[offset:offset_end]
|
||||||
|
# Context: preceding prompt tokens accumulated in
|
||||||
|
# self.prompt_logprobs from previous loop iterations.
|
||||||
|
context_token_ids = self._get_sampled_context_ids(self.prompt_logprobs)
|
||||||
# Apply UTF-8 correction within this position's token boundaries
|
# Apply UTF-8 correction within this position's token boundaries
|
||||||
decoded_tokens_for_pos = self._verify_tokens(
|
decoded_tokens_for_pos = self._verify_tokens(
|
||||||
decoded_tokens_list=decoded_tokens_slice, tokens=token_ids_list[pos]
|
decoded_tokens_list=decoded_tokens_slice,
|
||||||
|
tokens=token_ids_list[pos],
|
||||||
|
context_token_ids=context_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update with the Logprob container for this pos.
|
# Update with the Logprob container for this pos.
|
||||||
@@ -196,41 +205,139 @@ class LogprobsProcessor:
|
|||||||
self.prompt_logprobs = []
|
self.prompt_logprobs = []
|
||||||
return plp
|
return plp
|
||||||
|
|
||||||
def _correct_decoded_token(self, idx: int, tokens: list[int]) -> str:
|
@staticmethod
|
||||||
assert self.tokenizer is not None, "self.tokenizer should not be None"
|
def _get_sampled_context_ids(
|
||||||
|
logprobs_source: SampleLogprobs | PromptLogprobs | None,
|
||||||
|
max_context: int = 4,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Extract recent sampled token IDs from a logprobs source.
|
||||||
|
|
||||||
# try with prev token id in same list
|
The sampled (or prompt) token at each position is the first
|
||||||
if idx > 0:
|
entry, since it is always inserted first by
|
||||||
possible_decoded_token = self.tokenizer.decode(tokens[idx - 1 : idx + 1])
|
append_logprobs_for_next_position.
|
||||||
if not possible_decoded_token.endswith("<EFBFBD>"):
|
|
||||||
return possible_decoded_token
|
|
||||||
# try with previous logprob token id
|
|
||||||
if self.logprobs:
|
|
||||||
latest_token_id = next(iter(self.logprobs[-1]))
|
|
||||||
|
|
||||||
decode_ids = [latest_token_id]
|
Args:
|
||||||
if idx > 0:
|
logprobs_source: The logprobs container to extract from.
|
||||||
decode_ids.extend(tokens[idx - 1 : idx + 1])
|
max_context: Maximum number of preceding tokens to return.
|
||||||
|
4 is sufficient for any UTF-8 multi-byte sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sampled token IDs, oldest first, most recent last.
|
||||||
|
"""
|
||||||
|
if not logprobs_source:
|
||||||
|
return []
|
||||||
|
|
||||||
|
n = len(logprobs_source)
|
||||||
|
start = max(0, n - max_context)
|
||||||
|
|
||||||
|
# Efficient path for FlatLogprobs: access token_ids directly.
|
||||||
|
if isinstance(logprobs_source, FlatLogprobs):
|
||||||
|
return [
|
||||||
|
logprobs_source.token_ids[logprobs_source.start_indices[i]]
|
||||||
|
for i in range(start, n)
|
||||||
|
if logprobs_source.start_indices[i] < logprobs_source.end_indices[i]
|
||||||
|
]
|
||||||
|
|
||||||
|
# list[dict] path
|
||||||
|
result: list[int] = []
|
||||||
|
for i in range(start, n):
|
||||||
|
entry = logprobs_source[i]
|
||||||
|
if entry is not None:
|
||||||
|
result.append(next(iter(entry)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _correct_decoded_token(
|
||||||
|
self, token_id: int, context_token_ids: list[int]
|
||||||
|
) -> str:
|
||||||
|
"""Correct a decoded token that contains the replacement character.
|
||||||
|
|
||||||
|
When byte-fallback tokenization splits multi-byte UTF-8
|
||||||
|
characters across tokens, individual token decoding produces
|
||||||
|
the replacement character U+FFFD. This method uses preceding
|
||||||
|
sampled tokens as context to reconstruct the correct text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_id: The single token ID to correct.
|
||||||
|
context_token_ids: Preceding sampled token IDs in sequential
|
||||||
|
order (oldest first). These are the actual tokens in
|
||||||
|
the generated sequence, NOT top-k alternatives.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The corrected decoded string, or empty string if the byte
|
||||||
|
sequence is genuinely incomplete at this point.
|
||||||
|
"""
|
||||||
|
assert self.tokenizer is not None
|
||||||
|
|
||||||
|
max_ctx = min(len(context_token_ids), 4)
|
||||||
|
|
||||||
|
for num_ctx in range(1, max_ctx + 1):
|
||||||
|
context = context_token_ids[-num_ctx:]
|
||||||
|
full_decoded = self.tokenizer.decode(context + [token_id])
|
||||||
|
|
||||||
|
if full_decoded.endswith("<EFBFBD>"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find the boundary between "clean" context tokens and
|
||||||
|
# byte-fallback tokens that are part of the same incomplete
|
||||||
|
# sequence. Byte-fallback context tokens returned "" when
|
||||||
|
# they were processed, so their text must be attributed to
|
||||||
|
# this completing token.
|
||||||
|
clean_end = len(context)
|
||||||
|
for j in range(len(context) - 1, -1, -1):
|
||||||
|
if self.tokenizer.decode([context[j]]).endswith("<EFBFBD>"):
|
||||||
|
clean_end = j
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Decode only the clean (non-byte-fallback) prefix.
|
||||||
|
if clean_end > 0:
|
||||||
|
clean_prefix = self.tokenizer.decode(context[:clean_end])
|
||||||
else:
|
else:
|
||||||
decode_ids.extend(tokens[idx : idx + 1])
|
clean_prefix = ""
|
||||||
|
|
||||||
possible_decoded_token = self.tokenizer.decode(decode_ids)
|
if full_decoded.startswith(clean_prefix):
|
||||||
if not possible_decoded_token.endswith("<EFBFBD>"):
|
return full_decoded[len(clean_prefix) :]
|
||||||
return possible_decoded_token
|
|
||||||
|
# Tokenizer normalization may cause prefix mismatch.
|
||||||
|
# Find the longest common prefix between them.
|
||||||
|
common_len = 0
|
||||||
|
for a, b in zip(clean_prefix, full_decoded):
|
||||||
|
if a != b:
|
||||||
|
break
|
||||||
|
common_len += 1
|
||||||
|
return full_decoded[common_len:]
|
||||||
|
|
||||||
# by default return empty string
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def _verify_tokens(
|
def _verify_tokens(
|
||||||
self, decoded_tokens_list: list[str], tokens: list[int]
|
self,
|
||||||
|
decoded_tokens_list: list[str],
|
||||||
|
tokens: list[int],
|
||||||
|
context_token_ids: list[int] | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
"""Verify and correct decoded tokens with replacement characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoded_tokens_list: Decoded token strings to verify.
|
||||||
|
tokens: Token IDs corresponding to decoded_tokens_list.
|
||||||
|
These are alternatives at the SAME position (e.g.
|
||||||
|
[sampled, top1, top2]), NOT sequential tokens.
|
||||||
|
context_token_ids: Preceding sampled token IDs providing
|
||||||
|
sequential context. If None, extracted from
|
||||||
|
self.logprobs.
|
||||||
|
"""
|
||||||
|
if context_token_ids is None:
|
||||||
|
context_token_ids = self._get_sampled_context_ids(self.logprobs)
|
||||||
|
|
||||||
corrected_decoded_token_map = dict()
|
corrected_decoded_token_map = dict()
|
||||||
for idx, text in enumerate(decoded_tokens_list):
|
for idx, text in enumerate(decoded_tokens_list):
|
||||||
if text.endswith("<EFBFBD>"):
|
if text.endswith("<EFBFBD>"):
|
||||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
# Replacement char at the end means a potential
|
||||||
# from byte fallback tokenization.
|
# unfinished byte sequence from byte-fallback
|
||||||
|
# tokenization. Correct each token independently
|
||||||
|
# using only the sequential context.
|
||||||
corrected_decoded_token_map[idx] = self._correct_decoded_token(
|
corrected_decoded_token_map[idx] = self._correct_decoded_token(
|
||||||
idx, tokens
|
tokens[idx], context_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, text in corrected_decoded_token_map.items():
|
for idx, text in corrected_decoded_token_map.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user