diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index f1185222f..9630f8cae 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -274,12 +274,28 @@ def _validate_logprobs( # the logprob token id at this sequence position decoded_token = pos_logprob_dict[lp_tok].decoded_token ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok) - assert decoded_token == ref_decoded_token, ( - f"Sampled logprob token id {lp_tok} decodes to" - f" {ref_decoded_token} but Logprob decoded" - f" token is {decoded_token} instead" - f" (at position {idx})" - ) + + # With UTF-8 correction logic, tokens ending with "�" + # (incomplete byte sequences) are corrected to either + # empty string or proper UTF-8 characters + if ref_decoded_token.endswith("�"): + # Token needs UTF-8 correction + assert not decoded_token.endswith("�"), ( + f"Sampled logprob token id {lp_tok} decodes to" + f" '{ref_decoded_token}' (ends with replacement char)" + f" but corrected decoded token '{decoded_token}'" + f" still ends with replacement char" + f" (at position {idx}). UTF-8 correction should" + f" have removed it." + ) + else: + # No correction needed, should match exactly + assert decoded_token == ref_decoded_token, ( + f"Sampled logprob token id {lp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})" + ) ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob # Assert that cumulative logprobs are correct @@ -420,12 +436,28 @@ def _validate_logprobs( # the logprob token id at this sequence position decoded_token = pos_logprob_dict[plp_tok].decoded_token ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok) - assert decoded_token == ref_decoded_token, ( - f"Prompt logprob token id {plp_tok} decodes to" - f" {ref_decoded_token} but Logprob decoded" - f" token is {decoded_token} instead" - f" (at position {idx})" - ) + + # With UTF-8 correction logic, tokens ending with "�" + # (incomplete byte sequences) are corrected to either + # empty string or proper UTF-8 characters + if ref_decoded_token.endswith("�"): + # Token needs UTF-8 correction + assert not decoded_token.endswith("�"), ( + f"Prompt logprob token id {plp_tok} decodes to" + f" '{ref_decoded_token}' (ends with replacement char)" + f" but corrected decoded token '{decoded_token}'" + f" still ends with replacement char" + f" (at position {idx}). UTF-8 correction should" + f" have removed it." + ) + else: + # No correction needed, should match exactly + assert decoded_token == ref_decoded_token, ( + f"Prompt logprob token id {plp_tok} decodes to" + f" {ref_decoded_token} but Logprob decoded" + f" token is {decoded_token} instead" + f" (at position {idx})" + ) else: # Prompt logprobs disabled for this request assert prompt_logprobs is None diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 1e2cc2241..abb3ce2ef 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -514,6 +514,424 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): del llm +class TestCorrectDecodedToken: + """Unit tests for _correct_decoded_token method in LogprobsProcessor. + + This method handles UTF-8 decoding issues where incomplete byte sequences + result in the Unicode replacement character "�" (U+FFFD). This commonly + happens with byte-fallback tokenization when multi-byte UTF-8 characters + are split across tokens. + """ + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer for testing.""" + from unittest.mock import Mock + + tokenizer = Mock() + return tokenizer + + @pytest.fixture + def processor_with_empty_logprobs(self, mock_tokenizer): + """Create a LogprobsProcessor with empty logprobs.""" + from vllm.v1.engine.logprobs import LogprobsProcessor + + processor = LogprobsProcessor( + tokenizer=mock_tokenizer, + logprobs=[], + prompt_logprobs=None, + cumulative_logprob=0.0, + num_logprobs=1, + num_prompt_logprobs=None, + ) + return processor + + @pytest.fixture + def processor_with_previous_logprobs(self, mock_tokenizer): + """Create a LogprobsProcessor with previous logprobs.""" + from vllm.v1.engine.logprobs import LogprobsProcessor + + processor = LogprobsProcessor( + tokenizer=mock_tokenizer, + logprobs=[{123: None}], # Previous token ID is 123 + prompt_logprobs=None, + cumulative_logprob=0.0, + num_logprobs=1, + num_prompt_logprobs=None, + ) + return processor + + def test_correction_with_previous_token_in_list( + self, processor_with_empty_logprobs + ): + """Test correction using previous token in the same list. + + Scenario: Token at idx=1 ends with "�", but when decoded with + the previous token (idx=0), it forms a valid UTF-8 sequence. + Example: token[0]="�", token[1]="�" -> together form "polarized" + """ + processor = processor_with_empty_logprobs + tokens = [100, 101, 102] # token IDs + + # Mock tokenizer behavior: + # - decode([102]) returns "�" (ends with replacement char) + # - decode([101, 102]) returns "valid" (no replacement char) + processor.tokenizer.decode.side_effect = lambda ids: ( + "valid" if ids == [101, 102] else "�" + ) + + result = processor._correct_decoded_token(2, tokens) + assert result == "valid" + processor.tokenizer.decode.assert_called_with([101, 102]) + + def test_correction_with_previous_logprob_token( + self, processor_with_previous_logprobs + ): + """Test correction using previous logprob token. + + Scenario: Cannot correct with previous token in list (idx=0), + but can correct with previous logprob token. + """ + processor = processor_with_previous_logprobs + tokens = [100] # single token + + # Mock tokenizer behavior: + # - decode([100]) returns "�" (ends with replacement char) + # - decode([123, 100]) returns " "polarized" (no replacement char) + # Token 123 is from previous logprobs + def mock_decode(ids): + if ids == [123, 100]: + return ' "polarized"' + return "�" + + processor.tokenizer.decode.side_effect = mock_decode + + result = processor._correct_decoded_token(0, tokens) + assert result == ' "polarized"' + + def test_correction_at_idx_zero_no_previous_logprobs( + self, processor_with_empty_logprobs + ): + """Test correction at idx=0 with no previous logprobs. + + Scenario: First token in list, no previous logprobs available. + Should return empty string as fallback. + """ + processor = processor_with_empty_logprobs + tokens = [100] + + # Mock tokenizer always returns "�" + processor.tokenizer.decode.return_value = "�" + + result = processor._correct_decoded_token(0, tokens) + assert result == "" + + def test_correction_at_idx_zero_with_previous_logprobs( + self, processor_with_previous_logprobs + ): + """Test correction at idx=0 with previous logprobs available. + + Scenario: First token in list, but previous logprobs exist. + Should try correction with previous logprob token. + """ + processor = processor_with_previous_logprobs + tokens = [200] + + # Mock tokenizer behavior + def mock_decode(ids): + if ids == [123, 200]: + return "corrected" + return "�" + + processor.tokenizer.decode.side_effect = mock_decode + + result = processor._correct_decoded_token(0, tokens) + assert result == "corrected" + + def test_no_correction_needed_returns_fallback( + self, processor_with_previous_logprobs + ): + """Test fallback to empty string when no correction works. + + Scenario: All correction attempts still end with "�". + Should return empty string as final fallback. + """ + processor = processor_with_previous_logprobs + tokens = [100, 101, 102] + + # Mock tokenizer always returns text ending with "�" + processor.tokenizer.decode.return_value = "still�" + + result = processor._correct_decoded_token(2, tokens) + assert result == "" + + def test_middle_token_correction(self, processor_with_previous_logprobs): + """Test correction for a token in the middle of the list. + + Scenario: Token at idx=5 in a longer list needs correction. + """ + processor = processor_with_previous_logprobs + tokens = [10, 20, 30, 40, 50, 60, 70, 80] + + # Mock tokenizer behavior for middle token + def mock_decode(ids): + if ids == [50, 60]: + return "olar" + return "�" + + processor.tokenizer.decode.side_effect = mock_decode + + result = processor._correct_decoded_token(5, tokens) + assert result == "olar" + + def test_multiple_consecutive_replacement_chars( + self, processor_with_previous_logprobs + ): + """Test handling of multiple consecutive replacement characters. + + Scenario: Sequence like ["�", "�", "p"] where first two should + become empty strings. + """ + processor = processor_with_previous_logprobs + + # Test first replacement char + tokens = [100, 101, 102] + processor.tokenizer.decode.return_value = "still�" + result1 = processor._correct_decoded_token(0, tokens) + assert result1 == "" + + # Test second replacement char + result2 = processor._correct_decoded_token(1, tokens) + assert result2 == "" + + def test_correction_with_multibyte_utf8(self, processor_with_previous_logprobs): + """Test correction involving multi-byte UTF-8 characters. + + Scenario: Byte-fallback tokenization splits multi-byte UTF-8 + characters (e.g., curly quotes, Chinese characters, emojis). + Example from user: "�", "�" -> "", "\"" + """ + processor = processor_with_previous_logprobs + tokens = [200, 201] + + # Mock tokenizer behavior for multi-byte UTF-8 correction + def mock_decode(ids): + # When decoding first token (idx=0) with previous logprob token + if ids == [123, 200]: + return ' "' # Space + left curly quote + # When decoding second token (idx=1) with previous token in list + elif ids == [200, 201]: + return '"' # Right curly quote + # When decoding second token (idx=1) with previous logprob + prev token + elif ids == [123, 200, 201]: + return ' ""' # Full sequence + return "�" + + processor.tokenizer.decode.side_effect = mock_decode + + # First token correction (idx=0) + # Will call decode([123, 200]) since idx=0 uses previous logprob token + result1 = processor._correct_decoded_token(0, tokens) + assert result1 == ' "' + + # Second token correction (idx=1) + # Will call decode([200, 201]) since idx>0 uses previous token in list + result2 = processor._correct_decoded_token(1, tokens) + assert result2 == '"' + + def test_real_world_opt125m_scenario(self, mock_tokenizer): + """Test the real-world scenario from user's example. + + User's example with facebook/opt-125m: + Before: [" the", " term", " �", "�", "p", "olar", "ized", "�", "�", ...] + After: [" the", " term", "", " "", "p", "olar", "ized", "", "\"", ...] + """ + from vllm.v1.engine.logprobs import LogprobsProcessor + + # Simulate the sequence of tokens + processor = LogprobsProcessor( + tokenizer=mock_tokenizer, + logprobs=[], + prompt_logprobs=None, + cumulative_logprob=0.0, + num_logprobs=1, + num_prompt_logprobs=None, + ) + + # Token IDs representing the problematic sequence + tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9] # placeholder IDs + + # Mock decode behavior simulating the real scenario + def mock_decode(ids): + # Simulate cases where individual tokens decode to "�" + # but combinations decode correctly + if len(ids) == 1: + if ids[0] == 3 or ids[0] == 4 or ids[0] == 8 or ids[0] == 9: + return "�" + elif len(ids) == 2: + if ids == [2, 3]: + return " term�" # Still ends with �, need more context + elif ids == [3, 4]: + return ' "' # Corrected to space + left curly quote + elif ids == [7, 8]: + return "ized�" # Still ends with � + elif ids == [8, 9]: + return '"' # Corrected to right curly quote + elif len(ids) == 3: + if ids == [1, 2, 3]: + return " the term�" # Still ends with issue + elif ids == [2, 3, 4]: + return ' term "' # With all context + return "normal_text" + + mock_tokenizer.decode.side_effect = mock_decode + + # Test token at index 2 (should fail to correct, return "") + # Token 3 individually is "�" + # decode([2, 3]) = " term�" (still ends with �) + # No previous logprobs, so fallback to "" + result = processor._correct_decoded_token(2, tokens) + assert result == "" + + # Test token at index 3 (should correct to " "") + # Token 4 individually is "�" + # decode([3, 4]) = " "" (corrected!) + processor.logprobs = [{2: None}] # Add previous logprob + result = processor._correct_decoded_token(3, tokens) + assert result == ' "' + + +def test_verify_tokens_integration(): + """Integration test for _verify_tokens with real model. + + This test validates that _verify_tokens correctly identifies and + corrects tokens ending with the replacement character "�". + Uses facebook/opt-125m which is known to produce these issues. + """ + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=0, + enable_prefix_caching=False, + gpu_memory_utilization=0.15, + max_model_len=256, + ) + + # Use a prompt that triggers multi-byte UTF-8 issues + # Based on user's example: "In this example," + test_prompts = ["In this example,"] + + sampling_params = SamplingParams( + max_tokens=16, + temperature=0, + logprobs=0, + ) + + results = runner.llm.generate(test_prompts, sampling_params=sampling_params) + + # Verify that decoded tokens don't contain replacement characters + for result in results: + assert result.outputs[0].logprobs is not None + for logprob_dict in result.outputs[0].logprobs: + for token_id, logprob_info in logprob_dict.items(): + decoded_token = logprob_info.decoded_token + # Decoded tokens should not end with replacement character + # They should either be corrected or empty string + assert not decoded_token.endswith("�"), ( + f"Token {token_id} decoded to '{decoded_token}' which " + f"ends with replacement character" + ) + # Decoded tokens should not contain lone replacement characters + assert decoded_token != "�", ( + f"Token {token_id} is a lone replacement character" + ) + + +def test_utf8_edge_cases_with_real_model(): + """Test various UTF-8 edge cases with a real model. + + Tests prompts that are likely to trigger byte-fallback tokenization + and multi-byte UTF-8 splitting. + """ + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=1, + enable_prefix_caching=False, + gpu_memory_utilization=0.15, + max_model_len=256, + ) + + # Prompts with various multi-byte UTF-8 characters + test_prompts = [ + 'Smart quotes: "Hello"', # Curly quotes + "Em dash — test", # Em dash + "Ellipsis… continues", # Ellipsis + "Chinese: 你好", # Chinese characters + "Emoji: 😀 🎉", # Emojis + 'Mixed: "quoted" — with symbols', # Mixed + ] + + sampling_params = SamplingParams( + max_tokens=10, + temperature=0, + logprobs=1, + ) + + results = runner.llm.generate(test_prompts, sampling_params=sampling_params) + + for i, result in enumerate(results): + prompt = test_prompts[i] + assert result.outputs[0].logprobs is not None + + # Check that no decoded tokens end with replacement character + for logprob_dict in result.outputs[0].logprobs: + for token_id, logprob_info in logprob_dict.items(): + decoded_token = logprob_info.decoded_token + assert not decoded_token.endswith("�"), ( + f"Prompt: '{prompt}'\n" + f"Token {token_id} decoded to '{decoded_token}' which " + f"ends with replacement character" + ) + + +def test_correct_decoded_token_preserves_valid_tokens(): + """Test that valid tokens (not ending with �) are not modified. + + The _correct_decoded_token method should only be called for tokens + ending with "�", but this test verifies the broader _verify_tokens + logic doesn't affect valid tokens. + """ + runner = VllmRunner( + "facebook/opt-125m", + max_logprobs=2, + enable_prefix_caching=False, + gpu_memory_utilization=0.15, + max_model_len=256, + ) + + # Simple prompt with standard ASCII characters + test_prompts = ["Hello world, this is a test."] + + sampling_params = SamplingParams( + max_tokens=10, + temperature=0, + logprobs=2, + ) + + results = runner.llm.generate(test_prompts, sampling_params=sampling_params) + + for result in results: + assert result.outputs[0].logprobs is not None + + # All decoded tokens should be valid strings + for logprob_dict in result.outputs[0].logprobs: + for token_id, logprob_info in logprob_dict.items(): + decoded_token = logprob_info.decoded_token + # Valid tokens should be non-empty strings (or empty if corrected) + assert isinstance(decoded_token, str) + # Should not contain replacement character + assert "�" not in decoded_token + + @pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode)) @pytest.mark.parametrize( "model_setup", diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 599725b6d..64ac32312 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +from collections.abc import Iterable from dataclasses import dataclass from vllm.logger import init_logger @@ -88,11 +89,16 @@ class LogprobsProcessor: logprobs = logprobs_np.tolist() token_ids = token_ids_np.tolist() # Detokenize (non-incrementally). - decoded_tokens = ( - NONES - if self.tokenizer is None - else (convert_ids_list_to_tokens(self.tokenizer, token_ids)) - ) + decoded_tokens: list[str] | Iterable[None] + if self.tokenizer is None: + decoded_tokens = NONES + else: + decoded_tokens_list = convert_ids_list_to_tokens( + self.tokenizer, token_ids + ) + decoded_tokens = self._verify_tokens( + decoded_tokens_list=decoded_tokens_list, tokens=token_ids + ) # Sampler puts the sampled logprob in first. sampled_token_logprob = logprobs[0] @@ -126,37 +132,45 @@ class LogprobsProcessor: token_ids, logprobs, ranks = prompt_logprobs_tensors - # Detokenize non-incrementally. - # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] - decoded_tokens = ( - None - if self.tokenizer is None - else ( - convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist()) - ) - ) - # Recover shapes. num_prompt_tokens, num_logprobs = logprobs.shape + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + all_decoded_tokens: list[str] | None = ( + None + if self.tokenizer is None + else convert_ids_list_to_tokens( + self.tokenizer, token_ids.flatten().tolist() + ) + ) + # Pythonize the torch tensors. prompt_token_ranks = ranks.tolist() prompt_logprobs = logprobs.tolist() - token_ids = token_ids.tolist() + token_ids_list = token_ids.tolist() # Make Logprob for each position. for pos in range(num_prompt_tokens): - # Handle flattening. + # Handle flattening and UTF-8 correction per position offset = pos * num_logprobs offset_end = offset + num_logprobs - decoded_tokens_for_pos = ( - NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] - ) + + decoded_tokens_for_pos: list[str] | Iterable[None] + if all_decoded_tokens is None: + decoded_tokens_for_pos = NONES + else: + # Extract decoded tokens for this position + decoded_tokens_slice = all_decoded_tokens[offset:offset_end] + # Apply UTF-8 correction within this position's token boundaries + decoded_tokens_for_pos = self._verify_tokens( + decoded_tokens_list=decoded_tokens_slice, tokens=token_ids_list[pos] + ) # Update with the Logprob container for this pos. append_logprobs_for_next_position( self.prompt_logprobs, - token_ids[pos], + token_ids_list[pos], prompt_logprobs[pos], decoded_tokens_for_pos, prompt_token_ranks[pos], @@ -182,6 +196,48 @@ class LogprobsProcessor: self.prompt_logprobs = [] return plp + def _correct_decoded_token(self, idx: int, tokens: list[int]) -> str: + assert self.tokenizer is not None, "self.tokenizer should not be None" + + # try with prev token id in same list + if idx > 0: + possible_decoded_token = self.tokenizer.decode(tokens[idx - 1 : idx + 1]) + if not possible_decoded_token.endswith("�"): + return possible_decoded_token + # try with previous logprob token id + if self.logprobs: + latest_token_id = next(iter(self.logprobs[-1])) + + decode_ids = [latest_token_id] + if idx > 0: + decode_ids.extend(tokens[idx - 1 : idx + 1]) + else: + decode_ids.extend(tokens[idx : idx + 1]) + + possible_decoded_token = self.tokenizer.decode(decode_ids) + if not possible_decoded_token.endswith("�"): + return possible_decoded_token + + # by default return empty string + return "" + + def _verify_tokens( + self, decoded_tokens_list: list[str], tokens: list[int] + ) -> list[str]: + corrected_decoded_token_map = dict() + for idx, text in enumerate(decoded_tokens_list): + if text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + corrected_decoded_token_map[idx] = self._correct_decoded_token( + idx, tokens + ) + + for idx, text in corrected_decoded_token_map.items(): + decoded_tokens_list[idx] = text + + return decoded_tokens_list + def update_from_output(self, output: EngineCoreOutput) -> None: if output.new_logprobs is not None: self._update_sample_logprobs(output.new_logprobs)