[Sampler] Introduce logprobs mode for logging (#21398)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from tests.v1.sample.utils import (
|
||||
assert_incr_detok_str_matches_non_incr_detok_str,
|
||||
compute_correct_cumulative_logprob, get_test_batch)
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LogprobsMode
|
||||
|
||||
from ...conftest import HfRunner, VllmRunner
|
||||
|
||||
@@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
# prompt token
|
||||
assert prompt_logprobs is not None
|
||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"logprobs_mode",
|
||||
["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"])
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test with LLM engine with different logprobs_mode.
|
||||
For logprobs, we should have non-positive values.
|
||||
For logits, we should expect at least one positive values.
|
||||
"""
|
||||
from vllm import LLM
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
llm = LLM(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=5,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.05,
|
||||
max_model_len=16,
|
||||
logprobs_mode=logprobs_mode)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"],
|
||||
sampling_params=vllm_sampling_params)
|
||||
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
for output in results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if "logprobs" in logprobs_mode:
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if "logits" in logprobs_mode:
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
|
||||
Reference in New Issue
Block a user