[Core] Switch Flat logprob control from environment variable to SamplingParams (#28914)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.logprobs import (
|
||||
FlatLogprobs,
|
||||
Logprob,
|
||||
@@ -14,24 +12,20 @@ from vllm.logprobs import (
|
||||
)
|
||||
|
||||
|
||||
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
||||
|
||||
prompt_logprobs = create_prompt_logprobs()
|
||||
def test_create_logprobs_non_flat() -> None:
|
||||
prompt_logprobs = create_prompt_logprobs(flat_logprobs=False)
|
||||
assert isinstance(prompt_logprobs, list)
|
||||
# Ensure first prompt position logprobs is None
|
||||
assert len(prompt_logprobs) == 1
|
||||
assert prompt_logprobs[0] is None
|
||||
|
||||
sample_logprobs = create_sample_logprobs()
|
||||
sample_logprobs = create_sample_logprobs(flat_logprobs=False)
|
||||
assert isinstance(sample_logprobs, list)
|
||||
assert len(sample_logprobs) == 0
|
||||
|
||||
|
||||
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
||||
|
||||
prompt_logprobs = create_prompt_logprobs()
|
||||
def test_create_logprobs_flat() -> None:
|
||||
prompt_logprobs = create_prompt_logprobs(flat_logprobs=True)
|
||||
assert isinstance(prompt_logprobs, FlatLogprobs)
|
||||
assert prompt_logprobs.start_indices == [0]
|
||||
assert prompt_logprobs.end_indices == [0]
|
||||
@@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert len(prompt_logprobs) == 1
|
||||
assert prompt_logprobs[0] == dict()
|
||||
|
||||
sample_logprobs = create_sample_logprobs()
|
||||
sample_logprobs = create_sample_logprobs(flat_logprobs=True)
|
||||
assert isinstance(sample_logprobs, FlatLogprobs)
|
||||
assert len(sample_logprobs.start_indices) == 0
|
||||
assert len(sample_logprobs.end_indices) == 0
|
||||
@@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert len(sample_logprobs) == 0
|
||||
|
||||
|
||||
def test_append_logprobs_for_next_position_none_flat(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
||||
logprobs = create_sample_logprobs()
|
||||
def test_append_logprobs_for_next_position_none_flat() -> None:
|
||||
logprobs = create_sample_logprobs(flat_logprobs=False)
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
token_ids=[1],
|
||||
@@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat(
|
||||
]
|
||||
|
||||
|
||||
def test_append_logprobs_for_next_position_flat(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
||||
logprobs = create_sample_logprobs()
|
||||
def test_append_logprobs_for_next_position_flat() -> None:
|
||||
logprobs = create_sample_logprobs(flat_logprobs=True)
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
token_ids=[1],
|
||||
|
||||
Reference in New Issue
Block a user