[V1] Logits processors extensibility (#19912)
Signed-off-by: Andrew Feldman <afeldman@redhat.com> Signed-off-by: Andrew Feldman <afeld2012@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
127
tests/v1/logits_processors/utils.py
Normal file
127
tests/v1/logits_processors/utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality)
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
DUMMY_LOGITPROC_ARG = "target_token"
|
||||
TEMP_GREEDY = 0.0
|
||||
MAX_TOKENS = 20
|
||||
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
|
||||
DUMMY_LOGITPROC_MODULE = "DummyModule"
|
||||
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
||||
|
||||
|
||||
class CustomLogitprocSource(Enum):
|
||||
"""How to source a logitproc for testing purposes"""
|
||||
LOGITPROC_SOURCE_NONE = auto() # No custom logitproc
|
||||
LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint
|
||||
LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN)
|
||||
LOGITPROC_SOURCE_CLASS = auto() # Via provided class object
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
self.req_info: dict[int, SamplingParams] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Never impacts greedy sampling"""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
assert params is not None
|
||||
if params.extra_args and (target_token :=
|
||||
params.extra_args.get("target_token")):
|
||||
self.req_info[index] = target_token
|
||||
|
||||
if self.req_info:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
self.req_info.pop(index, None)
|
||||
|
||||
# Process moved requests, unidirectional move (a->b) and swap
|
||||
# (a<->b)
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
a_val = self.req_info.pop(adx, None)
|
||||
b_val = self.req_info.pop(bdx, None)
|
||||
if a_val is not None:
|
||||
self.req_info[bdx] = a_val
|
||||
if direct == MoveDirectionality.SWAP and b_val is not None:
|
||||
self.req_info[adx] = b_val
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.req_info:
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
rows_list = list(self.req_info.keys())
|
||||
cols = torch.tensor([self.req_info[i] for i in rows_list],
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float('-inf')
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
"""Dummy module with dummy logitproc class"""
|
||||
dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE)
|
||||
dummy_module.DummyLogitsProcessor = DummyLogitsProcessor # type: ignore
|
||||
|
||||
|
||||
class EntryPoint:
|
||||
"""Dummy entrypoint class for logitsprocs testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.name = DUMMY_LOGITPROC_ENTRYPOINT
|
||||
self.value = DUMMY_LOGITPROC_FQCN
|
||||
|
||||
def load(self):
|
||||
return DummyLogitsProcessor
|
||||
|
||||
|
||||
class EntryPoints(list):
|
||||
"""Dummy EntryPoints class for logitsprocs testing"""
|
||||
|
||||
def __init__(self, group: str):
|
||||
# Emulate list-like functionality
|
||||
eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else []
|
||||
super().__init__(eps)
|
||||
# Extra attributes
|
||||
self.names = [ep.name for ep in eps]
|
||||
|
||||
|
||||
"""Fake version of importlib.metadata.entry_points"""
|
||||
entry_points = lambda group: EntryPoints(group)
|
||||
Reference in New Issue
Block a user