Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,23 +10,31 @@ import pytest
import torch
from tests.utils import create_new_process_for_each_test
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
create_penalty_tensor,
create_prompt_tokens_tensor,
fake_apply_logitsprocs,
fake_update_logitsprocs_state)
from tests.v1.sample.utils import (
LogitsprocsTestFakes,
create_fake_logits,
create_penalty_tensor,
create_prompt_tokens_tensor,
fake_apply_logitsprocs,
fake_update_logitsprocs_state,
)
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
# yapf: disable
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
build_logitsprocs)
from vllm.v1.sample.logits_processor import (
BatchUpdate,
BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
build_logitsprocs,
)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata
@@ -49,9 +57,10 @@ LogitprocType = Union[type[LogitsProcessor], str]
class LogitsProcsRequestParams:
"""Encapsulates key params for a single request in a batch.
Params can be customized based on the enabled logitproc
"""
workload_index: int
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
out_tokens: list[int] # Output tokens required for min tokens test
@@ -64,14 +73,13 @@ class LogitsProcsRequestParams:
# Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values
# don't matter *for these tests* so use 0 as a dummy value
self.out_tokens = ([0] *
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))
self.prompt_tokens = []
self.params = _sampling_params_from_logitproc(logitproc_type)
def __str__(self):
"""For debugging"""
summ = ', '.join(f'{k}={v}' for k, v in vars(self).items())
summ = ", ".join(f"{k}={v}" for k, v in vars(self).items())
return f"MyClass({summ})"
@@ -86,12 +94,13 @@ def _generate_fake_sampling_metadata(
prompt_token_ids: list[list[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
np.random.randint(0, vocab_size, size=num_output_tokens).tolist()
)
prompt_token_ids.append(
np.random.randint(0,
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
np.random.randint(
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
).tolist()
)
logitsprocs = build_logitsprocs(
vllm_config=VllmConfig(),
device=device,
@@ -99,15 +108,16 @@ def _generate_fake_sampling_metadata(
is_pooling_model=False,
)
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
temperature=torch.full((batch_size,), 0.0),
all_greedy=True,
all_random=False,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
prompt_token_ids=create_prompt_tokens_tensor(
prompt_token_ids, vocab_size, device
),
output_token_ids=output_token_ids,
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
@@ -115,7 +125,8 @@ def _generate_fake_sampling_metadata(
no_penalties=True,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=logitsprocs)
logitsprocs=logitsprocs,
)
return fake_sampling_metadata
@@ -127,15 +138,15 @@ def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low
sampling_metadata = _generate_fake_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
)
return LogitsprocsTestFakes(
logits=fake_logits,
sampling_metadata=sampling_metadata,
)
def _sampling_params_from_logitproc(
logitproc_type: LogitprocType) -> SamplingParams:
def _sampling_params_from_logitproc(logitproc_type: LogitprocType) -> SamplingParams:
"""Customize request SamplingParams for a specified logitproc"""
# SamplingParams for req with no logitproc
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
@@ -150,7 +161,7 @@ def _generate_mixed_logitsprocs_batch_params(
) -> list[LogitsProcsRequestParams]:
"""Define key params for a batch of requests with a different
logitproc enabled per request.
The batch will have `reqs_per_logitproc` repeats for all
`logitsprocs_types` under test, including the case where
no logitsproc is enabled. The batch is randomly shuffled. The
@@ -173,7 +184,8 @@ def _generate_mixed_logitsprocs_batch_params(
return [
LogitsProcsRequestParams(
workload_index=idx,
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc])
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc],
)
for idx, pdx in enumerate(batch_perm)
]
@@ -185,10 +197,12 @@ def _raise_error_invalid(
step_idx: int,
err_cls: type[Exception] = ValueError,
) -> None:
raise err_cls(f"Validation failed for step={step_idx}, "
f"batch_index={batch_index}, "
f"workload_index={request_params.workload_index}, "
f"req_params={request_params}. Reason: {msg_suffix}")
raise err_cls(
f"Validation failed for step={step_idx}, "
f"batch_index={batch_index}, "
f"workload_index={request_params.workload_index}, "
f"req_params={request_params}. Reason: {msg_suffix}"
)
def _logit_bias_params(kwargs: dict) -> None:
@@ -208,8 +222,7 @@ def _logit_bias_validate(
) -> None:
"""Validate logit bias logitproc applied correctly"""
logit_bias = request_params.params.logit_bias
logits_old = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
logits_old = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()
logits_new = logits_new[batch_index].cpu()
for token_id in range(VOCAB_SIZE):
logit_old_value = logits_old[token_id]
@@ -218,22 +231,28 @@ def _logit_bias_validate(
bias_value = logit_bias[token_id]
exp_value = bias_value + logit_old_value
if logit_new_value != pytest.approx(exp_value):
_raise_error_invalid(msg_suffix=(
f"Biased token {token_id} logit value {logit_new_value} "
f"does not match expected value {exp_value} "
f"given bias {bias_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
_raise_error_invalid(
msg_suffix=(
f"Biased token {token_id} logit value {logit_new_value} "
f"does not match expected value {exp_value} "
f"given bias {bias_value}"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
else:
if logit_new_value != pytest.approx(logit_old_value):
_raise_error_invalid(msg_suffix=(
f"Unbiased token {token_id} logit value {logit_new_value} "
f"does not match expected value {logit_old_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
_raise_error_invalid(
msg_suffix=(
f"Unbiased token {token_id} logit value {logit_new_value} "
f"does not match expected value {logit_old_value}"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
def _min_p_params(kwargs: dict) -> None:
@@ -259,26 +278,27 @@ def _min_p_validate(
msg_suffix="Invalid: dominant token 0 masked (-inf)",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
step_idx=step_idx,
)
else:
if request_params.params.min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: non-dominant token {token_id} not masked",
msg_suffix=f"Invalid: non-dominant token {token_id} not masked",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
step_idx=step_idx,
)
else:
# No masking when min_p is 0
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: token {token_id} masked when min_p=0.0",
msg_suffix=f"Invalid: token {token_id} masked when min_p=0.0",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
step_idx=step_idx,
)
def _min_tokens_params(kwargs: dict) -> None:
@@ -303,7 +323,8 @@ def _min_tokens_validate(
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
mt_lp: MinTokensLogitsProcessor = next(
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor))
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor)
)
assert isinstance(mt_lp, MinTokensLogitsProcessor)
min_tok = mt_lp.min_toks.get(batch_index, None)
@@ -312,38 +333,50 @@ def _min_tokens_validate(
(_, out_tok, all_stop_token_ids) = min_tok
num_out_tokens = len(out_tok)
if num_out_tokens != ref_num_out_tokens:
_raise_error_invalid(msg_suffix=(
"Number of output tokens in min-token logit processor "
f"request metadata ({num_out_tokens}) does not match "
f"reference ({ref_num_out_tokens})."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
_raise_error_invalid(
msg_suffix=(
"Number of output tokens in min-token logit processor "
f"request metadata ({num_out_tokens}) does not match "
f"reference ({ref_num_out_tokens})."
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
if ref_all_stop_token_ids != all_stop_token_ids:
_raise_error_invalid(msg_suffix=(
"Stop token ids do not match reference; all_stop_token_ids: "
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
f"{sorted(ref_all_stop_token_ids)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
_raise_error_invalid(
msg_suffix=(
"Stop token ids do not match reference; all_stop_token_ids: "
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
f"{sorted(ref_all_stop_token_ids)}"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
if min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min reached, but batch "
"index is recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
_raise_error_invalid(
msg_suffix=(
"Expected min-tokens request with min reached, but batch "
"index is recognized by min-tokens logits processor."
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError,
)
elif not min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min not reached, but batch "
"index is not recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
_raise_error_invalid(
msg_suffix=(
"Expected min-tokens request with min not reached, but batch "
"index is not recognized by min-tokens logits processor."
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError,
)
# Validate min-token logits
for token_id in range(VOCAB_SIZE):
@@ -351,21 +384,27 @@ def _min_tokens_validate(
if token_id in ref_all_stop_token_ids and not min_reached:
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} is a stop token and "
"the sequence has not reached min length, "
"but the token is not masked "
f"(logit={logits_for_token})"),
msg_suffix=(
f"Token {token_id} is a stop token and "
"the sequence has not reached min length, "
"but the token is not masked "
f"(logit={logits_for_token})"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
step_idx=step_idx,
)
else:
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} should not be masked but "
f"is (output len={ref_num_out_tokens})"),
msg_suffix=(
f"Token {token_id} should not be masked but "
f"is (output len={ref_num_out_tokens})"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
step_idx=step_idx,
)
def _none_validate(
@@ -377,52 +416,58 @@ def _none_validate(
step_idx: int,
) -> None:
"""Validate that no logits processors are applied"""
logits = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
logits = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()
ref_logits = logits_new[batch_index]
if not torch.all(ref_logits == logits):
mismatch_toks = (ref_logits
!= logits).nonzero(as_tuple=True)[0].tolist()
mismatch_toks = (ref_logits != logits).nonzero(as_tuple=True)[0].tolist()
mismatch_strs = []
for token in mismatch_toks:
val = float(logits[token])
ref_val = float(ref_logits[token])
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
_raise_error_invalid(msg_suffix=(
f"Unexpected modification of logits: {','.join(mismatch_strs)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
_raise_error_invalid(
msg_suffix=(
f"Unexpected modification of logits: {','.join(mismatch_strs)}"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
class LogitsprocTestHelpers(NamedTuple):
"""Supports setting up and validating logitsprocs unit tests."""
eval_fxn: Callable
gen_request_fxn: Optional[Callable] = None
logitsprocs_test_mapping = {
STR_NO_LOGITPROC:
LogitsprocTestHelpers(eval_fxn=_none_validate),
LogitBiasLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params,
eval_fxn=_logit_bias_validate),
MinPLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_p_params,
eval_fxn=_min_p_validate),
MinTokensLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
eval_fxn=_min_tokens_validate),
STR_NO_LOGITPROC: LogitsprocTestHelpers(eval_fxn=_none_validate),
LogitBiasLogitsProcessor: LogitsprocTestHelpers(
gen_request_fxn=_logit_bias_params, eval_fxn=_logit_bias_validate
),
MinPLogitsProcessor: LogitsprocTestHelpers(
gen_request_fxn=_min_p_params, eval_fxn=_min_p_validate
),
MinTokensLogitsProcessor: LogitsprocTestHelpers(
gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate
),
}
def _get_test_cases() -> list[list[str]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types = list(logitsprocs_test_mapping.keys())
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
for logitproc_type in logitsprocs_types
if logitproc_type != STR_NO_LOGITPROC
] + [logitsprocs_types]
return (
[[STR_NO_LOGITPROC]]
+ [
[logitproc_type, STR_NO_LOGITPROC]
for logitproc_type in logitsprocs_types
if logitproc_type != STR_NO_LOGITPROC
]
+ [logitsprocs_types]
)
def _generate_fake_step_update(
@@ -440,11 +485,18 @@ def _generate_fake_step_update(
# Other 50%: add a limited number of reqs (less than the number
# of workload reqs remaining, less than an arbitrary max)
# If no workload reqs remain: 100% of steps have 0 adds
num_step_add = random.choice([
0,
random.randint(1, min(max_add_remove_per_step,
workload_reqs_remaining))
]) if workload_reqs_remaining else 0
num_step_add = (
random.choice(
[
0,
random.randint(
1, min(max_add_remove_per_step, workload_reqs_remaining)
),
]
)
if workload_reqs_remaining
else 0
)
# 50% of steps: remove no requests
# Other 50%: remove a limited number of reqs (less than the number
@@ -452,9 +504,11 @@ def _generate_fake_step_update(
# If persistent batch is empty: 100% of steps have 0 removals until
# more requests are added. Assume that removed requests are always
# drawn from the current batch, before new adds
num_step_remove = random.choice([
0, random.randint(1, min(max_add_remove_per_step, batch_size))
]) if batch_size else 0
num_step_remove = (
random.choice([0, random.randint(1, min(max_add_remove_per_step, batch_size))])
if batch_size
else 0
)
num_step_add_replace = min(num_step_add, num_step_remove)
@@ -463,23 +517,34 @@ def _generate_fake_step_update(
batch_update_builder.removed_append(removal)
# Get added requests from workload
for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]:
for add_req_params in workload_params[wdx : (wdx + num_step_add_replace)]:
# Replace as many removed requests as possible with added requests
add_remove_idx = batch_update_builder.pop_removed()
batch_update_builder.added.append(
(add_remove_idx, add_req_params.params,
add_req_params.prompt_tokens, add_req_params.out_tokens))
(
add_remove_idx,
add_req_params.params,
add_req_params.prompt_tokens,
add_req_params.out_tokens,
)
)
persistent_batch[add_remove_idx] = add_req_params
# Append remaining added requests to end of batch
add_reqs_append = workload_params[(wdx +
num_step_add_replace):(wdx +
num_step_add)]
batch_update_builder.added.extend([
(adx + batch_size, add_req_params.params, add_req_params.prompt_tokens,
add_req_params.out_tokens)
for adx, add_req_params in enumerate(add_reqs_append)
])
add_reqs_append = workload_params[
(wdx + num_step_add_replace) : (wdx + num_step_add)
]
batch_update_builder.added.extend(
[
(
adx + batch_size,
add_req_params.params,
add_req_params.prompt_tokens,
add_req_params.out_tokens,
)
for adx, add_req_params in enumerate(add_reqs_append)
]
)
persistent_batch.extend(add_reqs_append)
pre_condense_batch_size = len(persistent_batch)
wdx += num_step_add # Update workload offset
@@ -488,8 +553,10 @@ def _generate_fake_step_update(
last_nonempty_index = pre_condense_batch_size - 1
condensed_to_idxs = set()
while batch_update_builder.removed:
if (last_nonempty_index in batch_update_builder.removed
or last_nonempty_index in condensed_to_idxs):
if (
last_nonempty_index in batch_update_builder.removed
or last_nonempty_index in condensed_to_idxs
):
last_nonempty_index -= 1
continue
# last_nonempty_index is the highest persistent batch index that was
@@ -504,11 +571,10 @@ def _generate_fake_step_update(
# move last_nonempty_index -> first_empty_index
batch_update_builder.pop_removed()
condensed_to_idxs.add(first_empty_index)
persistent_batch[first_empty_index] = persistent_batch[
last_nonempty_index]
persistent_batch[first_empty_index] = persistent_batch[last_nonempty_index]
batch_update_builder.moved.append(
(last_nonempty_index, first_empty_index,
MoveDirectionality.UNIDIRECTIONAL))
(last_nonempty_index, first_empty_index, MoveDirectionality.UNIDIRECTIONAL)
)
last_nonempty_index -= 1
@@ -524,18 +590,21 @@ def _generate_fake_step_update(
k = random.randint(0, condensed_batch_size // 2)
idxs = list(range(condensed_batch_size))
random.shuffle(idxs)
swaps = [
tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)
]
batch_update_builder.moved.extend([
(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps
])
swaps = [tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)]
batch_update_builder.moved.extend(
[(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps]
)
for adx, bdx in swaps:
persistent_batch[adx], persistent_batch[bdx] = persistent_batch[
bdx], persistent_batch[adx]
persistent_batch[adx], persistent_batch[bdx] = (
persistent_batch[bdx],
persistent_batch[adx],
)
return (batch_update_builder.get_and_reset(condensed_batch_size), wdx,
workload_size - wdx)
return (
batch_update_builder.get_and_reset(condensed_batch_size),
wdx,
workload_size - wdx,
)
def _assert_valid(
@@ -550,8 +619,10 @@ def _assert_valid(
# Trivial case of empty persistent batch
assert len(persistent_batch) == 0
if logits_w_lp.shape[0] != 0:
raise ValueError("Fake persistent batch is empty but logitsprocs "
f"output batch has shape {logits_w_lp.shape}")
raise ValueError(
"Fake persistent batch is empty but logitsprocs "
f"output batch has shape {logits_w_lp.shape}"
)
return
# Validate logits for each fake request
@@ -560,36 +631,40 @@ def _assert_valid(
# Invoke the appropriate validation function for
# the logitproc employed by this request
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
fxn(test_fakes=test_fakes,
fxn(
test_fakes=test_fakes,
persistent_batch=persistent_batch,
logits_new=logits_w_lp,
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
step_idx=step_idx,
)
@create_new_process_for_each_test()
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
def test_logitsprocs(device: str, reqs_per_logitproc: int,
logitsprocs_under_test: list[str]):
def test_logitsprocs(
device: str, reqs_per_logitproc: int, logitsprocs_under_test: list[str]
):
random.seed(40)
torch.set_default_device(device)
# Define a shuffled batch of requests which individually use a different
# logitproc, or no logitproc at all
workload_params = _generate_mixed_logitsprocs_batch_params(
reqs_per_logitproc=reqs_per_logitproc,
logitsprocs_types=logitsprocs_under_test)
reqs_per_logitproc=reqs_per_logitproc, logitsprocs_types=logitsprocs_under_test
)
workload_size = len(workload_params)
# Create fake test data structures for testing.
test_fakes = _generate_test_fakes(workload_size, device)
wdx = 0 # Next request index in workload to add
persistent_batch: list[LogitsProcsRequestParams] = [
] # Persistent batch state, as list of workload indices
persistent_batch: list[
LogitsProcsRequestParams
] = [] # Persistent batch state, as list of workload indices
# Generate fake removed request indices from current persistent
# batch before adds

View File

@@ -7,32 +7,44 @@ from typing import Union
import pytest
from tests.utils import create_new_process_for_each_test
# yapf: disable
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MODEL_NAME,
POOLING_MODEL_NAME, TEMP_GREEDY,
CustomLogitprocSource,
DummyLogitsProcessor,
WrappedPerReqLogitsProcessor,
dummy_module)
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS,
MODEL_NAME,
POOLING_MODEL_NAME,
TEMP_GREEDY,
CustomLogitprocSource,
DummyLogitsProcessor,
WrappedPerReqLogitsProcessor,
dummy_module,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from tests.v1.logits_processors.utils import prompts
# yapf: enable
from vllm import LLM, SamplingParams
from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS,
LogitsProcessor)
from vllm.v1.sample.logits_processor import (
STR_POOLING_REJECTS_LOGITSPROCS,
LogitsProcessor,
)
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=TEMP_GREEDY,
max_tokens=MAX_TOKENS,
extra_args={DUMMY_LOGITPROC_ARG: 128}),
SamplingParams(
temperature=TEMP_GREEDY,
max_tokens=MAX_TOKENS,
extra_args={DUMMY_LOGITPROC_ARG: 128},
),
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
SamplingParams(temperature=TEMP_GREEDY,
max_tokens=MAX_TOKENS,
extra_args={DUMMY_LOGITPROC_ARG: 67}),
SamplingParams(
temperature=TEMP_GREEDY,
max_tokens=MAX_TOKENS,
extra_args={DUMMY_LOGITPROC_ARG: 67},
),
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
]
@@ -49,7 +61,7 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
2. Server has *not* loaded dummy logitproc; test that all requests
behave as if logitproc is *not* operating (output matches reference
`LLM` output.)
Args:
kwargs: `LLM` constructor kwargs
logitproc_loaded: server has loaded dummy logitproc if True
@@ -73,7 +85,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
# Validate outputs
for bdx, (out_lp, out_ref, params) in enumerate(
zip(outputs_logitproc, outputs_ref, sampling_params_list)):
zip(outputs_logitproc, outputs_ref, sampling_params_list)
):
lp_toks = out_lp.outputs[0].token_ids
if logitproc_loaded and params.extra_args:
# This request exercises custom logitproc; validate that logitproc
@@ -81,8 +94,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
target_token = params.extra_args[DUMMY_LOGITPROC_ARG]
if not all(x == target_token for x in lp_toks):
raise AssertionError(
f"Request {bdx} generated {lp_toks}, should all be "
f"{target_token}")
f"Request {bdx} generated {lp_toks}, should all be {target_token}"
)
else:
# This request does not exercise custom logitproc (or custom
# logitproc is not enabled on this server); validate against
@@ -90,16 +103,15 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
ref_toks = out_ref.outputs[0].token_ids
if lp_toks != ref_toks:
raise AssertionError(
f"Request {bdx} generated {lp_toks}, should match "
f"{ref_toks}")
f"Request {bdx} generated {lp_toks}, should match {ref_toks}"
)
@create_new_process_for_each_test()
@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource))
def test_custom_logitsprocs(monkeypatch,
logitproc_source: CustomLogitprocSource):
def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource):
"""Test offline Python interface for passing custom logitsprocs
Construct an `LLM` instance which loads a custom logitproc that has a
well-defined behavior (mask out all tokens except one `target_token`)
@@ -118,7 +130,7 @@ def test_custom_logitsprocs(monkeypatch,
instance output
* Logitproc passed in via {entrypoint, class object, fully-qualified class
name (FQCN)} - test that dummy logitproc is utilized correctly when
provided via any of these three possible sources
provided via any of these three possible sources
Args:
monkeypatch: for setting env vars
@@ -142,6 +154,7 @@ def test_custom_logitsprocs(monkeypatch,
# Scenario: vLLM loads a logitproc from a preconfigured entrypoint
# To that end, mock a dummy logitproc entrypoint
import importlib.metadata
importlib.metadata.entry_points = fake_entry_points # type: ignore
# fork is required for workers to see entrypoint patch
@@ -165,7 +178,7 @@ def test_custom_logitsprocs(monkeypatch,
@create_new_process_for_each_test()
def test_custom_logitsprocs_req(monkeypatch):
"""Test passing request-level logits processor to offline Python interface
Wrap a request-level logits processor to create a batch level logits
processor that has a well-defined behavior (mask out all tokens except one
`target_token`)
@@ -190,18 +203,23 @@ def test_custom_logitsprocs_req(monkeypatch):
# Test that logitproc info is passed to workers
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
random.seed(40)
_run_test({"logits_processors": [WrappedPerReqLogitsProcessor]},
logitproc_loaded=True)
_run_test(
{"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True
)
@create_new_process_for_each_test()
@pytest.mark.parametrize("logitproc_source", [
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
CustomLogitprocSource.LOGITPROC_SOURCE_FQCN,
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
])
@pytest.mark.parametrize(
"logitproc_source",
[
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
CustomLogitprocSource.LOGITPROC_SOURCE_FQCN,
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
],
)
def test_pooling_rejects_custom_logitsprocs(
monkeypatch, logitproc_source: CustomLogitprocSource):
monkeypatch, logitproc_source: CustomLogitprocSource
):
"""Validate that vLLM engine initialization properly rejects custom
logitsprocs when the model is a pooling model.
@@ -233,6 +251,7 @@ def test_pooling_rejects_custom_logitsprocs(
# Patch in dummy logitproc entrypoint
import importlib.metadata
importlib.metadata.entry_points = fake_entry_points # type: ignore
# fork is required for entrypoint patch to be visible to workers,
@@ -245,10 +264,15 @@ def test_pooling_rejects_custom_logitsprocs(
gpu_memory_utilization=0.1,
)
# Require that no logitsprocs have been loaded
assert sum([
1 for _ in llm.llm_engine.model_executor.driver_worker.worker.
model_runner.input_batch.logitsprocs.all
]) == 0
assert (
sum(
[
1
for _ in llm.llm_engine.model_executor.driver_worker.worker.model_runner.input_batch.logitsprocs.all
]
)
== 0
)
return
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}

View File

@@ -10,16 +10,20 @@ import openai
import pytest
import pytest_asyncio
from tests.utils import (RemoteOpenAIServerCustom,
create_new_process_for_each_test)
from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_test
# yapf: disable
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MODEL_NAME,
TEMP_GREEDY, dummy_module)
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS,
MODEL_NAME,
TEMP_GREEDY,
dummy_module,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from tests.v1.logits_processors.utils import prompts
# yapf: enable
@@ -33,11 +37,12 @@ def _server_with_logitproc_entrypoint(
# Patch `entry_points` to inject logitproc entrypoint
import importlib.metadata
importlib.metadata.entry_points = fake_entry_points # type: ignore
from vllm.entrypoints.cli import main
# fork is required for workers to see entrypoint patch
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
if env_dict is not None:
os.environ.update(env_dict)
@@ -55,10 +60,11 @@ def _server_with_logitproc_module(
# Patch `modules` to inject dummy logitproc module
from vllm.entrypoints.cli import main
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
# fork is required for workers to see entrypoint patch
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
if env_dict is not None:
os.environ.update(env_dict)
@@ -80,8 +86,9 @@ def default_server_args():
]
@pytest.fixture(scope="function",
params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]])
@pytest.fixture(
scope="function", params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]
)
def server(default_server_args, request, monkeypatch):
"""Consider two server configurations:
(1) --logits-processors cli arg specifies dummy logits processor via fully-
@@ -102,8 +109,7 @@ def server(default_server_args, request, monkeypatch):
args = default_server_args
_server_fxn = _server_with_logitproc_entrypoint
with RemoteOpenAIServerCustom(MODEL_NAME, args,
_server_fxn) as remote_server:
with RemoteOpenAIServerCustom(MODEL_NAME, args, _server_fxn) as remote_server:
yield remote_server
@@ -133,7 +139,7 @@ api_keyword_args = {
)
async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
"""Test custom logitsprocs when starting OpenAI server from CLI
Launch vLLM OpenAI-compatible server, configured to load a custom logitproc
that has a well-defined behavior (mask out all tokens except one
`target_token`).
@@ -157,9 +163,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
# For requests which activate the dummy logitproc, choose one of
# two `target_token` values which are known not to be EOS tokens
request_keyword_args["extra_body"] = {
"vllm_xargs": {
DUMMY_LOGITPROC_ARG: target_token
}
"vllm_xargs": {DUMMY_LOGITPROC_ARG: target_token}
}
batch = await client.completions.create(
model=model_name,
@@ -173,8 +177,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
choices: openai.types.CompletionChoice = batch.choices
toks = choices[0].logprobs.tokens
if not all([x == toks[0] for x in toks]):
raise AssertionError(
f"Generated {toks} should all be {toks[0]}")
raise AssertionError(f"Generated {toks} should all be {toks[0]}")
# Alternate whether to activate dummy logitproc for each request
use_dummy_logitproc = not use_dummy_logitproc

View File

@@ -10,10 +10,13 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP,
AdapterLogitsProcessor,
BatchUpdate, LogitsProcessor,
RequestLogitsProcessor)
from vllm.v1.sample.logits_processor import (
LOGITSPROCS_GROUP,
AdapterLogitsProcessor,
BatchUpdate,
LogitsProcessor,
RequestLogitsProcessor,
)
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
logger = init_logger(__name__)
@@ -30,6 +33,7 @@ DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
class CustomLogitprocSource(Enum):
"""How to source a logitproc for testing purposes"""
LOGITPROC_SOURCE_NONE = auto() # No custom logitproc
LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint
LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN)
@@ -48,8 +52,9 @@ prompts = [
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
self.req_info: dict[int, int] = {}
def is_argmax_invariant(self) -> bool:
@@ -60,8 +65,8 @@ class DummyLogitsProcessor(LogitsProcessor):
process_dict_updates(
self.req_info,
batch_update,
lambda params, _, __: params.extra_args and
(params.extra_args.get("target_token")),
lambda params, _, __: params.extra_args
and (params.extra_args.get("target_token")),
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
@@ -69,16 +74,16 @@ class DummyLogitsProcessor(LogitsProcessor):
return logits
# Save target values before modification
cols = torch.tensor(list(self.req_info.values()),
dtype=torch.long,
device=logits.device)
rows = torch.tensor(list(self.req_info.keys()),
dtype=torch.long,
device=logits.device)
cols = torch.tensor(
list(self.req_info.values()), dtype=torch.long, device=logits.device
)
rows = torch.tensor(
list(self.req_info.keys()), dtype=torch.long, device=logits.device
)
values_to_keep = logits[rows, cols].clone()
# Mask all but target tokens
logits[rows] = float('-inf')
logits[rows] = float("-inf")
logits[rows, cols] = values_to_keep
return logits
@@ -154,14 +159,17 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
Returns:
`Callable` request logits processor, or None
"""
target_token: Optional[
Any] = params.extra_args and params.extra_args.get("target_token")
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.", target_token)
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)