Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,23 +10,31 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
|
||||
create_penalty_tensor,
|
||||
create_prompt_tokens_tensor,
|
||||
fake_apply_logitsprocs,
|
||||
fake_update_logitsprocs_state)
|
||||
from tests.v1.sample.utils import (
|
||||
LogitsprocsTestFakes,
|
||||
create_fake_logits,
|
||||
create_penalty_tensor,
|
||||
create_prompt_tokens_tensor,
|
||||
fake_apply_logitsprocs,
|
||||
fake_update_logitsprocs_state,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
# yapf: disable
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
|
||||
LogitBiasLogitsProcessor,
|
||||
LogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor,
|
||||
MoveDirectionality,
|
||||
build_logitsprocs)
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
BatchUpdate,
|
||||
BatchUpdateBuilder,
|
||||
LogitBiasLogitsProcessor,
|
||||
LogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor,
|
||||
MoveDirectionality,
|
||||
build_logitsprocs,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
@@ -49,9 +57,10 @@ LogitprocType = Union[type[LogitsProcessor], str]
|
||||
|
||||
class LogitsProcsRequestParams:
|
||||
"""Encapsulates key params for a single request in a batch.
|
||||
|
||||
|
||||
Params can be customized based on the enabled logitproc
|
||||
"""
|
||||
|
||||
workload_index: int
|
||||
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
|
||||
out_tokens: list[int] # Output tokens required for min tokens test
|
||||
@@ -64,14 +73,13 @@ class LogitsProcsRequestParams:
|
||||
# Number of output tokens is randomly 0 or twice the min-tokens
|
||||
# threshold which will be used in testing. Output token values
|
||||
# don't matter *for these tests* so use 0 as a dummy value
|
||||
self.out_tokens = ([0] *
|
||||
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
|
||||
self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))
|
||||
self.prompt_tokens = []
|
||||
self.params = _sampling_params_from_logitproc(logitproc_type)
|
||||
|
||||
def __str__(self):
|
||||
"""For debugging"""
|
||||
summ = ', '.join(f'{k}={v}' for k, v in vars(self).items())
|
||||
summ = ", ".join(f"{k}={v}" for k, v in vars(self).items())
|
||||
return f"MyClass({summ})"
|
||||
|
||||
|
||||
@@ -86,12 +94,13 @@ def _generate_fake_sampling_metadata(
|
||||
prompt_token_ids: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
output_token_ids.append(
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist()
|
||||
)
|
||||
prompt_token_ids.append(
|
||||
np.random.randint(0,
|
||||
vocab_size,
|
||||
size=np.random.randint(
|
||||
1, MAX_NUM_PROMPT_TOKENS)).tolist())
|
||||
np.random.randint(
|
||||
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
|
||||
).tolist()
|
||||
)
|
||||
logitsprocs = build_logitsprocs(
|
||||
vllm_config=VllmConfig(),
|
||||
device=device,
|
||||
@@ -99,15 +108,16 @@ def _generate_fake_sampling_metadata(
|
||||
is_pooling_model=False,
|
||||
)
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
temperature=torch.full((batch_size,), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
prompt_token_ids=create_prompt_tokens_tensor(
|
||||
prompt_token_ids, vocab_size, device
|
||||
),
|
||||
output_token_ids=output_token_ids,
|
||||
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
||||
@@ -115,7 +125,8 @@ def _generate_fake_sampling_metadata(
|
||||
no_penalties=True,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=logitsprocs)
|
||||
logitsprocs=logitsprocs,
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
@@ -127,15 +138,15 @@ def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
|
||||
fake_logits[i, 0] = 10.0 # High logit for first token
|
||||
fake_logits[i, 1:] = 1e-2 # Others remain low
|
||||
sampling_metadata = _generate_fake_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
return LogitsprocsTestFakes(
|
||||
logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _sampling_params_from_logitproc(
|
||||
logitproc_type: LogitprocType) -> SamplingParams:
|
||||
def _sampling_params_from_logitproc(logitproc_type: LogitprocType) -> SamplingParams:
|
||||
"""Customize request SamplingParams for a specified logitproc"""
|
||||
# SamplingParams for req with no logitproc
|
||||
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
|
||||
@@ -150,7 +161,7 @@ def _generate_mixed_logitsprocs_batch_params(
|
||||
) -> list[LogitsProcsRequestParams]:
|
||||
"""Define key params for a batch of requests with a different
|
||||
logitproc enabled per request.
|
||||
|
||||
|
||||
The batch will have `reqs_per_logitproc` repeats for all
|
||||
`logitsprocs_types` under test, including the case where
|
||||
no logitsproc is enabled. The batch is randomly shuffled. The
|
||||
@@ -173,7 +184,8 @@ def _generate_mixed_logitsprocs_batch_params(
|
||||
return [
|
||||
LogitsProcsRequestParams(
|
||||
workload_index=idx,
|
||||
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc])
|
||||
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc],
|
||||
)
|
||||
for idx, pdx in enumerate(batch_perm)
|
||||
]
|
||||
|
||||
@@ -185,10 +197,12 @@ def _raise_error_invalid(
|
||||
step_idx: int,
|
||||
err_cls: type[Exception] = ValueError,
|
||||
) -> None:
|
||||
raise err_cls(f"Validation failed for step={step_idx}, "
|
||||
f"batch_index={batch_index}, "
|
||||
f"workload_index={request_params.workload_index}, "
|
||||
f"req_params={request_params}. Reason: {msg_suffix}")
|
||||
raise err_cls(
|
||||
f"Validation failed for step={step_idx}, "
|
||||
f"batch_index={batch_index}, "
|
||||
f"workload_index={request_params.workload_index}, "
|
||||
f"req_params={request_params}. Reason: {msg_suffix}"
|
||||
)
|
||||
|
||||
|
||||
def _logit_bias_params(kwargs: dict) -> None:
|
||||
@@ -208,8 +222,7 @@ def _logit_bias_validate(
|
||||
) -> None:
|
||||
"""Validate logit bias logitproc applied correctly"""
|
||||
logit_bias = request_params.params.logit_bias
|
||||
logits_old = (
|
||||
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
|
||||
logits_old = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()
|
||||
logits_new = logits_new[batch_index].cpu()
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
logit_old_value = logits_old[token_id]
|
||||
@@ -218,22 +231,28 @@ def _logit_bias_validate(
|
||||
bias_value = logit_bias[token_id]
|
||||
exp_value = bias_value + logit_old_value
|
||||
if logit_new_value != pytest.approx(exp_value):
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
f"Biased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {exp_value} "
|
||||
f"given bias {bias_value}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Biased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {exp_value} "
|
||||
f"given bias {bias_value}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
else:
|
||||
if logit_new_value != pytest.approx(logit_old_value):
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
f"Unbiased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {logit_old_value}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Unbiased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {logit_old_value}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
def _min_p_params(kwargs: dict) -> None:
|
||||
@@ -259,26 +278,27 @@ def _min_p_validate(
|
||||
msg_suffix="Invalid: dominant token 0 masked (-inf)",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
step_idx=step_idx,
|
||||
)
|
||||
else:
|
||||
if request_params.params.min_p > 0.0:
|
||||
# Non-dominant tokens should be masked when min_p > 0
|
||||
if logits_for_token != -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=
|
||||
f"Invalid: non-dominant token {token_id} not masked",
|
||||
msg_suffix=f"Invalid: non-dominant token {token_id} not masked",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
step_idx=step_idx,
|
||||
)
|
||||
else:
|
||||
# No masking when min_p is 0
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=
|
||||
f"Invalid: token {token_id} masked when min_p=0.0",
|
||||
msg_suffix=f"Invalid: token {token_id} masked when min_p=0.0",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
def _min_tokens_params(kwargs: dict) -> None:
|
||||
@@ -303,7 +323,8 @@ def _min_tokens_validate(
|
||||
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
|
||||
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
|
||||
mt_lp: MinTokensLogitsProcessor = next(
|
||||
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor))
|
||||
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor)
|
||||
)
|
||||
assert isinstance(mt_lp, MinTokensLogitsProcessor)
|
||||
min_tok = mt_lp.min_toks.get(batch_index, None)
|
||||
|
||||
@@ -312,38 +333,50 @@ def _min_tokens_validate(
|
||||
(_, out_tok, all_stop_token_ids) = min_tok
|
||||
num_out_tokens = len(out_tok)
|
||||
if num_out_tokens != ref_num_out_tokens:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Number of output tokens in min-token logit processor "
|
||||
f"request metadata ({num_out_tokens}) does not match "
|
||||
f"reference ({ref_num_out_tokens})."),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Number of output tokens in min-token logit processor "
|
||||
f"request metadata ({num_out_tokens}) does not match "
|
||||
f"reference ({ref_num_out_tokens})."
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
if ref_all_stop_token_ids != all_stop_token_ids:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Stop token ids do not match reference; all_stop_token_ids: "
|
||||
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
|
||||
f"{sorted(ref_all_stop_token_ids)}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Stop token ids do not match reference; all_stop_token_ids: "
|
||||
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
|
||||
f"{sorted(ref_all_stop_token_ids)}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
if min_reached:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Expected min-tokens request with min reached, but batch "
|
||||
"index is recognized by min-tokens logits processor."),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError)
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Expected min-tokens request with min reached, but batch "
|
||||
"index is recognized by min-tokens logits processor."
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError,
|
||||
)
|
||||
|
||||
elif not min_reached:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Expected min-tokens request with min not reached, but batch "
|
||||
"index is not recognized by min-tokens logits processor."),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError)
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
"Expected min-tokens request with min not reached, but batch "
|
||||
"index is not recognized by min-tokens logits processor."
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError,
|
||||
)
|
||||
|
||||
# Validate min-token logits
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
@@ -351,21 +384,27 @@ def _min_tokens_validate(
|
||||
if token_id in ref_all_stop_token_ids and not min_reached:
|
||||
if logits_for_token != -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(f"Token {token_id} is a stop token and "
|
||||
"the sequence has not reached min length, "
|
||||
"but the token is not masked "
|
||||
f"(logit={logits_for_token})"),
|
||||
msg_suffix=(
|
||||
f"Token {token_id} is a stop token and "
|
||||
"the sequence has not reached min length, "
|
||||
"but the token is not masked "
|
||||
f"(logit={logits_for_token})"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
step_idx=step_idx,
|
||||
)
|
||||
else:
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(f"Token {token_id} should not be masked but "
|
||||
f"is (output len={ref_num_out_tokens})"),
|
||||
msg_suffix=(
|
||||
f"Token {token_id} should not be masked but "
|
||||
f"is (output len={ref_num_out_tokens})"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
def _none_validate(
|
||||
@@ -377,52 +416,58 @@ def _none_validate(
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate that no logits processors are applied"""
|
||||
logits = (
|
||||
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
|
||||
logits = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()
|
||||
ref_logits = logits_new[batch_index]
|
||||
if not torch.all(ref_logits == logits):
|
||||
mismatch_toks = (ref_logits
|
||||
!= logits).nonzero(as_tuple=True)[0].tolist()
|
||||
mismatch_toks = (ref_logits != logits).nonzero(as_tuple=True)[0].tolist()
|
||||
mismatch_strs = []
|
||||
for token in mismatch_toks:
|
||||
val = float(logits[token])
|
||||
ref_val = float(ref_logits[token])
|
||||
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
f"Unexpected modification of logits: {','.join(mismatch_strs)}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(
|
||||
f"Unexpected modification of logits: {','.join(mismatch_strs)}"
|
||||
),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
class LogitsprocTestHelpers(NamedTuple):
|
||||
"""Supports setting up and validating logitsprocs unit tests."""
|
||||
|
||||
eval_fxn: Callable
|
||||
gen_request_fxn: Optional[Callable] = None
|
||||
|
||||
|
||||
logitsprocs_test_mapping = {
|
||||
STR_NO_LOGITPROC:
|
||||
LogitsprocTestHelpers(eval_fxn=_none_validate),
|
||||
LogitBiasLogitsProcessor:
|
||||
LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params,
|
||||
eval_fxn=_logit_bias_validate),
|
||||
MinPLogitsProcessor:
|
||||
LogitsprocTestHelpers(gen_request_fxn=_min_p_params,
|
||||
eval_fxn=_min_p_validate),
|
||||
MinTokensLogitsProcessor:
|
||||
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
|
||||
eval_fxn=_min_tokens_validate),
|
||||
STR_NO_LOGITPROC: LogitsprocTestHelpers(eval_fxn=_none_validate),
|
||||
LogitBiasLogitsProcessor: LogitsprocTestHelpers(
|
||||
gen_request_fxn=_logit_bias_params, eval_fxn=_logit_bias_validate
|
||||
),
|
||||
MinPLogitsProcessor: LogitsprocTestHelpers(
|
||||
gen_request_fxn=_min_p_params, eval_fxn=_min_p_validate
|
||||
),
|
||||
MinTokensLogitsProcessor: LogitsprocTestHelpers(
|
||||
gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_test_cases() -> list[list[str]]:
|
||||
"""Each test case is a set of logitsprocs"""
|
||||
logitsprocs_types = list(logitsprocs_test_mapping.keys())
|
||||
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
|
||||
for logitproc_type in logitsprocs_types
|
||||
if logitproc_type != STR_NO_LOGITPROC
|
||||
] + [logitsprocs_types]
|
||||
return (
|
||||
[[STR_NO_LOGITPROC]]
|
||||
+ [
|
||||
[logitproc_type, STR_NO_LOGITPROC]
|
||||
for logitproc_type in logitsprocs_types
|
||||
if logitproc_type != STR_NO_LOGITPROC
|
||||
]
|
||||
+ [logitsprocs_types]
|
||||
)
|
||||
|
||||
|
||||
def _generate_fake_step_update(
|
||||
@@ -440,11 +485,18 @@ def _generate_fake_step_update(
|
||||
# Other 50%: add a limited number of reqs (less than the number
|
||||
# of workload reqs remaining, less than an arbitrary max)
|
||||
# If no workload reqs remain: 100% of steps have 0 adds
|
||||
num_step_add = random.choice([
|
||||
0,
|
||||
random.randint(1, min(max_add_remove_per_step,
|
||||
workload_reqs_remaining))
|
||||
]) if workload_reqs_remaining else 0
|
||||
num_step_add = (
|
||||
random.choice(
|
||||
[
|
||||
0,
|
||||
random.randint(
|
||||
1, min(max_add_remove_per_step, workload_reqs_remaining)
|
||||
),
|
||||
]
|
||||
)
|
||||
if workload_reqs_remaining
|
||||
else 0
|
||||
)
|
||||
|
||||
# 50% of steps: remove no requests
|
||||
# Other 50%: remove a limited number of reqs (less than the number
|
||||
@@ -452,9 +504,11 @@ def _generate_fake_step_update(
|
||||
# If persistent batch is empty: 100% of steps have 0 removals until
|
||||
# more requests are added. Assume that removed requests are always
|
||||
# drawn from the current batch, before new adds
|
||||
num_step_remove = random.choice([
|
||||
0, random.randint(1, min(max_add_remove_per_step, batch_size))
|
||||
]) if batch_size else 0
|
||||
num_step_remove = (
|
||||
random.choice([0, random.randint(1, min(max_add_remove_per_step, batch_size))])
|
||||
if batch_size
|
||||
else 0
|
||||
)
|
||||
|
||||
num_step_add_replace = min(num_step_add, num_step_remove)
|
||||
|
||||
@@ -463,23 +517,34 @@ def _generate_fake_step_update(
|
||||
batch_update_builder.removed_append(removal)
|
||||
|
||||
# Get added requests from workload
|
||||
for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]:
|
||||
for add_req_params in workload_params[wdx : (wdx + num_step_add_replace)]:
|
||||
# Replace as many removed requests as possible with added requests
|
||||
add_remove_idx = batch_update_builder.pop_removed()
|
||||
batch_update_builder.added.append(
|
||||
(add_remove_idx, add_req_params.params,
|
||||
add_req_params.prompt_tokens, add_req_params.out_tokens))
|
||||
(
|
||||
add_remove_idx,
|
||||
add_req_params.params,
|
||||
add_req_params.prompt_tokens,
|
||||
add_req_params.out_tokens,
|
||||
)
|
||||
)
|
||||
persistent_batch[add_remove_idx] = add_req_params
|
||||
|
||||
# Append remaining added requests to end of batch
|
||||
add_reqs_append = workload_params[(wdx +
|
||||
num_step_add_replace):(wdx +
|
||||
num_step_add)]
|
||||
batch_update_builder.added.extend([
|
||||
(adx + batch_size, add_req_params.params, add_req_params.prompt_tokens,
|
||||
add_req_params.out_tokens)
|
||||
for adx, add_req_params in enumerate(add_reqs_append)
|
||||
])
|
||||
add_reqs_append = workload_params[
|
||||
(wdx + num_step_add_replace) : (wdx + num_step_add)
|
||||
]
|
||||
batch_update_builder.added.extend(
|
||||
[
|
||||
(
|
||||
adx + batch_size,
|
||||
add_req_params.params,
|
||||
add_req_params.prompt_tokens,
|
||||
add_req_params.out_tokens,
|
||||
)
|
||||
for adx, add_req_params in enumerate(add_reqs_append)
|
||||
]
|
||||
)
|
||||
persistent_batch.extend(add_reqs_append)
|
||||
pre_condense_batch_size = len(persistent_batch)
|
||||
wdx += num_step_add # Update workload offset
|
||||
@@ -488,8 +553,10 @@ def _generate_fake_step_update(
|
||||
last_nonempty_index = pre_condense_batch_size - 1
|
||||
condensed_to_idxs = set()
|
||||
while batch_update_builder.removed:
|
||||
if (last_nonempty_index in batch_update_builder.removed
|
||||
or last_nonempty_index in condensed_to_idxs):
|
||||
if (
|
||||
last_nonempty_index in batch_update_builder.removed
|
||||
or last_nonempty_index in condensed_to_idxs
|
||||
):
|
||||
last_nonempty_index -= 1
|
||||
continue
|
||||
# last_nonempty_index is the highest persistent batch index that was
|
||||
@@ -504,11 +571,10 @@ def _generate_fake_step_update(
|
||||
# move last_nonempty_index -> first_empty_index
|
||||
batch_update_builder.pop_removed()
|
||||
condensed_to_idxs.add(first_empty_index)
|
||||
persistent_batch[first_empty_index] = persistent_batch[
|
||||
last_nonempty_index]
|
||||
persistent_batch[first_empty_index] = persistent_batch[last_nonempty_index]
|
||||
batch_update_builder.moved.append(
|
||||
(last_nonempty_index, first_empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
(last_nonempty_index, first_empty_index, MoveDirectionality.UNIDIRECTIONAL)
|
||||
)
|
||||
|
||||
last_nonempty_index -= 1
|
||||
|
||||
@@ -524,18 +590,21 @@ def _generate_fake_step_update(
|
||||
k = random.randint(0, condensed_batch_size // 2)
|
||||
idxs = list(range(condensed_batch_size))
|
||||
random.shuffle(idxs)
|
||||
swaps = [
|
||||
tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)
|
||||
]
|
||||
batch_update_builder.moved.extend([
|
||||
(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps
|
||||
])
|
||||
swaps = [tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)]
|
||||
batch_update_builder.moved.extend(
|
||||
[(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps]
|
||||
)
|
||||
for adx, bdx in swaps:
|
||||
persistent_batch[adx], persistent_batch[bdx] = persistent_batch[
|
||||
bdx], persistent_batch[adx]
|
||||
persistent_batch[adx], persistent_batch[bdx] = (
|
||||
persistent_batch[bdx],
|
||||
persistent_batch[adx],
|
||||
)
|
||||
|
||||
return (batch_update_builder.get_and_reset(condensed_batch_size), wdx,
|
||||
workload_size - wdx)
|
||||
return (
|
||||
batch_update_builder.get_and_reset(condensed_batch_size),
|
||||
wdx,
|
||||
workload_size - wdx,
|
||||
)
|
||||
|
||||
|
||||
def _assert_valid(
|
||||
@@ -550,8 +619,10 @@ def _assert_valid(
|
||||
# Trivial case of empty persistent batch
|
||||
assert len(persistent_batch) == 0
|
||||
if logits_w_lp.shape[0] != 0:
|
||||
raise ValueError("Fake persistent batch is empty but logitsprocs "
|
||||
f"output batch has shape {logits_w_lp.shape}")
|
||||
raise ValueError(
|
||||
"Fake persistent batch is empty but logitsprocs "
|
||||
f"output batch has shape {logits_w_lp.shape}"
|
||||
)
|
||||
return
|
||||
|
||||
# Validate logits for each fake request
|
||||
@@ -560,36 +631,40 @@ def _assert_valid(
|
||||
# Invoke the appropriate validation function for
|
||||
# the logitproc employed by this request
|
||||
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
|
||||
fxn(test_fakes=test_fakes,
|
||||
fxn(
|
||||
test_fakes=test_fakes,
|
||||
persistent_batch=persistent_batch,
|
||||
logits_new=logits_w_lp,
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
|
||||
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
|
||||
def test_logitsprocs(device: str, reqs_per_logitproc: int,
|
||||
logitsprocs_under_test: list[str]):
|
||||
def test_logitsprocs(
|
||||
device: str, reqs_per_logitproc: int, logitsprocs_under_test: list[str]
|
||||
):
|
||||
random.seed(40)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Define a shuffled batch of requests which individually use a different
|
||||
# logitproc, or no logitproc at all
|
||||
workload_params = _generate_mixed_logitsprocs_batch_params(
|
||||
reqs_per_logitproc=reqs_per_logitproc,
|
||||
logitsprocs_types=logitsprocs_under_test)
|
||||
reqs_per_logitproc=reqs_per_logitproc, logitsprocs_types=logitsprocs_under_test
|
||||
)
|
||||
workload_size = len(workload_params)
|
||||
|
||||
# Create fake test data structures for testing.
|
||||
test_fakes = _generate_test_fakes(workload_size, device)
|
||||
|
||||
wdx = 0 # Next request index in workload to add
|
||||
persistent_batch: list[LogitsProcsRequestParams] = [
|
||||
] # Persistent batch state, as list of workload indices
|
||||
persistent_batch: list[
|
||||
LogitsProcsRequestParams
|
||||
] = [] # Persistent batch state, as list of workload indices
|
||||
|
||||
# Generate fake removed request indices from current persistent
|
||||
# batch before adds
|
||||
|
||||
@@ -7,32 +7,44 @@ from typing import Union
|
||||
import pytest
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
|
||||
# yapf: disable
|
||||
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS, MODEL_NAME,
|
||||
POOLING_MODEL_NAME, TEMP_GREEDY,
|
||||
CustomLogitprocSource,
|
||||
DummyLogitsProcessor,
|
||||
WrappedPerReqLogitsProcessor,
|
||||
dummy_module)
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS,
|
||||
MODEL_NAME,
|
||||
POOLING_MODEL_NAME,
|
||||
TEMP_GREEDY,
|
||||
CustomLogitprocSource,
|
||||
DummyLogitsProcessor,
|
||||
WrappedPerReqLogitsProcessor,
|
||||
dummy_module,
|
||||
prompts,
|
||||
)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from tests.v1.logits_processors.utils import prompts
|
||||
|
||||
# yapf: enable
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS,
|
||||
LogitsProcessor)
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
STR_POOLING_REJECTS_LOGITSPROCS,
|
||||
LogitsProcessor,
|
||||
)
|
||||
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 128}),
|
||||
SamplingParams(
|
||||
temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 128},
|
||||
),
|
||||
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
|
||||
SamplingParams(temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 67}),
|
||||
SamplingParams(
|
||||
temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 67},
|
||||
),
|
||||
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
|
||||
]
|
||||
|
||||
@@ -49,7 +61,7 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
|
||||
2. Server has *not* loaded dummy logitproc; test that all requests
|
||||
behave as if logitproc is *not* operating (output matches reference
|
||||
`LLM` output.)
|
||||
|
||||
|
||||
Args:
|
||||
kwargs: `LLM` constructor kwargs
|
||||
logitproc_loaded: server has loaded dummy logitproc if True
|
||||
@@ -73,7 +85,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
|
||||
|
||||
# Validate outputs
|
||||
for bdx, (out_lp, out_ref, params) in enumerate(
|
||||
zip(outputs_logitproc, outputs_ref, sampling_params_list)):
|
||||
zip(outputs_logitproc, outputs_ref, sampling_params_list)
|
||||
):
|
||||
lp_toks = out_lp.outputs[0].token_ids
|
||||
if logitproc_loaded and params.extra_args:
|
||||
# This request exercises custom logitproc; validate that logitproc
|
||||
@@ -81,8 +94,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
|
||||
target_token = params.extra_args[DUMMY_LOGITPROC_ARG]
|
||||
if not all(x == target_token for x in lp_toks):
|
||||
raise AssertionError(
|
||||
f"Request {bdx} generated {lp_toks}, should all be "
|
||||
f"{target_token}")
|
||||
f"Request {bdx} generated {lp_toks}, should all be {target_token}"
|
||||
)
|
||||
else:
|
||||
# This request does not exercise custom logitproc (or custom
|
||||
# logitproc is not enabled on this server); validate against
|
||||
@@ -90,16 +103,15 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
|
||||
ref_toks = out_ref.outputs[0].token_ids
|
||||
if lp_toks != ref_toks:
|
||||
raise AssertionError(
|
||||
f"Request {bdx} generated {lp_toks}, should match "
|
||||
f"{ref_toks}")
|
||||
f"Request {bdx} generated {lp_toks}, should match {ref_toks}"
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource))
|
||||
def test_custom_logitsprocs(monkeypatch,
|
||||
logitproc_source: CustomLogitprocSource):
|
||||
def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource):
|
||||
"""Test offline Python interface for passing custom logitsprocs
|
||||
|
||||
|
||||
Construct an `LLM` instance which loads a custom logitproc that has a
|
||||
well-defined behavior (mask out all tokens except one `target_token`)
|
||||
|
||||
@@ -118,7 +130,7 @@ def test_custom_logitsprocs(monkeypatch,
|
||||
instance output
|
||||
* Logitproc passed in via {entrypoint, class object, fully-qualified class
|
||||
name (FQCN)} - test that dummy logitproc is utilized correctly when
|
||||
provided via any of these three possible sources
|
||||
provided via any of these three possible sources
|
||||
|
||||
Args:
|
||||
monkeypatch: for setting env vars
|
||||
@@ -142,6 +154,7 @@ def test_custom_logitsprocs(monkeypatch,
|
||||
# Scenario: vLLM loads a logitproc from a preconfigured entrypoint
|
||||
# To that end, mock a dummy logitproc entrypoint
|
||||
import importlib.metadata
|
||||
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
@@ -165,7 +178,7 @@ def test_custom_logitsprocs(monkeypatch,
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_logitsprocs_req(monkeypatch):
|
||||
"""Test passing request-level logits processor to offline Python interface
|
||||
|
||||
|
||||
Wrap a request-level logits processor to create a batch level logits
|
||||
processor that has a well-defined behavior (mask out all tokens except one
|
||||
`target_token`)
|
||||
@@ -190,18 +203,23 @@ def test_custom_logitsprocs_req(monkeypatch):
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
random.seed(40)
|
||||
_run_test({"logits_processors": [WrappedPerReqLogitsProcessor]},
|
||||
logitproc_loaded=True)
|
||||
_run_test(
|
||||
{"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("logitproc_source", [
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_FQCN,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"logitproc_source",
|
||||
[
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_FQCN,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
|
||||
],
|
||||
)
|
||||
def test_pooling_rejects_custom_logitsprocs(
|
||||
monkeypatch, logitproc_source: CustomLogitprocSource):
|
||||
monkeypatch, logitproc_source: CustomLogitprocSource
|
||||
):
|
||||
"""Validate that vLLM engine initialization properly rejects custom
|
||||
logitsprocs when the model is a pooling model.
|
||||
|
||||
@@ -233,6 +251,7 @@ def test_pooling_rejects_custom_logitsprocs(
|
||||
|
||||
# Patch in dummy logitproc entrypoint
|
||||
import importlib.metadata
|
||||
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
|
||||
# fork is required for entrypoint patch to be visible to workers,
|
||||
@@ -245,10 +264,15 @@ def test_pooling_rejects_custom_logitsprocs(
|
||||
gpu_memory_utilization=0.1,
|
||||
)
|
||||
# Require that no logitsprocs have been loaded
|
||||
assert sum([
|
||||
1 for _ in llm.llm_engine.model_executor.driver_worker.worker.
|
||||
model_runner.input_batch.logitsprocs.all
|
||||
]) == 0
|
||||
assert (
|
||||
sum(
|
||||
[
|
||||
1
|
||||
for _ in llm.llm_engine.model_executor.driver_worker.worker.model_runner.input_batch.logitsprocs.all
|
||||
]
|
||||
)
|
||||
== 0
|
||||
)
|
||||
return
|
||||
|
||||
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
|
||||
|
||||
@@ -10,16 +10,20 @@ import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import (RemoteOpenAIServerCustom,
|
||||
create_new_process_for_each_test)
|
||||
from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_test
|
||||
|
||||
# yapf: disable
|
||||
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS, MODEL_NAME,
|
||||
TEMP_GREEDY, dummy_module)
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS,
|
||||
MODEL_NAME,
|
||||
TEMP_GREEDY,
|
||||
dummy_module,
|
||||
prompts,
|
||||
)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from tests.v1.logits_processors.utils import prompts
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -33,11 +37,12 @@ def _server_with_logitproc_entrypoint(
|
||||
|
||||
# Patch `entry_points` to inject logitproc entrypoint
|
||||
import importlib.metadata
|
||||
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
@@ -55,10 +60,11 @@ def _server_with_logitproc_module(
|
||||
|
||||
# Patch `modules` to inject dummy logitproc module
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
@@ -80,8 +86,9 @@ def default_server_args():
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function",
|
||||
params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]])
|
||||
@pytest.fixture(
|
||||
scope="function", params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]
|
||||
)
|
||||
def server(default_server_args, request, monkeypatch):
|
||||
"""Consider two server configurations:
|
||||
(1) --logits-processors cli arg specifies dummy logits processor via fully-
|
||||
@@ -102,8 +109,7 @@ def server(default_server_args, request, monkeypatch):
|
||||
args = default_server_args
|
||||
_server_fxn = _server_with_logitproc_entrypoint
|
||||
|
||||
with RemoteOpenAIServerCustom(MODEL_NAME, args,
|
||||
_server_fxn) as remote_server:
|
||||
with RemoteOpenAIServerCustom(MODEL_NAME, args, _server_fxn) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -133,7 +139,7 @@ api_keyword_args = {
|
||||
)
|
||||
async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test custom logitsprocs when starting OpenAI server from CLI
|
||||
|
||||
|
||||
Launch vLLM OpenAI-compatible server, configured to load a custom logitproc
|
||||
that has a well-defined behavior (mask out all tokens except one
|
||||
`target_token`).
|
||||
@@ -157,9 +163,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# For requests which activate the dummy logitproc, choose one of
|
||||
# two `target_token` values which are known not to be EOS tokens
|
||||
request_keyword_args["extra_body"] = {
|
||||
"vllm_xargs": {
|
||||
DUMMY_LOGITPROC_ARG: target_token
|
||||
}
|
||||
"vllm_xargs": {DUMMY_LOGITPROC_ARG: target_token}
|
||||
}
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
@@ -173,8 +177,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
choices: openai.types.CompletionChoice = batch.choices
|
||||
toks = choices[0].logprobs.tokens
|
||||
if not all([x == toks[0] for x in toks]):
|
||||
raise AssertionError(
|
||||
f"Generated {toks} should all be {toks[0]}")
|
||||
raise AssertionError(f"Generated {toks} should all be {toks[0]}")
|
||||
|
||||
# Alternate whether to activate dummy logitproc for each request
|
||||
use_dummy_logitproc = not use_dummy_logitproc
|
||||
|
||||
@@ -10,10 +10,13 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP,
|
||||
AdapterLogitsProcessor,
|
||||
BatchUpdate, LogitsProcessor,
|
||||
RequestLogitsProcessor)
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
LOGITSPROCS_GROUP,
|
||||
AdapterLogitsProcessor,
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -30,6 +33,7 @@ DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
||||
|
||||
class CustomLogitprocSource(Enum):
|
||||
"""How to source a logitproc for testing purposes"""
|
||||
|
||||
LOGITPROC_SOURCE_NONE = auto() # No custom logitproc
|
||||
LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint
|
||||
LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN)
|
||||
@@ -48,8 +52,9 @@ prompts = [
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
self.req_info: dict[int, int] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
@@ -60,8 +65,8 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
lambda params, _, __: params.extra_args and
|
||||
(params.extra_args.get("target_token")),
|
||||
lambda params, _, __: params.extra_args
|
||||
and (params.extra_args.get("target_token")),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
@@ -69,16 +74,16 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
cols = torch.tensor(list(self.req_info.values()),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
rows = torch.tensor(list(self.req_info.keys()),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
cols = torch.tensor(
|
||||
list(self.req_info.values()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
rows = torch.tensor(
|
||||
list(self.req_info.keys()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float('-inf')
|
||||
logits[rows] = float("-inf")
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
@@ -154,14 +159,17 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
Returns:
|
||||
`Callable` request logits processor, or None
|
||||
"""
|
||||
target_token: Optional[
|
||||
Any] = params.extra_args and params.extra_args.get("target_token")
|
||||
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.", target_token)
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user