[BugFix] Fix logprobs with spec decode and modified logits (#30846)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-12-18 19:58:28 -08:00
committed by GitHub
parent 7b43db210c
commit 2ac85a4544
2 changed files with 30 additions and 11 deletions

View File

@@ -547,6 +547,13 @@ def test_spec_decode_logprobs(
sampling_params = SamplingParams(
temperature=0, logprobs=top_logprobs, max_tokens=10, ignore_eos=False
)
penalty_sampling_params = SamplingParams(
temperature=0,
logprobs=top_logprobs,
max_tokens=10,
ignore_eos=False,
presence_penalty=-1.0,
)
method, model_name, spec_model_name = model_setup
max_model_len = 256
@@ -558,14 +565,17 @@ def test_spec_decode_logprobs(
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
enable_prefix_caching=False,
)
ref_results = ref_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
for results in ref_results:
for output in results.outputs:
for logprobs in output.logprobs:
ref_logprobs.extend(logprobs.values())
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@@ -587,14 +597,17 @@ def test_spec_decode_logprobs(
# Force prefill chunking
enable_chunked_prefill=True,
max_num_batched_tokens=32,
enable_prefix_caching=False,
)
spec_results = spec_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
for results in spec_results:
for output in results.outputs:
for logprobs in output.logprobs:
spec_logprobs.extend(logprobs.values())
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()