[BugFix] Fix returned logprobs with spec decode + prefill chunking (#29216)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -521,8 +521,8 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
),
|
||||
@@ -541,7 +541,7 @@ def test_spec_decode_logprobs(
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
prompt = "Hello world"
|
||||
prompt = "Hello world " * 50
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
|
||||
)
|
||||
@@ -582,6 +582,9 @@ def test_spec_decode_logprobs(
|
||||
seed=42,
|
||||
logprobs_mode=logprobs_mode,
|
||||
gpu_memory_utilization=0.4,
|
||||
# Force prefill chunking
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=32,
|
||||
)
|
||||
spec_results = spec_llm.generate([prompt], sampling_params)
|
||||
# Collect logprobs outputs from spec decode LLM.
|
||||
@@ -597,6 +600,8 @@ def test_spec_decode_logprobs(
|
||||
# Per-token logprobs are expected to be the same.
|
||||
assert len(ref_logprobs) == len(spec_logprobs)
|
||||
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
|
||||
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
|
||||
assert math.isclose(
|
||||
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
|
||||
)
|
||||
assert ref_logprob.rank == spec_logprob.rank
|
||||
assert ref_logprob.decoded_token == spec_logprob.decoded_token
|
||||
|
||||
Reference in New Issue
Block a user