[BugFix] Fix returned logprobs with spec decode + prefill chunking (#29216)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-11-22 06:41:25 -08:00
committed by GitHub
parent 066209a045
commit d44a63c6d6
3 changed files with 22 additions and 15 deletions

View File

@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
),
marks=large_gpu_mark(min_gb=32),
),
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
"""
from vllm import LLM
prompt = "Hello world"
prompt = "Hello world " * 50
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
# Force prefill chunking
enable_chunked_prefill=True,
max_num_batched_tokens=32,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
@@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert math.isclose(
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token