Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,10 +10,13 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP,
|
||||
AdapterLogitsProcessor,
|
||||
BatchUpdate, LogitsProcessor,
|
||||
RequestLogitsProcessor)
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
LOGITSPROCS_GROUP,
|
||||
AdapterLogitsProcessor,
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -30,6 +33,7 @@ DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
||||
|
||||
class CustomLogitprocSource(Enum):
|
||||
"""How to source a logitproc for testing purposes"""
|
||||
|
||||
LOGITPROC_SOURCE_NONE = auto() # No custom logitproc
|
||||
LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint
|
||||
LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN)
|
||||
@@ -48,8 +52,9 @@ prompts = [
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
self.req_info: dict[int, int] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
@@ -60,8 +65,8 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
lambda params, _, __: params.extra_args and
|
||||
(params.extra_args.get("target_token")),
|
||||
lambda params, _, __: params.extra_args
|
||||
and (params.extra_args.get("target_token")),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
@@ -69,16 +74,16 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
cols = torch.tensor(list(self.req_info.values()),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
rows = torch.tensor(list(self.req_info.keys()),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
cols = torch.tensor(
|
||||
list(self.req_info.values()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
rows = torch.tensor(
|
||||
list(self.req_info.keys()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float('-inf')
|
||||
logits[rows] = float("-inf")
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
@@ -154,14 +159,17 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
Returns:
|
||||
`Callable` request logits processor, or None
|
||||
"""
|
||||
target_token: Optional[
|
||||
Any] = params.extra_args and params.extra_args.get("target_token")
|
||||
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.", target_token)
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user