diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index d029a6ce0..28fb2931b 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -539,6 +539,10 @@ class TestCorrectDecodedToken: result in the Unicode replacement character "�" (U+FFFD). This commonly happens with byte-fallback tokenization when multi-byte UTF-8 characters are split across tokens. + + The method signature is _correct_decoded_token(token_id, context_token_ids) + where token_id is the single token to correct and context_token_ids are + the preceding sampled tokens in sequential order. """ @pytest.fixture @@ -550,8 +554,8 @@ class TestCorrectDecodedToken: return tokenizer @pytest.fixture - def processor_with_empty_logprobs(self, mock_tokenizer): - """Create a LogprobsProcessor with empty logprobs.""" + def processor(self, mock_tokenizer): + """Create a LogprobsProcessor.""" from vllm.v1.engine.logprobs import LogprobsProcessor processor = LogprobsProcessor( @@ -564,209 +568,191 @@ class TestCorrectDecodedToken: ) return processor - @pytest.fixture - def processor_with_previous_logprobs(self, mock_tokenizer): - """Create a LogprobsProcessor with previous logprobs.""" - from vllm.v1.engine.logprobs import LogprobsProcessor + def test_correction_with_context(self, processor): + """Test correction using context from preceding sampled tokens. - processor = LogprobsProcessor( - tokenizer=mock_tokenizer, - logprobs=[{123: None}], # Previous token ID is 123 - prompt_logprobs=None, - cumulative_logprob=0.0, - num_logprobs=1, - num_prompt_logprobs=None, - ) - return processor - - def test_correction_with_previous_token_in_list( - self, processor_with_empty_logprobs - ): - """Test correction using previous token in the same list. - - Scenario: Token at idx=1 ends with "�", but when decoded with - the previous token (idx=0), it forms a valid UTF-8 sequence. - Example: token[0]="�", token[1]="�" -> together form "polarized" + Scenario: A byte-fallback token that completes a multi-byte + UTF-8 sequence when decoded with context. """ - processor = processor_with_empty_logprobs - tokens = [100, 101, 102] # token IDs - # Mock tokenizer behavior: - # - decode([102]) returns "�" (ends with replacement char) - # - decode([101, 102]) returns "valid" (no replacement char) - processor.tokenizer.decode.side_effect = lambda ids: ( - "valid" if ids == [101, 102] else "�" - ) + # Context is [101] (a preceding sampled token) + # Token 102 individually decodes to "�" + # decode([101, 102]) returns "valid" (complete sequence) + def mock_decode(ids): + if ids == [101, 102]: + return "hello valid" + if ids == [101]: + return "hello " + return "�" - result = processor._correct_decoded_token(2, tokens) + processor.tokenizer.decode.side_effect = mock_decode + + result = processor._correct_decoded_token(102, [101]) assert result == "valid" - processor.tokenizer.decode.assert_called_with([101, 102]) - def test_correction_with_previous_logprob_token( - self, processor_with_previous_logprobs - ): - """Test correction using previous logprob token. + def test_correction_with_context_from_logprobs(self, processor): + """Test correction using context from previous logprob entries. - Scenario: Cannot correct with previous token in list (idx=0), - but can correct with previous logprob token. + Scenario: Token decoded with context from previously sampled + tokens completes a UTF-8 sequence. """ - processor = processor_with_previous_logprobs - tokens = [100] # single token - # Mock tokenizer behavior: - # - decode([100]) returns "�" (ends with replacement char) - # - decode([123, 100]) returns " "polarized" (no replacement char) - # Token 123 is from previous logprobs + # Token 123 was previously sampled (in context) def mock_decode(ids): if ids == [123, 100]: - return ' "polarized"' + return 'hello "polarized"' + if ids == [123]: + return "hello " return "�" processor.tokenizer.decode.side_effect = mock_decode - result = processor._correct_decoded_token(0, tokens) - assert result == ' "polarized"' + result = processor._correct_decoded_token(100, [123]) + assert result == '"polarized"' - def test_correction_at_idx_zero_no_previous_logprobs( - self, processor_with_empty_logprobs - ): - """Test correction at idx=0 with no previous logprobs. + def test_correction_no_context(self, processor): + """Test correction with no context available. - Scenario: First token in list, no previous logprobs available. Should return empty string as fallback. """ - processor = processor_with_empty_logprobs - tokens = [100] - - # Mock tokenizer always returns "�" processor.tokenizer.decode.return_value = "�" - result = processor._correct_decoded_token(0, tokens) + result = processor._correct_decoded_token(100, []) assert result == "" - def test_correction_at_idx_zero_with_previous_logprobs( - self, processor_with_previous_logprobs - ): - """Test correction at idx=0 with previous logprobs available. + def test_correction_with_context_succeeds(self, processor): + """Test correction with context from previously sampled tokens.""" - Scenario: First token in list, but previous logprobs exist. - Should try correction with previous logprob token. - """ - processor = processor_with_previous_logprobs - tokens = [200] - - # Mock tokenizer behavior def mock_decode(ids): if ids == [123, 200]: - return "corrected" + return "hello corrected" + if ids == [123]: + return "hello " return "�" processor.tokenizer.decode.side_effect = mock_decode - result = processor._correct_decoded_token(0, tokens) + result = processor._correct_decoded_token(200, [123]) assert result == "corrected" - def test_no_correction_needed_returns_fallback( - self, processor_with_previous_logprobs - ): - """Test fallback to empty string when no correction works. - - Scenario: All correction attempts still end with "�". - Should return empty string as final fallback. - """ - processor = processor_with_previous_logprobs - tokens = [100, 101, 102] - - # Mock tokenizer always returns text ending with "�" + def test_fallback_when_all_attempts_fail(self, processor): + """Test fallback to empty string when no correction works.""" processor.tokenizer.decode.return_value = "still�" - result = processor._correct_decoded_token(2, tokens) + result = processor._correct_decoded_token(102, [100, 101]) assert result == "" - def test_middle_token_correction(self, processor_with_previous_logprobs): - """Test correction for a token in the middle of the list. + def test_increasing_context_window(self, processor): + """Test that increasing context window finds the correction. - Scenario: Token at idx=5 in a longer list needs correction. + Scenario: 3-byte UTF-8 char. With 1 context token, still + incomplete. With 2 context tokens, completes the sequence. """ - processor = processor_with_previous_logprobs - tokens = [10, 20, 30, 40, 50, 60, 70, 80] - # Mock tokenizer behavior for middle token def mock_decode(ids): - if ids == [50, 60]: - return "olar" + # 1 context token: still incomplete + if ids == [81, 82]: + return "�" + # 2 context tokens: complete + if ids == [80, 81, 82]: + return "\u201c" + # Context-only decodes + if ids == [81]: + return "�" + if ids == [80, 81]: + return "�" return "�" processor.tokenizer.decode.side_effect = mock_decode - result = processor._correct_decoded_token(5, tokens) - assert result == "olar" + # Context has 2 preceding tokens [80, 81] + result = processor._correct_decoded_token(82, [80, 81]) + assert result == "\u201c" - def test_multiple_consecutive_replacement_chars( - self, processor_with_previous_logprobs - ): + def test_multiple_consecutive_replacement_chars(self, processor): """Test handling of multiple consecutive replacement characters. - Scenario: Sequence like ["�", "�", "p"] where first two should - become empty strings. + Scenario: Multi-byte sequence where intermediate bytes return + empty string and the final byte returns the complete character. """ - processor = processor_with_previous_logprobs - - # Test first replacement char - tokens = [100, 101, 102] processor.tokenizer.decode.return_value = "still�" - result1 = processor._correct_decoded_token(0, tokens) + + # First byte with no useful context: returns "" + result1 = processor._correct_decoded_token(100, [50]) assert result1 == "" - # Test second replacement char - result2 = processor._correct_decoded_token(1, tokens) + # Second byte with same context: still returns "" + result2 = processor._correct_decoded_token(101, [50]) assert result2 == "" - def test_correction_with_multibyte_utf8(self, processor_with_previous_logprobs): + def test_correction_with_multibyte_utf8(self, processor): """Test correction involving multi-byte UTF-8 characters. - Scenario: Byte-fallback tokenization splits multi-byte UTF-8 - characters (e.g., curly quotes, Chinese characters, emojis). - Example from user: "�", "�" -> "", "\"" + Scenario: Byte-fallback tokenization splits curly quotes. + The last byte token should produce the complete character. """ - processor = processor_with_previous_logprobs - tokens = [200, 201] - # Mock tokenizer behavior for multi-byte UTF-8 correction def mock_decode(ids): - # When decoding first token (idx=0) with previous logprob token + # Context [123] + first byte: completes to left curly quote if ids == [123, 200]: - return ' "' # Space + left curly quote - # When decoding second token (idx=1) with previous token in list - elif ids == [200, 201]: - return '"' # Right curly quote - # When decoding second token (idx=1) with previous logprob + prev token - elif ids == [123, 200, 201]: - return ' ""' # Full sequence - return "�" + return "hello \u201c" + if ids == [123]: + return "hello " + # Context [123] + second byte: completes to right curly quote + if ids == [123, 201]: + return "hello \u201d" + return "\ufffd" processor.tokenizer.decode.side_effect = mock_decode - # First token correction (idx=0) - # Will call decode([123, 200]) since idx=0 uses previous logprob token - result1 = processor._correct_decoded_token(0, tokens) - assert result1 == ' "' + # Each top-k token is corrected independently with same context + result1 = processor._correct_decoded_token(200, [123]) + assert result1 == "\u201c" - # Second token correction (idx=1) - # Will call decode([200, 201]) since idx>0 uses previous token in list - result2 = processor._correct_decoded_token(1, tokens) - assert result2 == '"' + result2 = processor._correct_decoded_token(201, [123]) + assert result2 == "\u201d" + + def test_topk_tokens_corrected_independently(self, processor): + """Test that top-k alternatives at the same position are each + corrected independently using only sequential context, not + each other. + + This is the core fix for issue #27300: when logprobs > 0, + alternative tokens must not be combined with each other. + """ + # Context: previously sampled token 50 + context = [50] + + def mock_decode(ids): + # Token 100 (sampled) with context + if ids == [50, 100]: + return "prefix \u201c" + # Token 200 (top-k alternative) with context + if ids == [50, 200]: + return "prefix \u2014" + # Context alone + if ids == [50]: + return "prefix " + return "\ufffd" + + processor.tokenizer.decode.side_effect = mock_decode + + # Both tokens at the same position use the SAME context [50] + result_sampled = processor._correct_decoded_token(100, context) + assert result_sampled == "\u201c" + + result_alt = processor._correct_decoded_token(200, context) + assert result_alt == "\u2014" def test_real_world_opt125m_scenario(self, mock_tokenizer): - """Test the real-world scenario from user's example. + """Test the real-world scenario from the bug report. - User's example with facebook/opt-125m: - Before: [" the", " term", " �", "�", "p", "olar", "ized", "�", "�", ...] - After: [" the", " term", "", " "", "p", "olar", "ized", "", "\"", ...] + Simulates the OPT-125m sequence where curly quotes are split + into byte-fallback tokens. Each token is corrected using only + the preceding sampled tokens as context. """ from vllm.v1.engine.logprobs import LogprobsProcessor - # Simulate the sequence of tokens processor = LogprobsProcessor( tokenizer=mock_tokenizer, logprobs=[], @@ -776,47 +762,106 @@ class TestCorrectDecodedToken: num_prompt_logprobs=None, ) - # Token IDs representing the problematic sequence - tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9] # placeholder IDs - - # Mock decode behavior simulating the real scenario + # Simulating: byte tokens 3, 4 form left curly quote "\u201c" + # byte tokens 8, 9 form right curly quote "\u201d" def mock_decode(ids): - # Simulate cases where individual tokens decode to "�" - # but combinations decode correctly - if len(ids) == 1: - if ids[0] in (3, 4, 8, 9): - return "�" - elif len(ids) == 2: - if ids == [2, 3]: - return " term�" # Still ends with �, need more context - elif ids == [3, 4]: - return ' "' # Corrected to space + left curly quote - elif ids == [7, 8]: - return "ized�" # Still ends with � - elif ids == [8, 9]: - return '"' # Corrected to right curly quote - elif len(ids) == 3: - if ids == [1, 2, 3]: - return " the term�" # Still ends with issue - elif ids == [2, 3, 4]: - return ' term "' # With all context + # Context decodes + if ids == [2]: + return " term" + if ids == [1, 2]: + return " the term" + if ids == [3]: + return "\ufffd" + if ids == [2, 3]: + return " term\ufffd" + if ids == [1, 2, 3]: + return " the term\ufffd" + # Token 4 with context [2, 3] -> completes left curly quote + if ids == [3, 4]: + return "\u201c" + if ids == [2, 3, 4]: + return " term\u201c" + # Context for right curly quote + if ids == [7]: + return "ized" + if ids == [7, 8]: + return "ized\ufffd" + if ids == [8, 9]: + return "\u201d" + if ids == [7, 8, 9]: + return "ized\u201d" return "normal_text" mock_tokenizer.decode.side_effect = mock_decode - # Test token at index 2 (should fail to correct, return "") - # Token 3 individually is "�" - # decode([2, 3]) = " term�" (still ends with �) - # No previous logprobs, so fallback to "" - result = processor._correct_decoded_token(2, tokens) + # First byte (token 3) of left curly quote with no context + result = processor._correct_decoded_token(3, []) assert result == "" - # Test token at index 3 (should correct to " "") - # Token 4 individually is "�" - # decode([3, 4]) = " "" (corrected!) - processor.logprobs = [{2: None}] # Add previous logprob - result = processor._correct_decoded_token(3, tokens) - assert result == ' "' + # First byte (token 3) with context [2] -> still incomplete + result = processor._correct_decoded_token(3, [2]) + assert result == "" + + # Second byte (token 4) of left curly quote with context [2, 3] + # Token 3 is byte-fallback, so clean context is [2] only. + # decode([2, 3, 4]) = " term\u201c", decode([2]) = " term" + # result = "\u201c" + result = processor._correct_decoded_token(4, [2, 3]) + assert result == "\u201c" + + # Second byte (token 9) of right curly quote with context [7, 8] + result = processor._correct_decoded_token(9, [7, 8]) + assert result == "\u201d" + + def test_byte_fallback_context_preserves_space(self, mock_tokenizer): + """Test that text from byte-fallback context tokens is preserved. + + In OPT-125m, token 44 = space + 2 bytes of curly quote. + When token 44 returns "" (incomplete), the space it carried + must be attributed to the completing token (48). + """ + from vllm.v1.engine.logprobs import LogprobsProcessor + + processor = LogprobsProcessor( + tokenizer=mock_tokenizer, + logprobs=[], + prompt_logprobs=None, + cumulative_logprob=0.0, + num_logprobs=1, + num_prompt_logprobs=None, + ) + + def mock_decode(ids): + # Token 44 = space + 2 bytes (like OPT-125m's \u0120\u00e2\u0080) + if ids == [44]: + return " \ufffd" + if ids == [48]: + return "\ufffd" + # Together they form: space + left curly quote + if ids == [44, 48]: + return " \u201c" + # With preceding clean context + if ids == [1385]: + return " term" + if ids == [1385, 44]: + return " term \ufffd" + if ids == [1385, 44, 48]: + return " term \u201c" + return "\ufffd" + + mock_tokenizer.decode.side_effect = mock_decode + + # Token 44 with context [1385] -> still ends with replacement + result = processor._correct_decoded_token(44, [1385]) + assert result == "" + + # Token 48 with context [1385, 44]: + # Token 44 is byte-fallback, so clean context is [1385]. + # decode([1385, 44, 48]) = " term \u201c" + # decode([1385]) = " term" + # result = " \u201c" (space preserved from token 44!) + result = processor._correct_decoded_token(48, [1385, 44]) + assert result == " \u201c" def test_verify_tokens_integration(): diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 513531c31..9ada6eda4 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from vllm.logger import init_logger from vllm.logprobs import ( + FlatLogprobs, PromptLogprobs, SampleLogprobs, append_logprobs_for_next_position, @@ -96,8 +97,11 @@ class LogprobsProcessor: decoded_tokens_list = convert_ids_list_to_tokens( self.tokenizer, token_ids ) + context_token_ids = self._get_sampled_context_ids(self.logprobs) decoded_tokens = self._verify_tokens( - decoded_tokens_list=decoded_tokens_list, tokens=token_ids + decoded_tokens_list=decoded_tokens_list, + tokens=token_ids, + context_token_ids=context_token_ids, ) # Sampler puts the sampled logprob in first. @@ -162,9 +166,14 @@ class LogprobsProcessor: else: # Extract decoded tokens for this position decoded_tokens_slice = all_decoded_tokens[offset:offset_end] + # Context: preceding prompt tokens accumulated in + # self.prompt_logprobs from previous loop iterations. + context_token_ids = self._get_sampled_context_ids(self.prompt_logprobs) # Apply UTF-8 correction within this position's token boundaries decoded_tokens_for_pos = self._verify_tokens( - decoded_tokens_list=decoded_tokens_slice, tokens=token_ids_list[pos] + decoded_tokens_list=decoded_tokens_slice, + tokens=token_ids_list[pos], + context_token_ids=context_token_ids, ) # Update with the Logprob container for this pos. @@ -196,41 +205,139 @@ class LogprobsProcessor: self.prompt_logprobs = [] return plp - def _correct_decoded_token(self, idx: int, tokens: list[int]) -> str: - assert self.tokenizer is not None, "self.tokenizer should not be None" + @staticmethod + def _get_sampled_context_ids( + logprobs_source: SampleLogprobs | PromptLogprobs | None, + max_context: int = 4, + ) -> list[int]: + """Extract recent sampled token IDs from a logprobs source. - # try with prev token id in same list - if idx > 0: - possible_decoded_token = self.tokenizer.decode(tokens[idx - 1 : idx + 1]) - if not possible_decoded_token.endswith("�"): - return possible_decoded_token - # try with previous logprob token id - if self.logprobs: - latest_token_id = next(iter(self.logprobs[-1])) + The sampled (or prompt) token at each position is the first + entry, since it is always inserted first by + append_logprobs_for_next_position. - decode_ids = [latest_token_id] - if idx > 0: - decode_ids.extend(tokens[idx - 1 : idx + 1]) + Args: + logprobs_source: The logprobs container to extract from. + max_context: Maximum number of preceding tokens to return. + 4 is sufficient for any UTF-8 multi-byte sequence. + + Returns: + List of sampled token IDs, oldest first, most recent last. + """ + if not logprobs_source: + return [] + + n = len(logprobs_source) + start = max(0, n - max_context) + + # Efficient path for FlatLogprobs: access token_ids directly. + if isinstance(logprobs_source, FlatLogprobs): + return [ + logprobs_source.token_ids[logprobs_source.start_indices[i]] + for i in range(start, n) + if logprobs_source.start_indices[i] < logprobs_source.end_indices[i] + ] + + # list[dict] path + result: list[int] = [] + for i in range(start, n): + entry = logprobs_source[i] + if entry is not None: + result.append(next(iter(entry))) + return result + + def _correct_decoded_token( + self, token_id: int, context_token_ids: list[int] + ) -> str: + """Correct a decoded token that contains the replacement character. + + When byte-fallback tokenization splits multi-byte UTF-8 + characters across tokens, individual token decoding produces + the replacement character U+FFFD. This method uses preceding + sampled tokens as context to reconstruct the correct text. + + Args: + token_id: The single token ID to correct. + context_token_ids: Preceding sampled token IDs in sequential + order (oldest first). These are the actual tokens in + the generated sequence, NOT top-k alternatives. + + Returns: + The corrected decoded string, or empty string if the byte + sequence is genuinely incomplete at this point. + """ + assert self.tokenizer is not None + + max_ctx = min(len(context_token_ids), 4) + + for num_ctx in range(1, max_ctx + 1): + context = context_token_ids[-num_ctx:] + full_decoded = self.tokenizer.decode(context + [token_id]) + + if full_decoded.endswith("�"): + continue + + # Find the boundary between "clean" context tokens and + # byte-fallback tokens that are part of the same incomplete + # sequence. Byte-fallback context tokens returned "" when + # they were processed, so their text must be attributed to + # this completing token. + clean_end = len(context) + for j in range(len(context) - 1, -1, -1): + if self.tokenizer.decode([context[j]]).endswith("�"): + clean_end = j + else: + break + + # Decode only the clean (non-byte-fallback) prefix. + if clean_end > 0: + clean_prefix = self.tokenizer.decode(context[:clean_end]) else: - decode_ids.extend(tokens[idx : idx + 1]) + clean_prefix = "" - possible_decoded_token = self.tokenizer.decode(decode_ids) - if not possible_decoded_token.endswith("�"): - return possible_decoded_token + if full_decoded.startswith(clean_prefix): + return full_decoded[len(clean_prefix) :] + + # Tokenizer normalization may cause prefix mismatch. + # Find the longest common prefix between them. + common_len = 0 + for a, b in zip(clean_prefix, full_decoded): + if a != b: + break + common_len += 1 + return full_decoded[common_len:] - # by default return empty string return "" def _verify_tokens( - self, decoded_tokens_list: list[str], tokens: list[int] + self, + decoded_tokens_list: list[str], + tokens: list[int], + context_token_ids: list[int] | None = None, ) -> list[str]: + """Verify and correct decoded tokens with replacement characters. + + Args: + decoded_tokens_list: Decoded token strings to verify. + tokens: Token IDs corresponding to decoded_tokens_list. + These are alternatives at the SAME position (e.g. + [sampled, top1, top2]), NOT sequential tokens. + context_token_ids: Preceding sampled token IDs providing + sequential context. If None, extracted from + self.logprobs. + """ + if context_token_ids is None: + context_token_ids = self._get_sampled_context_ids(self.logprobs) + corrected_decoded_token_map = dict() for idx, text in enumerate(decoded_tokens_list): if text.endswith("�"): - # utf-8 char at the end means it's a potential unfinished byte sequence - # from byte fallback tokenization. + # Replacement char at the end means a potential + # unfinished byte sequence from byte-fallback + # tokenization. Correct each token independently + # using only the sequential context. corrected_decoded_token_map[idx] = self._correct_decoded_token( - idx, tokens + tokens[idx], context_token_ids ) for idx, text in corrected_decoded_token_map.items():