[V1] Logits processors extensibility (#19912)
Signed-off-by: Andrew Feldman <afeldman@redhat.com> Signed-off-by: Andrew Feldman <afeld2012@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
185
vllm/v1/sample/logits_processor/__init__.py
Normal file
185
vllm/v1/sample/logits_processor/__init__.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor)
|
||||
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality)
|
||||
from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder,
|
||||
LogitsProcessors)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Error message when the user tries to initialize vLLM with a pooling model
|
||||
# and custom logitsproces
|
||||
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
|
||||
" logits processors.")
|
||||
|
||||
LOGITSPROCS_GROUP = 'vllm.logits_processors'
|
||||
|
||||
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
||||
MinTokensLogitsProcessor,
|
||||
LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
]
|
||||
|
||||
|
||||
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
|
||||
"""Load all installed logit processor plugins"""
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import entry_points
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
|
||||
if len(installed_logitsprocs_plugins) == 0:
|
||||
logger.debug("No logitsprocs plugins installed (group %s).",
|
||||
LOGITSPROCS_GROUP)
|
||||
return []
|
||||
|
||||
# Load logitsprocs plugins
|
||||
logger.debug("Loading installed logitsprocs plugins (group %s):",
|
||||
LOGITSPROCS_GROUP)
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for entrypoint in installed_logitsprocs_plugins:
|
||||
try:
|
||||
logger.debug("- Loading logitproc plugin entrypoint=%s target=%s",
|
||||
entrypoint.name, entrypoint.value)
|
||||
classes.append(entrypoint.load())
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load LogitsProcessor plugin {entrypoint}") from e
|
||||
return classes
|
||||
|
||||
|
||||
def _load_logitsprocs_by_fqcns(
|
||||
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]]
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load logit processor types, identifying them by fully-qualified class
|
||||
names (FQCNs).
|
||||
|
||||
Effectively, a mixed list of logitproc types and FQCN strings is converted
|
||||
into a list of entirely logitproc types, by loading from the FQCNs.
|
||||
|
||||
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
|
||||
|
||||
Already-loaded logitproc types must be subclasses of LogitsProcessor
|
||||
|
||||
Args:
|
||||
logits_processors: Potentially mixed list of logitsprocs types and FQCN
|
||||
strings for logitproc types
|
||||
|
||||
Returns:
|
||||
List of logitproc types
|
||||
|
||||
"""
|
||||
if not logits_processors:
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
"%s additional custom logits processors specified, checking whether "
|
||||
"they need to be loaded.", len(logits_processors))
|
||||
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for ldx, logitproc in enumerate(logits_processors):
|
||||
if isinstance(logitproc, type):
|
||||
logger.debug(" - Already-loaded logit processor: %s",
|
||||
logitproc.__name__)
|
||||
if not issubclass(logitproc, LogitsProcessor):
|
||||
raise ValueError(
|
||||
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
|
||||
)
|
||||
classes.append(logitproc)
|
||||
continue
|
||||
|
||||
logger.debug("- Loading logits processor %s", logitproc)
|
||||
module_path, qualname = logitproc.split(":")
|
||||
|
||||
try:
|
||||
# Load module
|
||||
module = importlib.import_module(module_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
|
||||
) from e
|
||||
|
||||
# Walk down dotted name to get logitproc class
|
||||
obj = module
|
||||
for attr in qualname.split("."):
|
||||
obj = getattr(obj, attr)
|
||||
if not isinstance(obj, type):
|
||||
raise ValueError("Loaded logit processor must be a type.")
|
||||
if not issubclass(obj, LogitsProcessor):
|
||||
raise ValueError(
|
||||
f"{obj.__name__} must be a subclass of LogitsProcessor")
|
||||
classes.append(obj)
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def _load_custom_logitsprocs(
|
||||
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]],
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load all custom logits processors.
|
||||
|
||||
* First load all installed logitproc plugins
|
||||
* Second load custom logitsprocs pass by the user at initialization time
|
||||
|
||||
Args:
|
||||
logits_processors: potentially mixed list of logitproc types and
|
||||
logitproc type fully-qualified names (FQCNs)
|
||||
which need to be loaded
|
||||
|
||||
Returns:
|
||||
A list of all loaded logitproc types
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_tpu():
|
||||
# No logitsprocs specified by caller
|
||||
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
|
||||
return []
|
||||
|
||||
return (_load_logitsprocs_plugins() +
|
||||
_load_logitsprocs_by_fqcns(logits_processors))
|
||||
|
||||
|
||||
def build_logitsprocs(
|
||||
vllm_config: "VllmConfig",
|
||||
device: torch.device,
|
||||
is_pin_memory: bool,
|
||||
is_pooling_model: bool,
|
||||
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
|
||||
) -> LogitsProcessors:
|
||||
if is_pooling_model:
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
|
||||
logger.debug("Skipping logits processor loading because pooling models"
|
||||
" do not support logits processors.")
|
||||
return LogitsProcessors()
|
||||
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
||||
return LogitsProcessors(
|
||||
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
|
||||
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor",
|
||||
"MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder",
|
||||
"MoveDirectionality", "LogitsProcessors", "build_logitsprocs",
|
||||
"STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP"
|
||||
]
|
||||
@@ -1,241 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = 0
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = 1
|
||||
|
||||
|
||||
# (index, params, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Note: each added request is represented as
|
||||
# (index, params, output_tok_ids)
|
||||
# Key assumption: output_tok_ids is a reference to the
|
||||
# request's running output tokens list; in this way
|
||||
# the logits processors always see the latest list of
|
||||
# generated tokens
|
||||
removed: Sequence[RemovedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
moved: list[MovedRequest]
|
||||
added: list[AddedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: Optional[list[RemovedRequest]] = None,
|
||||
moved: Optional[list[MovedRequest]] = None,
|
||||
added: Optional[list[AddedRequest]] = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.moved = moved or []
|
||||
self.added = added or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from
|
||||
the persistent batch.
|
||||
|
||||
Must not be called after the first time
|
||||
self.removed, self.pop_removed() or
|
||||
self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError("Cannot register new removed request after"
|
||||
" self.removed has been read.")
|
||||
self._removed.append(index)
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> Optional[int]:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> Optional[int]:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||
"""Generate a logitsprocs batch update data structure
|
||||
and reset internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
if not any((self._removed, self.moved, self.added)):
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
# Reset removed/moved/added update lists
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
TODO(andy): won't be utilized until logits
|
||||
processors are user-extensible
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional[BatchUpdate],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update is non-None iff there have been
|
||||
changes to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogitsProcessorManager:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # argmax-invariant logitsprocs
|
||||
non_argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # non-argmax-invariant logitsprocs
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
|
||||
|
||||
###### ----- Built-in LogitsProcessor impls below here
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, max_num_reqs: int, pin_memory: bool,
|
||||
device: DeviceLikeType):
|
||||
super().__init__()
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.min_p_count: int = 0
|
||||
|
||||
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
pin_memory=is_pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
|
||||
self.use_double_tensor = torch.device("cpu") != torch.device(device)
|
||||
self.use_double_tensor = torch.device(device).type != "cpu"
|
||||
|
||||
if self.use_double_tensor:
|
||||
# Pre-allocated device tensor
|
||||
@@ -260,8 +51,8 @@ class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
needs_update = False
|
||||
# Process added requests.
|
||||
for index, params, _ in batch_update.added:
|
||||
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
|
||||
for index, params, _, _ in batch_update.added:
|
||||
min_p = params.min_p
|
||||
if self.min_p_cpu[index] != min_p:
|
||||
needs_update = True
|
||||
self.min_p_cpu[index] = min_p
|
||||
@@ -316,11 +107,10 @@ class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
super().__init__()
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
def __init__(self, _, device: torch.device, is_pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.pin_memory = is_pin_memory
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
|
||||
self.bias_tensor: torch.Tensor = torch.tensor(())
|
||||
self.logits_slice = (self._device_tensor([], torch.int32),
|
||||
@@ -337,9 +127,8 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
needs_update: bool = False
|
||||
# Process added requests.
|
||||
for index, params, _ in batch_update.added:
|
||||
if isinstance(params, SamplingParams) and (lb :=
|
||||
params.logit_bias):
|
||||
for index, params, _, _ in batch_update.added:
|
||||
if lb := params.logit_bias:
|
||||
self.biases[index] = lb
|
||||
needs_update = True
|
||||
else:
|
||||
@@ -400,12 +189,12 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
# index -> (min_toks, output_token_ids, stop_token_ids)
|
||||
super().__init__()
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.pin_memory = is_pin_memory
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
|
||||
# (req_idx_tensor,eos_tok_id_tensor)
|
||||
self.logits_slice: tuple[torch.Tensor,
|
||||
@@ -424,9 +213,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
if batch_update:
|
||||
# Process added requests.
|
||||
for index, params, output_tok_ids in batch_update.added:
|
||||
if (isinstance(params, SamplingParams)
|
||||
and (min_tokens := params.min_tokens)
|
||||
for index, params, _, output_tok_ids in batch_update.added:
|
||||
if ((min_tokens := params.min_tokens)
|
||||
and len(output_tok_ids) < min_tokens):
|
||||
# Replace request metadata at batch index
|
||||
self.min_toks[index] = (min_tokens, output_tok_ids,
|
||||
@@ -499,35 +287,3 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
# Inhibit EOS token for requests which have not reached min length
|
||||
logits[self.logits_slice] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
||||
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
|
||||
device: torch.device) -> LogitsProcessorManager:
|
||||
"""Construct 'builtin' vLLM logitsprocs which the engine
|
||||
loads by default.
|
||||
|
||||
Args:
|
||||
pin_memory_available: pinned memory is available for use
|
||||
for use by logitsproc
|
||||
max_num_reqs: ceiling on request count in persistent batch
|
||||
device: inference device
|
||||
|
||||
Returns:
|
||||
Data structure encapsulating loaded logitsprocs
|
||||
"""
|
||||
min_tokens_logitproc = MinTokensLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
logit_bias_logitproc = LogitBiasLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
min_p_logitproc = MinPLogitsProcessor(
|
||||
pin_memory=pin_memory_available,
|
||||
device=device,
|
||||
# +1 for temporary swap space
|
||||
max_num_reqs=max_num_reqs + 1)
|
||||
return LogitsProcessorManager(
|
||||
non_argmax_invariant=[
|
||||
min_tokens_logitproc,
|
||||
logit_bias_logitproc,
|
||||
],
|
||||
argmax_invariant=[min_p_logitproc],
|
||||
)
|
||||
86
vllm/v1/sample/logits_processor/interface.py
Normal file
86
vllm/v1/sample/logits_processor/interface.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = auto()
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = auto()
|
||||
|
||||
|
||||
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
|
||||
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Key assumption: the `output_tok_ids` list (which is an element of each
|
||||
# tuple in `added`) is a reference to the request's running output tokens
|
||||
# list; via this reference, the logits processors always see the latest
|
||||
# list of generated output tokens
|
||||
removed: Sequence[RemovedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional["BatchUpdate"],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update is non-None iff there have been
|
||||
changes to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
149
vllm/v1/sample/logits_processor/state.py
Normal file
149
vllm/v1/sample/logits_processor/state.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterator
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.v1.sample.logits_processor.interface import (AddedRequest,
|
||||
BatchUpdate,
|
||||
MovedRequest,
|
||||
RemovedRequest)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
moved: list[MovedRequest]
|
||||
added: list[AddedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: Optional[list[RemovedRequest]] = None,
|
||||
moved: Optional[list[MovedRequest]] = None,
|
||||
added: Optional[list[AddedRequest]] = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.moved = moved or []
|
||||
self.added = added or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from the persistent batch.
|
||||
|
||||
Must not be called after the first time self.removed,
|
||||
self.pop_removed() or self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError("Cannot register new removed request after"
|
||||
" self.removed has been read.")
|
||||
self._removed.append(index)
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> Optional[int]:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> Optional[int]:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def _is_update(self) -> bool:
|
||||
"""True if there is a batch state change"""
|
||||
return any((self._removed, self.moved, self.added))
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||
"""Generate a logitsprocs batch update data structure and reset
|
||||
internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
if not self._is_update():
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessors:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None:
|
||||
self.argmax_invariant: list[LogitsProcessor] = []
|
||||
self.non_argmax_invariant: list[LogitsProcessor] = []
|
||||
if logitsprocs:
|
||||
for logitproc in logitsprocs:
|
||||
(self.argmax_invariant if logitproc.is_argmax_invariant() else
|
||||
self.non_argmax_invariant).append(logitproc)
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator["LogitsProcessor"]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -40,4 +40,4 @@ class SamplingMetadata:
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessorManager
|
||||
logitsprocs: LogitsProcessors
|
||||
|
||||
@@ -18,8 +18,8 @@ from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
MoveDirectionality,
|
||||
init_builtin_logitsprocs)
|
||||
LogitsProcessors,
|
||||
MoveDirectionality)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
@@ -78,8 +78,11 @@ class InputBatch:
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
@@ -221,14 +224,6 @@ class InputBatch:
|
||||
# updates. Should reset each step.
|
||||
self.batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# Define logits processors.
|
||||
# TODO(andy): logits processor list should be extensible via engine
|
||||
# constructor argument; for now the list is fixed.
|
||||
self.logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=pin_memory,
|
||||
max_num_reqs=max_num_reqs + 1,
|
||||
device=device)
|
||||
|
||||
# TODO convert this to LogitsProcessor
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
@@ -244,6 +239,10 @@ class InputBatch:
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
@@ -255,28 +254,35 @@ class InputBatch:
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def _get_next_add_index(self) -> int:
|
||||
if (req_index := self.batch_update_builder.pop_removed()) is not None:
|
||||
# Fill the empty index.
|
||||
return req_index
|
||||
# Append to end
|
||||
return self.num_reqs
|
||||
|
||||
def _register_add_request(self, request: "CachedRequestState") -> int:
|
||||
"""Track add-request operations"""
|
||||
req_index = self._get_next_add_index()
|
||||
assert req_index < self.max_num_reqs
|
||||
params = (request.sampling_params
|
||||
if request.sampling_params else request.pooling_params)
|
||||
"""Track add-request operations for logits processors.
|
||||
Not applicable to pooling models.
|
||||
"""
|
||||
|
||||
# Detailed added request metadata is only required for non-pooling
|
||||
# models, to support logitsprocs
|
||||
assert request.sampling_params
|
||||
|
||||
# Fill the next empty index if there is one.
|
||||
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
||||
# Append to end otherwise.
|
||||
new_req_index = self.num_reqs
|
||||
|
||||
assert new_req_index < self.max_num_reqs
|
||||
self.batch_update_builder.added.append(
|
||||
(req_index, params, request.output_token_ids))
|
||||
return req_index
|
||||
(new_req_index, request.sampling_params, request.prompt_token_ids,
|
||||
request.output_token_ids))
|
||||
return new_req_index
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
) -> int:
|
||||
req_index = self._register_add_request(request)
|
||||
if not self.is_pooling_model:
|
||||
# New request index bookkeeping for autoregressive models.
|
||||
req_index = self._register_add_request(request)
|
||||
else:
|
||||
req_index = self.num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
@@ -411,7 +417,10 @@ class InputBatch:
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
if not self.is_pooling_model:
|
||||
# Autoregressive models require bookkeeping of removed requests to
|
||||
# support logitsprocs.
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
@@ -446,6 +455,8 @@ class InputBatch:
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
# For autoregressive models, track detailed request reordering info
|
||||
# to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(i1, i2, MoveDirectionality.SWAP))
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
@@ -513,11 +524,18 @@ class InputBatch:
|
||||
swaps: list of (from,to) swap tuples for moved requests
|
||||
empty_req_indices: indices not filled by condensation
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Will be contiguous in pooling case, just trim the lists.
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
return
|
||||
|
||||
if not (empty_req_indices := self.batch_update_builder.removed):
|
||||
# All removed requests were replaced by added requests, or else no
|
||||
# requests were removed at all. No condense() needed
|
||||
return
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
@@ -541,6 +559,8 @@ class InputBatch:
|
||||
# Move active request down into empty request
|
||||
# index.
|
||||
self.batch_update_builder.pop_removed()
|
||||
# Autoregressive models require detailed tracking of condense
|
||||
# operations to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(last_req_index, empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
@@ -596,15 +616,20 @@ class InputBatch:
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
|
||||
def refresh_metadata(self):
|
||||
"""Apply batch updates, reset input batch at end of step
|
||||
"""Apply any batch updates to sampling metadata."""
|
||||
|
||||
* Apply batch add/remove/permute to logits procs' states
|
||||
* If batch state is modified, update sampling metadata
|
||||
"""
|
||||
if self.is_pooling_model:
|
||||
# Batch changes every step for pooling models.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
return
|
||||
|
||||
# For non-pooling models - generate and apply logitsprocs update;
|
||||
# reset batch update tracking.
|
||||
# Update sampling metadata if batch state is changed.
|
||||
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
|
||||
for logit_proc in self.logitsprocs.all:
|
||||
logit_proc.update_state(batch_update)
|
||||
|
||||
@@ -68,6 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
@@ -80,7 +81,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from ..sample.logits_processor import LogitsProcessorManager
|
||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
@@ -221,6 +221,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=build_logitsprocs(
|
||||
self.vllm_config, self.device, self.pin_memory,
|
||||
self.is_pooling_model,
|
||||
self.vllm_config.model_config.logits_processors),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
@@ -2447,7 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
try:
|
||||
sampler_output = self.sampler(logits=logits,
|
||||
@@ -2968,6 +2973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
|
||||
Reference in New Issue
Block a user