[Sampler] Introduce logprobs mode for logging (#21398)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang
2025-07-23 01:39:25 -07:00
committed by GitHub
parent 23637dcdef
commit accac82928
7 changed files with 83 additions and 13 deletions

View File

@@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch)
from vllm import SamplingParams
from vllm.config import LogprobsMode
from ...conftest import HfRunner, VllmRunner
@@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
# prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)
@pytest.mark.parametrize(
"logprobs_mode",
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
llm = LLM(
"facebook/opt-125m",
max_logprobs=5,
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.05,
max_model_len=16,
logprobs_mode=logprobs_mode)
vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"],
sampling_params=vllm_sampling_params)
total_token_with_logprobs = 0
positive_values = 0
for output in results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
logprob = logprobs[token_id]
if "logprobs" in logprobs_mode:
assert logprob.logprob <= 0
if logprob.logprob > 0:
positive_values = positive_values + 1
total_token_with_logprobs = total_token_with_logprobs + 1
assert total_token_with_logprobs >= len(results[0].outputs)
if "logits" in logprobs_mode:
assert positive_values > 0
del llm