[Sampler] Support returning all logprobs or logits (#21792)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-08-04 03:04:12 -07:00
committed by GitHub
parent fed5849d3f
commit 54de71d0df
6 changed files with 45 additions and 9 deletions

View File

@@ -429,6 +429,33 @@ def test_zero_logprobs(vllm_model, example_prompts,
assert len(prompt_token_ids) == len(prompt_logprobs)
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return all vocabulary logprobs
Args:
example_prompts: list of example prompts (test fixture)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
runner = VllmRunner(
"facebook/opt-125m",
max_logprobs=-1,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
logprobs=-1)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all)
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
for i in range(len(results_logprobs_all)):
logprobs = results_logprobs_all[i].outputs[0].logprobs
assert logprobs is not None
for logprob in logprobs:
assert len(logprob) == vocab_size
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])