[Core] Switch Flat logprob control from environment variable to SamplingParams (#28914)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
Jialin Ouyang
2025-11-18 18:10:02 -08:00
committed by GitHub
parent da94c7c0eb
commit 40b6b38f2c
6 changed files with 33 additions and 41 deletions

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.logprobs import (
FlatLogprobs,
Logprob,
@@ -14,24 +12,20 @@ from vllm.logprobs import (
)
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
prompt_logprobs = create_prompt_logprobs()
def test_create_logprobs_non_flat() -> None:
prompt_logprobs = create_prompt_logprobs(flat_logprobs=False)
assert isinstance(prompt_logprobs, list)
# Ensure first prompt position logprobs is None
assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] is None
sample_logprobs = create_sample_logprobs()
sample_logprobs = create_sample_logprobs(flat_logprobs=False)
assert isinstance(sample_logprobs, list)
assert len(sample_logprobs) == 0
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
prompt_logprobs = create_prompt_logprobs()
def test_create_logprobs_flat() -> None:
prompt_logprobs = create_prompt_logprobs(flat_logprobs=True)
assert isinstance(prompt_logprobs, FlatLogprobs)
assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0]
@@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] == dict()
sample_logprobs = create_sample_logprobs()
sample_logprobs = create_sample_logprobs(flat_logprobs=True)
assert isinstance(sample_logprobs, FlatLogprobs)
assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0
@@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0
def test_append_logprobs_for_next_position_none_flat(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
logprobs = create_sample_logprobs()
def test_append_logprobs_for_next_position_none_flat() -> None:
logprobs = create_sample_logprobs(flat_logprobs=False)
append_logprobs_for_next_position(
logprobs,
token_ids=[1],
@@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat(
]
def test_append_logprobs_for_next_position_flat(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
logprobs = create_sample_logprobs()
def test_append_logprobs_for_next_position_flat() -> None:
logprobs = create_sample_logprobs(flat_logprobs=True)
append_logprobs_for_next_position(
logprobs,
token_ids=[1],