Files
vllm/tests/v1/sample/test_logprobs.py

1237 lines
45 KiB
Python
Raw Blame History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import math
from collections.abc import Generator
from typing import get_args
import pytest
import torch
from tests.utils import large_gpu_mark
from tests.v1.sample.utils import (
BatchLogprobsComposition,
BatchLogprobsSpecType,
assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob,
get_test_batch,
)
from vllm import SamplingParams
from vllm.config.model import LogprobsMode
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
from ...conftest import HfRunner, VllmRunner
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
DTYPE = "half"
NONE = BatchLogprobsComposition.NONE
SAMPLE = BatchLogprobsComposition.SAMPLE
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
# On ROCm, floating-point reductions in attention and GEMM kernels are
# non-associative and sensitive to batch geometry. The ref LLM (no spec
# decode, default scheduling) and the spec-decode LLM (chunked prefill,
# different effective batch sizes) follow different reduction orders,
# producing numerically divergent logprobs that get misattributed to
# spec-decode incorrectness.
#
# Force LLM instances into an identical, deterministic execution
# mode so the test isolates spec-decode correctness only:
ROCM_DETERMINISM_KWARGS: dict = (
dict(max_num_seqs=1, attention_backend="TRITON_ATTN")
if current_platform.is_rocm()
else {}
)
@pytest.fixture(
scope="module",
# Parameterize APC
params=[False, True],
)
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
with vllm_runner(
MODEL,
dtype=DTYPE,
max_logprobs=7,
# Very small number of batched tokens to ensure
# that we test chunking.
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enable_chunked_prefill=True,
enforce_eager=True,
# TODO: enable this once we support it for
# prompt logprobs.
enable_prefix_caching=request.param,
gpu_memory_utilization=0.4,
) as vllm_model:
yield vllm_model
@pytest.fixture(scope="module")
def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
with hf_runner(MODEL, dtype=DTYPE) as hf_model:
yield hf_model
def _repeat_logprob_config(
test_prompts,
logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> BatchLogprobsSpecType:
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
may-be-`None`) number of sample logprobs and
the optional number of prompt logprobs.
If more test prompts than logprob configs are
provided, the provided logprob configs are
tiled to match the number of test prompts.
If fewer test prompts than logprob configs
are provided, the list of logprob configs
is truncated to match the number of test
prompts.
Otherwise, the list of logprob configs
is returned as-is.
Args:
test_prompts: list of prompts under test
logprob_prompt_logprob_list: list of
(optional num sample logprob,
optional num prompt logprob)
tuples
Returns:
list of
(optional num sample logprob,optional num prompt logprob)
tuples which is either identical to
`logprob_prompt_logprob_list`, or else repeats
`logprob_prompt_logprob_list` enough times to match the
number of `test_prompts`, or else is truncated to match
the number of `test_prompts`
"""
num_test_prompts = len(test_prompts)
# Make sure there is a logprobs configuration for each test prompt
logprob_prompt_logprob_list = list(
itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
)
# Now the number of prompts should match the number of sample params combos
assert num_test_prompts == len(logprob_prompt_logprob_list)
return logprob_prompt_logprob_list
def _run_and_validate(
vllm_model: VllmRunner,
test_prompts: list[str],
vllm_sampling_params: SamplingParams,
hf_logprobs: list[list[torch.Tensor]],
hf_outputs: list[tuple[list[int], str]],
logprob_prompt_logprob_list: BatchLogprobsSpecType,
temperature: float,
max_tokens: int,
do_apc: bool,
) -> None:
vllm_results = vllm_model.llm.generate(
test_prompts, sampling_params=vllm_sampling_params
)
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
):
# Extract request-level (prompt)logprobs config
num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob
# Test whether sampled token output is consistent between vLLM and HF
# vLLM prompt+completion should match HF output
if temperature == 0.0:
assert (
vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
== hf_output[0]
)
else:
# Sampled tokens won't match if not greedy
assert (
vllm_result.prompt_token_ids
== hf_output[0][: len(vllm_result.prompt_token_ids)]
)
# Validate sample logprobs
if num_top_logprobs is not None:
assert num_top_logprobs is not None
# Confirm that the structure of the sample logprobs in the result is
# correct
assert vllm_result.outputs[0].logprobs is not None
assert len(vllm_result.outputs[0].logprobs) == max_tokens
for logprobs, token_id in zip(
vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
):
assert logprobs is not None
# Confirm that the output token appears among the logprobs
assert token_id in logprobs
token_in_topk = logprobs[token_id].rank <= num_top_logprobs
# If the output token is not included in the top K
# logprob, it can return 1 more data
if token_in_topk and num_top_logprobs != 0:
assert len(logprobs) == num_top_logprobs
else:
assert len(logprobs) == num_top_logprobs + 1
if num_top_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in logprobs.values()}
assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
output_text = vllm_result.outputs[0].text
output_string_from_most_likely_tokens_lst: list[str] = []
for top_logprobs in vllm_result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token
)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens_lst
)
assert_incr_detok_str_matches_non_incr_detok_str(
output_text,
output_string_from_most_likely_tokens,
"The output text from the top logprob for each token "
"position should be the same as the output text in the "
"result.",
)
# Compare vLLM sample logprobs to HF
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, sample_logprob in top_logprobs.items():
if temperature == 0.0 or i == 0:
logprob = sample_logprob.logprob
torch.testing.assert_close(
logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2,
)
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is"
" returned to the user."
)
# At this point we know the sample logprobs are correct for this
# request. Validate that cumulative_logprob is actually the sum.
# For each request, assert that the returned cumulative logprob
# matches the correct value, which is computed below.
torch.testing.assert_close(
vllm_result.outputs[0].cumulative_logprob,
compute_correct_cumulative_logprob(vllm_result.outputs[0]),
atol=1e-6,
rtol=1e-6,
)
else:
# Logprobs disabled for this request; should be None
assert vllm_result.outputs[0].logprobs is None
# Validate prompt logprobs
if num_top_prompt_logprobs is not None:
# Confirm that structure of prompt logprobs in result is correct
assert vllm_result.prompt_logprobs is not None
# - The first prompt logprob is always None
assert vllm_result.prompt_logprobs[0] is None
# - Prompt logprobs are returned for all indices in
# the prompt
assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
for prompt_logprobs, prompt_token_id in zip(
vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
):
assert prompt_logprobs is not None
# Confirm that the prompt token appears among the logprobs
assert prompt_token_id in prompt_logprobs
token_in_topk = (
prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
)
# If the prompt token is not included in the top K
# logprob, it can return 1 more data
if token_in_topk and num_top_prompt_logprobs != 0:
assert len(prompt_logprobs) == num_top_prompt_logprobs
else:
assert len(prompt_logprobs) == num_top_prompt_logprobs + 1
if num_top_prompt_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in prompt_logprobs.values()}
assert all(
r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
)
# Compare prompt logprobs to HF
# The first prompt logprob is always None, so we compare it from
# 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(
logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=2e-2,
rtol=2e-2,
)
else:
assert vllm_result.prompt_logprobs is None
@pytest.mark.parametrize(
"batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
)
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
hf_model,
vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float,
example_prompts: list[str],
) -> None:
"""Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
settings and validate that
* The generated logprobs and prompt logprobs are consistent with the
configuration settings, in terms of whether or not the logprobs
(of either type) were requested and how many were requested
* The generated logprobs are consistent with the generated tokens
* The generated (prompt)logprobs are consistent with HuggingFace
(prompt)logprobs, as a reference
batch_logprobs_composition controls the logprobs configurations for
requests in the batch under test.
APC tests run two test iterations so that cache hits occur.
To save time, only test one APC-enabled scenario
(sample & prompt logprobs enabled, temperature>0.0).
Args:
hf_model: HuggingFace reference model fixture
vllm_model: vLLM model fixture
batch_logprobs_composition: logprobs configuration for test batch
temperature: "temperature" sampling parameter
example_prompts: example prompt fixture
"""
vllm_config = vllm_model.llm.llm_engine.vllm_config
do_apc = vllm_config.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT):
# Skip some test-cases to save time.
pytest.skip()
test_prompts = example_prompts
max_tokens = 5
hf_outputs = hf_model.generate_greedy(
test_prompts,
max_tokens=max_tokens,
)
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list
)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(
max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984,
)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
_run_and_validate(
vllm_model=vllm_model,
test_prompts=test_prompts,
vllm_sampling_params=vllm_sampling_params,
hf_logprobs=hf_logprobs,
hf_outputs=hf_outputs,
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
max_tokens=max_tokens,
do_apc=do_apc,
)
def test_max_logprobs():
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation.
"""
with VllmRunner(
"facebook/opt-125m",
max_logprobs=1,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
) as runner:
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
"""
max_tokens = 5
sampling_params_logprobs_none = SamplingParams(
max_tokens=max_tokens,
logprobs=None,
prompt_logprobs=None,
temperature=0.0,
)
results_logprobs_none = vllm_model.llm.generate(
example_prompts,
sampling_params=sampling_params_logprobs_none,
)
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts):
"""Engine should return sampled token and prompt token logprobs
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
"""
max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(
max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
)
results_logprobs_zero = vllm_model.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_zero
)
for i in range(len(results_logprobs_zero)):
# Check that there is one sample logprob dict for each
# sample token
logprobs = results_logprobs_zero[i].outputs[0].logprobs
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
assert logprobs is not None
assert len(sampled_token_ids) == len(logprobs)
assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
# Check that there is one prompt logprob dict for each
# prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)
def test_all_logprobs(example_prompts):
"""Engine should return all vocabulary logprobs and prompt logprobs
Args:
example_prompts: list of example prompts (test fixture)
"""
with VllmRunner(
"facebook/opt-125m",
max_logprobs=-1,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
) as runner:
sampling_params_logprobs_all = SamplingParams(
max_tokens=5, logprobs=-1, prompt_logprobs=-1
)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all
)
vocab_size = runner.llm.llm_engine.model_config.get_vocab_size()
for i in range(len(results_logprobs_all)):
logprobs = results_logprobs_all[i].outputs[0].logprobs
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
assert logprobs is not None
for logprob in logprobs:
assert len(logprob) == vocab_size
assert prompt_logprobs is not None
assert prompt_logprobs[0] is None
for prompt_logprob in prompt_logprobs[1:]:
assert len(prompt_logprob) == vocab_size
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
def test_logprobs_mode(logprobs_mode: LogprobsMode):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from vllm import LLM
llm = LLM(
"facebook/opt-125m",
max_logprobs=5,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.05,
max_model_len=16,
logprobs_mode=logprobs_mode,
)
try:
vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
total_token_with_logprobs = 0
positive_values = 0
for output in results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
logprob = logprobs[token_id]
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
assert logprob.logprob <= 0
if logprob.logprob > 0:
positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs)
if logprobs_mode in ("raw_logits", "processed_logits"):
assert positive_values > 0
finally:
del llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
class TestCorrectDecodedToken:
"""Unit tests for _correct_decoded_token method in LogprobsProcessor.
This method handles UTF-8 decoding issues where incomplete byte sequences
result in the Unicode replacement character "<EFBFBD>" (U+FFFD). This commonly
happens with byte-fallback tokenization when multi-byte UTF-8 characters
are split across tokens.
The method signature is _correct_decoded_token(token_id, context_token_ids)
where token_id is the single token to correct and context_token_ids are
the preceding sampled tokens in sequential order.
"""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer for testing."""
from unittest.mock import Mock
tokenizer = Mock()
return tokenizer
@pytest.fixture
def processor(self, mock_tokenizer):
"""Create a LogprobsProcessor."""
from vllm.v1.engine.logprobs import LogprobsProcessor
processor = LogprobsProcessor(
tokenizer=mock_tokenizer,
logprobs=[],
prompt_logprobs=None,
cumulative_logprob=0.0,
num_logprobs=1,
num_prompt_logprobs=None,
)
return processor
def test_correction_with_context(self, processor):
"""Test correction using context from preceding sampled tokens.
Scenario: A byte-fallback token that completes a multi-byte
UTF-8 sequence when decoded with context.
"""
# Context is [101] (a preceding sampled token)
# Token 102 individually decodes to "<22>"
# decode([101, 102]) returns "valid" (complete sequence)
def mock_decode(ids):
if ids == [101, 102]:
return "hello valid"
if ids == [101]:
return "hello "
return "<EFBFBD>"
processor.tokenizer.decode.side_effect = mock_decode
result = processor._correct_decoded_token(102, [101])
assert result == "valid"
def test_correction_with_context_from_logprobs(self, processor):
"""Test correction using context from previous logprob entries.
Scenario: Token decoded with context from previously sampled
tokens completes a UTF-8 sequence.
"""
# Token 123 was previously sampled (in context)
def mock_decode(ids):
if ids == [123, 100]:
return 'hello "polarized"'
if ids == [123]:
return "hello "
return "<EFBFBD>"
processor.tokenizer.decode.side_effect = mock_decode
result = processor._correct_decoded_token(100, [123])
assert result == '"polarized"'
def test_correction_no_context(self, processor):
"""Test correction with no context available.
Should return empty string as fallback.
"""
processor.tokenizer.decode.return_value = "<EFBFBD>"
result = processor._correct_decoded_token(100, [])
assert result == ""
def test_correction_with_context_succeeds(self, processor):
"""Test correction with context from previously sampled tokens."""
def mock_decode(ids):
if ids == [123, 200]:
return "hello corrected"
if ids == [123]:
return "hello "
return "<EFBFBD>"
processor.tokenizer.decode.side_effect = mock_decode
result = processor._correct_decoded_token(200, [123])
assert result == "corrected"
def test_fallback_when_all_attempts_fail(self, processor):
"""Test fallback to empty string when no correction works."""
processor.tokenizer.decode.return_value = "still<EFBFBD>"
result = processor._correct_decoded_token(102, [100, 101])
assert result == ""
def test_increasing_context_window(self, processor):
"""Test that increasing context window finds the correction.
Scenario: 3-byte UTF-8 char. With 1 context token, still
incomplete. With 2 context tokens, completes the sequence.
"""
def mock_decode(ids):
# 1 context token: still incomplete
if ids == [81, 82]:
return "<EFBFBD>"
# 2 context tokens: complete
if ids == [80, 81, 82]:
return "\u201c"
# Context-only decodes
if ids == [81]:
return "<EFBFBD>"
if ids == [80, 81]:
return "<EFBFBD>"
return "<EFBFBD>"
processor.tokenizer.decode.side_effect = mock_decode
# Context has 2 preceding tokens [80, 81]
result = processor._correct_decoded_token(82, [80, 81])
assert result == "\u201c"
def test_multiple_consecutive_replacement_chars(self, processor):
"""Test handling of multiple consecutive replacement characters.
Scenario: Multi-byte sequence where intermediate bytes return
empty string and the final byte returns the complete character.
"""
processor.tokenizer.decode.return_value = "still<EFBFBD>"
# First byte with no useful context: returns ""
result1 = processor._correct_decoded_token(100, [50])
assert result1 == ""
# Second byte with same context: still returns ""
result2 = processor._correct_decoded_token(101, [50])
assert result2 == ""
def test_correction_with_multibyte_utf8(self, processor):
"""Test correction involving multi-byte UTF-8 characters.
Scenario: Byte-fallback tokenization splits curly quotes.
The last byte token should produce the complete character.
"""
def mock_decode(ids):
# Context [123] + first byte: completes to left curly quote
if ids == [123, 200]:
return "hello \u201c"
if ids == [123]:
return "hello "
# Context [123] + second byte: completes to right curly quote
if ids == [123, 201]:
return "hello \u201d"
return "\ufffd"
processor.tokenizer.decode.side_effect = mock_decode
# Each top-k token is corrected independently with same context
result1 = processor._correct_decoded_token(200, [123])
assert result1 == "\u201c"
result2 = processor._correct_decoded_token(201, [123])
assert result2 == "\u201d"
def test_topk_tokens_corrected_independently(self, processor):
"""Test that top-k alternatives at the same position are each
corrected independently using only sequential context, not
each other.
This is the core fix for issue #27300: when logprobs > 0,
alternative tokens must not be combined with each other.
"""
# Context: previously sampled token 50
context = [50]
def mock_decode(ids):
# Token 100 (sampled) with context
if ids == [50, 100]:
return "prefix \u201c"
# Token 200 (top-k alternative) with context
if ids == [50, 200]:
return "prefix \u2014"
# Context alone
if ids == [50]:
return "prefix "
return "\ufffd"
processor.tokenizer.decode.side_effect = mock_decode
# Both tokens at the same position use the SAME context [50]
result_sampled = processor._correct_decoded_token(100, context)
assert result_sampled == "\u201c"
result_alt = processor._correct_decoded_token(200, context)
assert result_alt == "\u2014"
def test_real_world_opt125m_scenario(self, mock_tokenizer):
"""Test the real-world scenario from the bug report.
Simulates the OPT-125m sequence where curly quotes are split
into byte-fallback tokens. Each token is corrected using only
the preceding sampled tokens as context.
"""
from vllm.v1.engine.logprobs import LogprobsProcessor
processor = LogprobsProcessor(
tokenizer=mock_tokenizer,
logprobs=[],
prompt_logprobs=None,
cumulative_logprob=0.0,
num_logprobs=1,
num_prompt_logprobs=None,
)
# Simulating: byte tokens 3, 4 form left curly quote "\u201c"
# byte tokens 8, 9 form right curly quote "\u201d"
def mock_decode(ids):
# Context decodes
if ids == [2]:
return " term"
if ids == [1, 2]:
return " the term"
if ids == [3]:
return "\ufffd"
if ids == [2, 3]:
return " term\ufffd"
if ids == [1, 2, 3]:
return " the term\ufffd"
# Token 4 with context [2, 3] -> completes left curly quote
if ids == [3, 4]:
return "\u201c"
if ids == [2, 3, 4]:
return " term\u201c"
# Context for right curly quote
if ids == [7]:
return "ized"
if ids == [7, 8]:
return "ized\ufffd"
if ids == [8, 9]:
return "\u201d"
if ids == [7, 8, 9]:
return "ized\u201d"
return "normal_text"
mock_tokenizer.decode.side_effect = mock_decode
# First byte (token 3) of left curly quote with no context
result = processor._correct_decoded_token(3, [])
assert result == ""
# First byte (token 3) with context [2] -> still incomplete
result = processor._correct_decoded_token(3, [2])
assert result == ""
# Second byte (token 4) of left curly quote with context [2, 3]
# Token 3 is byte-fallback, so clean context is [2] only.
# decode([2, 3, 4]) = " term\u201c", decode([2]) = " term"
# result = "\u201c"
result = processor._correct_decoded_token(4, [2, 3])
assert result == "\u201c"
# Second byte (token 9) of right curly quote with context [7, 8]
result = processor._correct_decoded_token(9, [7, 8])
assert result == "\u201d"
def test_byte_fallback_context_preserves_space(self, mock_tokenizer):
"""Test that text from byte-fallback context tokens is preserved.
In OPT-125m, token 44 = space + 2 bytes of curly quote.
When token 44 returns "" (incomplete), the space it carried
must be attributed to the completing token (48).
"""
from vllm.v1.engine.logprobs import LogprobsProcessor
processor = LogprobsProcessor(
tokenizer=mock_tokenizer,
logprobs=[],
prompt_logprobs=None,
cumulative_logprob=0.0,
num_logprobs=1,
num_prompt_logprobs=None,
)
def mock_decode(ids):
# Token 44 = space + 2 bytes (like OPT-125m's \u0120\u00e2\u0080)
if ids == [44]:
return " \ufffd"
if ids == [48]:
return "\ufffd"
# Together they form: space + left curly quote
if ids == [44, 48]:
return " \u201c"
# With preceding clean context
if ids == [1385]:
return " term"
if ids == [1385, 44]:
return " term \ufffd"
if ids == [1385, 44, 48]:
return " term \u201c"
return "\ufffd"
mock_tokenizer.decode.side_effect = mock_decode
# Token 44 with context [1385] -> still ends with replacement
result = processor._correct_decoded_token(44, [1385])
assert result == ""
# Token 48 with context [1385, 44]:
# Token 44 is byte-fallback, so clean context is [1385].
# decode([1385, 44, 48]) = " term \u201c"
# decode([1385]) = " term"
# result = " \u201c" (space preserved from token 44!)
result = processor._correct_decoded_token(48, [1385, 44])
assert result == " \u201c"
def test_verify_tokens_integration():
"""Integration test for _verify_tokens with real model.
This test validates that _verify_tokens correctly identifies and
corrects tokens ending with the replacement character "<EFBFBD>".
Uses facebook/opt-125m which is known to produce these issues.
"""
with VllmRunner(
"facebook/opt-125m",
max_logprobs=0,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
) as runner:
# Use a prompt that triggers multi-byte UTF-8 issues
# Based on user's example: "In this example,"
test_prompts = ["In this example,"]
sampling_params = SamplingParams(
max_tokens=16,
temperature=0,
logprobs=0,
)
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
# Verify that decoded tokens don't contain replacement characters
for result in results:
assert result.outputs[0].logprobs is not None
for logprob_dict in result.outputs[0].logprobs:
for token_id, logprob_info in logprob_dict.items():
decoded_token = logprob_info.decoded_token
# Decoded tokens should not end with replacement character
# They should either be corrected or empty string
assert not decoded_token.endswith("<EFBFBD>"), (
f"Token {token_id} decoded to '{decoded_token}' which "
f"ends with replacement character"
)
# Decoded tokens should not contain lone replacement characters
assert decoded_token != "<EFBFBD>", (
f"Token {token_id} is a lone replacement character"
)
def test_utf8_edge_cases_with_real_model():
"""Test various UTF-8 edge cases with a real model.
Tests prompts that are likely to trigger byte-fallback tokenization
and multi-byte UTF-8 splitting.
"""
with VllmRunner(
"facebook/opt-125m",
max_logprobs=1,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
) as runner:
# Prompts with various multi-byte UTF-8 characters
test_prompts = [
'Smart quotes: "Hello"', # Curly quotes
"Em dash — test", # Em dash
"Ellipsis… continues", # Ellipsis
"Chinese: 你好", # Chinese characters
"Emoji: 😀 🎉", # Emojis
'Mixed: "quoted" — with symbols', # Mixed
]
sampling_params = SamplingParams(
max_tokens=10,
temperature=0,
logprobs=1,
)
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
for i, result in enumerate(results):
prompt = test_prompts[i]
assert result.outputs[0].logprobs is not None
# Check that no decoded tokens end with replacement character
for logprob_dict in result.outputs[0].logprobs:
for token_id, logprob_info in logprob_dict.items():
decoded_token = logprob_info.decoded_token
assert not decoded_token.endswith("<EFBFBD>"), (
f"Prompt: '{prompt}'\n"
f"Token {token_id} decoded to '{decoded_token}' which "
f"ends with replacement character"
)
def test_correct_decoded_token_preserves_valid_tokens():
"""Test that valid tokens (not ending with <20>) are not modified.
The _correct_decoded_token method should only be called for tokens
ending with "<EFBFBD>", but this test verifies the broader _verify_tokens
logic doesn't affect valid tokens.
"""
with VllmRunner(
"facebook/opt-125m",
max_logprobs=2,
enable_prefix_caching=False,
gpu_memory_utilization=0.15,
max_model_len=256,
) as runner:
# Simple prompt with standard ASCII characters
test_prompts = ["Hello world, this is a test."]
sampling_params = SamplingParams(
max_tokens=10,
temperature=0,
logprobs=2,
)
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
for result in results:
assert result.outputs[0].logprobs is not None
# All decoded tokens should be valid strings
for logprob_dict in result.outputs[0].logprobs:
for token_id, logprob_info in logprob_dict.items():
decoded_token = logprob_info.decoded_token
# Valid tokens should be non-empty strings (or empty if corrected)
assert isinstance(decoded_token, str)
# Should not contain replacement character
assert "<EFBFBD>" not in decoded_token
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
@pytest.mark.parametrize(
"model_setup",
[
pytest.param(
(
"eagle",
"meta-llama/Llama-3.2-1B-Instruct",
{
"method": "eagle",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": 3,
},
0,
),
marks=large_gpu_mark(min_gb=32),
id="eagle0",
),
pytest.param(
(
"eagle",
"meta-llama/Llama-3.2-1B-Instruct",
{
"method": "eagle",
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
"num_speculative_tokens": 3,
},
3,
),
marks=large_gpu_mark(min_gb=32),
id="eagle3",
),
pytest.param(
(
"ngram",
"meta-llama/Llama-3.2-1B-Instruct",
{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
3,
),
marks=large_gpu_mark(min_gb=32),
id="ngram",
),
],
)
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, dict, int],
monkeypatch,
):
"""Spec decode logprobs should match those of the base model.
Runs the base model and spec decode model sequentially, ensuring
only one LLM instance is alive at a time to avoid GPU memory
contention. Both use identical chunked prefill settings and eager
mode to control for infrastructure differences.
Args:
logprobs_mode: logprobs mode.
model_setup: Tuple of (method, base model name,
speculative_config dict, top_logprobs).
monkeypatch: pytest fixture for setting env vars.
"""
from vllm import LLM
# The ROCm skinny GEMM kernels (gemm_kernels.cu) are
# non-deterministic across LLM instantiations due to persistent
# workgroup scheduling and wave-level shuffle reductions, which
# causes logprob differences that get misattributed to spec decode.
# Disable them so this test isolates spec decode correctness only.
# TODO(akaratza): Remove this workaround once the follow-up to
# https://github.com/vllm-project/vllm/pull/33493#issuecomment-3906083975
# lands with a determinism fix for wvSplitK kernels.
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
method, model_name, spec_config, top_logprobs = model_setup
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
)
penalty_sampling_params = SamplingParams(
temperature=0,
logprobs=top_logprobs,
max_tokens=10,
ignore_eos=False,
presence_penalty=-1.0,
)
max_model_len = 256
# Run base LLM.
ref_llm = LLM(
model=model_name,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
enable_prefix_caching=False,
**ROCM_DETERMINISM_KWARGS,
)
ref_results = ref_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for results in ref_results:
for output in results.outputs:
for logprobs in output.logprobs:
ref_logprobs.extend(logprobs.values())
del ref_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Run spec decode LLM.
# Add max_model_len to spec_config if not present
spec_config_with_len = {**spec_config, "max_model_len": max_model_len}
spec_llm = LLM(
model_name,
speculative_config=spec_config_with_len,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
# Force prefill chunking
enable_chunked_prefill=True,
max_num_batched_tokens=32,
enable_prefix_caching=False,
**ROCM_DETERMINISM_KWARGS,
)
spec_results = spec_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for results in spec_results:
for output in results.outputs:
for logprobs in output.logprobs:
spec_logprobs.extend(logprobs.values())
del spec_llm
torch.accelerator.empty_cache()
cleanup_dist_env_and_memory()
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
), (
f"Logprob mismatch: ref={ref_logprob.logprob} "
f"spec={spec_logprob.logprob} "
f"diff={abs(ref_logprob.logprob - spec_logprob.logprob)} "
f"(token={ref_logprob.decoded_token!r})"
)
assert ref_logprob.rank == spec_logprob.rank, (
f"Rank mismatch: ref={ref_logprob.rank} "
f"spec={spec_logprob.rank} "
f"(token={ref_logprob.decoded_token!r})"
)
assert ref_logprob.decoded_token == spec_logprob.decoded_token
def test_prompt_logprobs_with_chunking_and_preemption():
"""Test that prompt logprobs are correctly returned when using
both chunked prefill and preemption.
This test ensures that the num_prompt_logprobs tracking persists
across preemptions and prefill chunks.
"""
# Create prompts that will trigger chunking and preemption
prompts = [
"The following numbers of the sequence "
+ ", ".join(str(i) for i in range(10))
+ " are:",
"In one word, the capital of France is ",
] + [f"Tell me about the number {i}: " for i in range(32)]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=40,
min_tokens=20,
prompt_logprobs=2, # Request prompt logprobs
)
with VllmRunner(
"Qwen/Qwen3-0.6B",
max_model_len=512,
enable_chunked_prefill=True,
max_num_batched_tokens=48, # Force prefill chunking
num_gpu_blocks_override=32, # Force preemptions
disable_log_stats=False,
gpu_memory_utilization=0.25,
) as vllm_model:
metrics_before = vllm_model.llm.get_metrics()
# Generate with prompt logprobs using generate_w_logprobs which
# returns (output_ids, output_str, output_logprobs, prompt_logprobs)
outputs = vllm_model.generate_w_logprobs(
prompts, sampling_params=sampling_params, include_prompt_token_ids=True
)
# Verify that all outputs have prompt logprobs
for i, output in enumerate(outputs):
_, _, _, prompt_token_ids, prompt_logprobs = output
assert prompt_logprobs is not None and len(prompt_logprobs) > 0, (
f"Output {i} missing prompt logprobs"
)
assert len(prompt_logprobs) == len(prompt_token_ids), (
"Unexpected number of prompt logprob positions"
)
# Each position should have the requested number of logprobs
for pos, logprobs_dict in enumerate(prompt_logprobs):
if logprobs_dict is not None: # First token may be None
assert (
sampling_params.prompt_logprobs
<= len(logprobs_dict)
<= sampling_params.prompt_logprobs + 1
), (
f"Output {i} position {pos} has {len(logprobs_dict)} "
f"logprobs, expected {sampling_params.prompt_logprobs}"
)
# Check that we actually had preemptions
metrics_after = vllm_model.llm.get_metrics()
preemptions_before = next(
(m.value for m in metrics_before if m.name == "vllm:num_preemptions"), 0
)
preemptions_after = next(
(m.value for m in metrics_after if m.name == "vllm:num_preemptions"), 0
)
preemptions = preemptions_after - preemptions_before
assert preemptions > 0, "Test did not trigger any preemptions"
print(f"Test passed with {preemptions} preemptions")