Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,9 +9,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.sample.utils import (
|
||||
BatchLogprobsComposition, BatchLogprobsSpecType,
|
||||
BatchLogprobsComposition,
|
||||
BatchLogprobsSpecType,
|
||||
assert_incr_detok_str_matches_non_incr_detok_str,
|
||||
compute_correct_cumulative_logprob, get_test_batch)
|
||||
compute_correct_cumulative_logprob,
|
||||
get_test_batch,
|
||||
)
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LogprobsMode
|
||||
|
||||
@@ -29,22 +32,23 @@ SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
# Parameterize APC
|
||||
params=[False, True])
|
||||
params=[False, True],
|
||||
)
|
||||
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
|
||||
with vllm_runner(
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
max_logprobs=7,
|
||||
# Very small number of batched tokens to ensure
|
||||
# that we test chunking.
|
||||
max_num_batched_tokens=16,
|
||||
max_num_seqs=16,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
#TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
enable_prefix_caching=request.param,
|
||||
gpu_memory_utilization=0.4, # up to 2 alive concurrently
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
max_logprobs=7,
|
||||
# Very small number of batched tokens to ensure
|
||||
# that we test chunking.
|
||||
max_num_batched_tokens=16,
|
||||
max_num_seqs=16,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
# TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
enable_prefix_caching=request.param,
|
||||
gpu_memory_utilization=0.4, # up to 2 alive concurrently
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
@@ -96,8 +100,8 @@ def _repeat_logprob_config(
|
||||
num_test_prompts = len(test_prompts)
|
||||
# Make sure there is a logprobs configuration for each test prompt
|
||||
logprob_prompt_logprob_list = list(
|
||||
itertools.islice(itertools.cycle(logprob_prompt_logprob_list),
|
||||
num_test_prompts))
|
||||
itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
|
||||
)
|
||||
# Now the number of prompts should match the number of sample params combos
|
||||
assert num_test_prompts == len(logprob_prompt_logprob_list)
|
||||
return logprob_prompt_logprob_list
|
||||
@@ -115,24 +119,28 @@ def _run_and_validate(
|
||||
do_apc: bool,
|
||||
) -> None:
|
||||
vllm_results = vllm_model.llm.generate(
|
||||
test_prompts, sampling_params=vllm_sampling_params)
|
||||
test_prompts, sampling_params=vllm_sampling_params
|
||||
)
|
||||
|
||||
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
|
||||
vllm_results, hf_logprobs, hf_outputs,
|
||||
logprob_prompt_logprob_list):
|
||||
|
||||
vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
|
||||
):
|
||||
# Extract request-level (prompt)logprobs config
|
||||
num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob
|
||||
|
||||
# Test whether sampled token output is consistent between vLLM and HF
|
||||
# vLLM prompt+completion should match HF output
|
||||
if temperature == 0.0:
|
||||
assert (vllm_result.prompt_token_ids +
|
||||
vllm_result.outputs[0].token_ids == hf_output[0])
|
||||
assert (
|
||||
vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
|
||||
== hf_output[0]
|
||||
)
|
||||
else:
|
||||
# Sampled tokens won't match if not greedy
|
||||
assert (vllm_result.prompt_token_ids == hf_output[0]
|
||||
[:len(vllm_result.prompt_token_ids)])
|
||||
assert (
|
||||
vllm_result.prompt_token_ids
|
||||
== hf_output[0][: len(vllm_result.prompt_token_ids)]
|
||||
)
|
||||
|
||||
# Validate sample logprobs
|
||||
if num_top_logprobs is not None:
|
||||
@@ -141,8 +149,9 @@ def _run_and_validate(
|
||||
# correct
|
||||
assert vllm_result.outputs[0].logprobs is not None
|
||||
assert len(vllm_result.outputs[0].logprobs) == max_tokens
|
||||
for logprobs, token_id in zip(vllm_result.outputs[0].logprobs,
|
||||
vllm_result.outputs[0].token_ids):
|
||||
for logprobs, token_id in zip(
|
||||
vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
|
||||
):
|
||||
assert logprobs is not None
|
||||
|
||||
# Confirm that the output token appears among the logprobs
|
||||
@@ -159,23 +168,26 @@ def _run_and_validate(
|
||||
if num_top_logprobs > 0:
|
||||
# We should have an entry for each of the topk ranks
|
||||
all_ranks = {lp.rank for lp in logprobs.values()}
|
||||
assert all(r in all_ranks
|
||||
for r in range(1, num_top_logprobs + 1))
|
||||
assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
|
||||
|
||||
output_text = vllm_result.outputs[0].text
|
||||
output_string_from_most_likely_tokens_lst: list[str] = []
|
||||
for top_logprobs in vllm_result.outputs[0].logprobs:
|
||||
top_logprob = next(iter(top_logprobs.values()))
|
||||
output_string_from_most_likely_tokens_lst.append(
|
||||
top_logprob.decoded_token)
|
||||
top_logprob.decoded_token
|
||||
)
|
||||
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens_lst)
|
||||
output_string_from_most_likely_tokens_lst
|
||||
)
|
||||
assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
output_text, output_string_from_most_likely_tokens,
|
||||
output_text,
|
||||
output_string_from_most_likely_tokens,
|
||||
"The output text from the top logprob for each token "
|
||||
"position should be the same as the output text in the "
|
||||
"result.")
|
||||
"result.",
|
||||
)
|
||||
|
||||
# Compare vLLM sample logprobs to HF
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
@@ -187,11 +199,12 @@ def _run_and_validate(
|
||||
logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
assert isinstance(
|
||||
sample_logprob.decoded_token,
|
||||
str), ("The token should be decoded by the time it is"
|
||||
" returned to the user.")
|
||||
rtol=1e-2,
|
||||
)
|
||||
assert isinstance(sample_logprob.decoded_token, str), (
|
||||
"The token should be decoded by the time it is"
|
||||
" returned to the user."
|
||||
)
|
||||
|
||||
# At this point we know the sample logprobs are correct for this
|
||||
# request. Validate that cumulative_logprob is actually the sum.
|
||||
@@ -201,7 +214,8 @@ def _run_and_validate(
|
||||
vllm_result.outputs[0].cumulative_logprob,
|
||||
compute_correct_cumulative_logprob(vllm_result.outputs[0]),
|
||||
atol=1e-6,
|
||||
rtol=1e-6)
|
||||
rtol=1e-6,
|
||||
)
|
||||
else:
|
||||
# Logprobs disabled for this request; should be None
|
||||
assert vllm_result.outputs[0].logprobs is None
|
||||
@@ -214,17 +228,17 @@ def _run_and_validate(
|
||||
assert vllm_result.prompt_logprobs[0] is None
|
||||
# - Prompt logprobs are returned for all indices in
|
||||
# the prompt
|
||||
assert len(vllm_result.prompt_logprobs) == len(
|
||||
vllm_result.prompt_token_ids)
|
||||
assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
|
||||
for prompt_logprobs, prompt_token_id in zip(
|
||||
vllm_result.prompt_logprobs[1:],
|
||||
vllm_result.prompt_token_ids[1:]):
|
||||
vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
|
||||
):
|
||||
assert prompt_logprobs is not None
|
||||
|
||||
# Confirm that the prompt token appears among the logprobs
|
||||
assert prompt_token_id in prompt_logprobs
|
||||
token_in_topk = prompt_logprobs[
|
||||
prompt_token_id].rank <= num_top_prompt_logprobs
|
||||
token_in_topk = (
|
||||
prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
|
||||
)
|
||||
|
||||
# If the prompt token is not included in the top K
|
||||
# logprob, it can return 1 more data
|
||||
@@ -236,8 +250,9 @@ def _run_and_validate(
|
||||
if num_top_prompt_logprobs > 0:
|
||||
# We should have an entry for each of the topk ranks
|
||||
all_ranks = {lp.rank for lp in prompt_logprobs.values()}
|
||||
assert all(r in all_ranks
|
||||
for r in range(1, num_top_prompt_logprobs + 1))
|
||||
assert all(
|
||||
r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
|
||||
)
|
||||
|
||||
# Compare prompt logprobs to HF
|
||||
# The first prompt logprob is always None, so we compare it from
|
||||
@@ -249,19 +264,24 @@ def _run_and_validate(
|
||||
logprob.logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
rtol=2e-2,
|
||||
)
|
||||
else:
|
||||
assert vllm_result.prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_logprobs_composition",
|
||||
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
|
||||
)
|
||||
@pytest.mark.parametrize("temperature", [0.0, 2.0])
|
||||
def test_get_logprobs_and_prompt_logprobs(
|
||||
hf_model, vllm_model,
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
temperature: float, example_prompts: list[str],
|
||||
monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
hf_model,
|
||||
vllm_model,
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
temperature: float,
|
||||
example_prompts: list[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test V1 Engine logprobs & prompt logprobs
|
||||
|
||||
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
|
||||
@@ -291,8 +311,9 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
|
||||
if do_apc and (temperature < 2.0
|
||||
or batch_logprobs_composition != SAMPLE_PROMPT):
|
||||
if do_apc and (
|
||||
temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT
|
||||
):
|
||||
# Skip some test-cases to save time.
|
||||
pytest.skip()
|
||||
test_prompts = example_prompts
|
||||
@@ -309,19 +330,21 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
|
||||
# Batch has mixed sample params
|
||||
# (different logprobs/prompt logprobs combos)
|
||||
logprob_prompt_logprob_list = get_test_batch(
|
||||
batch_logprobs_composition)
|
||||
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
|
||||
|
||||
# Ensure that each test prompt has a logprob config for testing
|
||||
logprob_prompt_logprob_list = _repeat_logprob_config(
|
||||
test_prompts, logprob_prompt_logprob_list)
|
||||
test_prompts, logprob_prompt_logprob_list
|
||||
)
|
||||
# Generate SamplingParams
|
||||
vllm_sampling_params = [
|
||||
SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984)
|
||||
SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984,
|
||||
)
|
||||
for num_lp, num_plp in logprob_prompt_logprob_list
|
||||
]
|
||||
for _ in range(2 if do_apc else 1):
|
||||
@@ -334,7 +357,8 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
do_apc=do_apc)
|
||||
do_apc=do_apc,
|
||||
)
|
||||
|
||||
|
||||
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -351,19 +375,18 @@ def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256)
|
||||
max_model_len=256,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"],
|
||||
sampling_params=bad_sampling_params)
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
def test_none_logprobs(vllm_model, example_prompts,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_none_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
|
||||
|
||||
Args:
|
||||
@@ -388,14 +411,12 @@ def test_none_logprobs(vllm_model, example_prompts,
|
||||
for i in range(len(results_logprobs_none)):
|
||||
# Check sample logprobs are None
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[
|
||||
0].cumulative_logprob is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
# Check prompt logprobs are None
|
||||
assert results_logprobs_none[i].prompt_logprobs is None
|
||||
|
||||
|
||||
def test_zero_logprobs(vllm_model, example_prompts,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Engine should return sampled token and prompt token logprobs
|
||||
|
||||
Args:
|
||||
@@ -406,12 +427,12 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
max_tokens = 5
|
||||
|
||||
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=0,
|
||||
prompt_logprobs=0,
|
||||
temperature=0.0)
|
||||
sampling_params_logprobs_zero = SamplingParams(
|
||||
max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
|
||||
)
|
||||
results_logprobs_zero = vllm_model.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero)
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero
|
||||
)
|
||||
|
||||
for i in range(len(results_logprobs_zero)):
|
||||
# Check that there is one sample logprob dict for each
|
||||
@@ -422,8 +443,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
|
||||
assert logprobs is not None
|
||||
assert len(sampled_token_ids) == len(logprobs)
|
||||
assert results_logprobs_zero[i].outputs[
|
||||
0].cumulative_logprob is not None
|
||||
assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
|
||||
# Check that there is one prompt logprob dict for each
|
||||
# prompt token
|
||||
assert prompt_logprobs is not None
|
||||
@@ -444,13 +464,15 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256)
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
|
||||
logprobs=-1,
|
||||
prompt_logprobs=-1)
|
||||
sampling_params_logprobs_all = SamplingParams(
|
||||
max_tokens=5, logprobs=-1, prompt_logprobs=-1
|
||||
)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all)
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all
|
||||
)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
|
||||
for i in range(len(results_logprobs_all)):
|
||||
@@ -466,13 +488,13 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test with LLM engine with different logprobs_mode.
|
||||
For logprobs, we should have non-positive values.
|
||||
For logits, we should expect at least one positive values.
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -483,10 +505,10 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.05,
|
||||
max_model_len=16,
|
||||
logprobs_mode=logprobs_mode)
|
||||
logprobs_mode=logprobs_mode,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"],
|
||||
sampling_params=vllm_sampling_params)
|
||||
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
|
||||
Reference in New Issue
Block a user