Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,9 +9,12 @@ import pytest
import torch
from tests.v1.sample.utils import (
BatchLogprobsComposition, BatchLogprobsSpecType,
BatchLogprobsComposition,
BatchLogprobsSpecType,
assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch)
compute_correct_cumulative_logprob,
get_test_batch,
)
from vllm import SamplingParams
from vllm.config import LogprobsMode
@@ -29,22 +32,23 @@ SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
@pytest.fixture(
scope="module",
# Parameterize APC
params=[False, True])
params=[False, True],
)
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
with vllm_runner(
MODEL,
dtype=DTYPE,
max_logprobs=7,
# Very small number of batched tokens to ensure
# that we test chunking.
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enforce_eager=True,
#TODO: enable this once we support it for
# prompt logprobs.
enable_prefix_caching=request.param,
gpu_memory_utilization=0.4, # up to 2 alive concurrently
MODEL,
dtype=DTYPE,
max_logprobs=7,
# Very small number of batched tokens to ensure
# that we test chunking.
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enforce_eager=True,
# TODO: enable this once we support it for
# prompt logprobs.
enable_prefix_caching=request.param,
gpu_memory_utilization=0.4, # up to 2 alive concurrently
) as vllm_model:
yield vllm_model
@@ -96,8 +100,8 @@ def _repeat_logprob_config(
num_test_prompts = len(test_prompts)
# Make sure there is a logprobs configuration for each test prompt
logprob_prompt_logprob_list = list(
itertools.islice(itertools.cycle(logprob_prompt_logprob_list),
num_test_prompts))
itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts)
)
# Now the number of prompts should match the number of sample params combos
assert num_test_prompts == len(logprob_prompt_logprob_list)
return logprob_prompt_logprob_list
@@ -115,24 +119,28 @@ def _run_and_validate(
do_apc: bool,
) -> None:
vllm_results = vllm_model.llm.generate(
test_prompts, sampling_params=vllm_sampling_params)
test_prompts, sampling_params=vllm_sampling_params
)
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
vllm_results, hf_logprobs, hf_outputs,
logprob_prompt_logprob_list):
vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list
):
# Extract request-level (prompt)logprobs config
num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob
# Test whether sampled token output is consistent between vLLM and HF
# vLLM prompt+completion should match HF output
if temperature == 0.0:
assert (vllm_result.prompt_token_ids +
vllm_result.outputs[0].token_ids == hf_output[0])
assert (
vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids
== hf_output[0]
)
else:
# Sampled tokens won't match if not greedy
assert (vllm_result.prompt_token_ids == hf_output[0]
[:len(vllm_result.prompt_token_ids)])
assert (
vllm_result.prompt_token_ids
== hf_output[0][: len(vllm_result.prompt_token_ids)]
)
# Validate sample logprobs
if num_top_logprobs is not None:
@@ -141,8 +149,9 @@ def _run_and_validate(
# correct
assert vllm_result.outputs[0].logprobs is not None
assert len(vllm_result.outputs[0].logprobs) == max_tokens
for logprobs, token_id in zip(vllm_result.outputs[0].logprobs,
vllm_result.outputs[0].token_ids):
for logprobs, token_id in zip(
vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids
):
assert logprobs is not None
# Confirm that the output token appears among the logprobs
@@ -159,23 +168,26 @@ def _run_and_validate(
if num_top_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in logprobs.values()}
assert all(r in all_ranks
for r in range(1, num_top_logprobs + 1))
assert all(r in all_ranks for r in range(1, num_top_logprobs + 1))
output_text = vllm_result.outputs[0].text
output_string_from_most_likely_tokens_lst: list[str] = []
for top_logprobs in vllm_result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token)
top_logprob.decoded_token
)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens_lst)
output_string_from_most_likely_tokens_lst
)
assert_incr_detok_str_matches_non_incr_detok_str(
output_text, output_string_from_most_likely_tokens,
output_text,
output_string_from_most_likely_tokens,
"The output text from the top logprob for each token "
"position should be the same as the output text in the "
"result.")
"result.",
)
# Compare vLLM sample logprobs to HF
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
@@ -187,11 +199,12 @@ def _run_and_validate(
logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(
sample_logprob.decoded_token,
str), ("The token should be decoded by the time it is"
" returned to the user.")
rtol=1e-2,
)
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is"
" returned to the user."
)
# At this point we know the sample logprobs are correct for this
# request. Validate that cumulative_logprob is actually the sum.
@@ -201,7 +214,8 @@ def _run_and_validate(
vllm_result.outputs[0].cumulative_logprob,
compute_correct_cumulative_logprob(vllm_result.outputs[0]),
atol=1e-6,
rtol=1e-6)
rtol=1e-6,
)
else:
# Logprobs disabled for this request; should be None
assert vllm_result.outputs[0].logprobs is None
@@ -214,17 +228,17 @@ def _run_and_validate(
assert vllm_result.prompt_logprobs[0] is None
# - Prompt logprobs are returned for all indices in
# the prompt
assert len(vllm_result.prompt_logprobs) == len(
vllm_result.prompt_token_ids)
assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids)
for prompt_logprobs, prompt_token_id in zip(
vllm_result.prompt_logprobs[1:],
vllm_result.prompt_token_ids[1:]):
vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:]
):
assert prompt_logprobs is not None
# Confirm that the prompt token appears among the logprobs
assert prompt_token_id in prompt_logprobs
token_in_topk = prompt_logprobs[
prompt_token_id].rank <= num_top_prompt_logprobs
token_in_topk = (
prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs
)
# If the prompt token is not included in the top K
# logprob, it can return 1 more data
@@ -236,8 +250,9 @@ def _run_and_validate(
if num_top_prompt_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in prompt_logprobs.values()}
assert all(r in all_ranks
for r in range(1, num_top_prompt_logprobs + 1))
assert all(
r in all_ranks for r in range(1, num_top_prompt_logprobs + 1)
)
# Compare prompt logprobs to HF
# The first prompt logprob is always None, so we compare it from
@@ -249,19 +264,24 @@ def _run_and_validate(
logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=2e-2,
rtol=2e-2)
rtol=2e-2,
)
else:
assert vllm_result.prompt_logprobs is None
@pytest.mark.parametrize("batch_logprobs_composition",
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
@pytest.mark.parametrize(
"batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]
)
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
hf_model, vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float, example_prompts: list[str],
monkeypatch: pytest.MonkeyPatch) -> None:
hf_model,
vllm_model,
batch_logprobs_composition: BatchLogprobsComposition,
temperature: float,
example_prompts: list[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
@@ -291,8 +311,9 @@ def test_get_logprobs_and_prompt_logprobs(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0
or batch_logprobs_composition != SAMPLE_PROMPT):
if do_apc and (
temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT
):
# Skip some test-cases to save time.
pytest.skip()
test_prompts = example_prompts
@@ -309,19 +330,21 @@ def test_get_logprobs_and_prompt_logprobs(
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(
batch_logprobs_composition)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
test_prompts, logprob_prompt_logprob_list
)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
SamplingParams(
max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984,
)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
@@ -334,7 +357,8 @@ def test_get_logprobs_and_prompt_logprobs(
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
max_tokens=max_tokens,
do_apc=do_apc)
do_apc=do_apc,
)
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
@@ -351,19 +375,18 @@ def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
max_model_len=256,
)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"],
sampling_params=bad_sampling_params)
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
def test_none_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
@@ -388,14 +411,12 @@ def test_none_logprobs(vllm_model, example_prompts,
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[
0].cumulative_logprob is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts,
monkeypatch: pytest.MonkeyPatch):
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
"""Engine should return sampled token and prompt token logprobs
Args:
@@ -406,12 +427,12 @@ def test_zero_logprobs(vllm_model, example_prompts,
m.setenv("VLLM_USE_V1", "1")
max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
logprobs=0,
prompt_logprobs=0,
temperature=0.0)
sampling_params_logprobs_zero = SamplingParams(
max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
)
results_logprobs_zero = vllm_model.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_zero)
example_prompts, sampling_params=sampling_params_logprobs_zero
)
for i in range(len(results_logprobs_zero)):
# Check that there is one sample logprob dict for each
@@ -422,8 +443,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
assert logprobs is not None
assert len(sampled_token_ids) == len(logprobs)
assert results_logprobs_zero[i].outputs[
0].cumulative_logprob is not None
assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
# Check that there is one prompt logprob dict for each
# prompt token
assert prompt_logprobs is not None
@@ -444,13 +464,15 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15,
max_model_len=256)
max_model_len=256,
)
sampling_params_logprobs_all = SamplingParams(max_tokens=5,
logprobs=-1,
prompt_logprobs=-1)
sampling_params_logprobs_all = SamplingParams(
max_tokens=5, logprobs=-1, prompt_logprobs=-1
)
results_logprobs_all = runner.llm.generate(
example_prompts, sampling_params=sampling_params_logprobs_all)
example_prompts, sampling_params=sampling_params_logprobs_all
)
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
for i in range(len(results_logprobs_all)):
@@ -466,13 +488,13 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
def test_logprobs_mode(logprobs_mode: LogprobsMode,
monkeypatch: pytest.MonkeyPatch):
def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch):
"""Test with LLM engine with different logprobs_mode.
For logprobs, we should have non-positive values.
For logits, we should expect at least one positive values.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -483,10 +505,10 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
# 2 other llms alive during whole session
gpu_memory_utilization=0.05,
max_model_len=16,
logprobs_mode=logprobs_mode)
logprobs_mode=logprobs_mode,
)
vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"],
sampling_params=vllm_sampling_params)
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
total_token_with_logprobs = 0
positive_values = 0