[V1] Logits processors extensibility (#19912)

Signed-off-by: Andrew Feldman <afeldman@redhat.com>
Signed-off-by: Andrew Feldman <afeld2012@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
afeldman-nm
2025-08-16 15:59:17 -04:00
committed by GitHub
parent 4fc722eca4
commit bf7f470b22
22 changed files with 1312 additions and 334 deletions

View File

@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import itertools
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional, Union
import torch
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor)
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder,
LogitsProcessors)
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
# Error message when the user tries to initialize vLLM with a pooling model
# and custom logitsproces
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
" logits processors.")
LOGITSPROCS_GROUP = 'vllm.logits_processors'
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
]
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
"""Load all installed logit processor plugins"""
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
if len(installed_logitsprocs_plugins) == 0:
logger.debug("No logitsprocs plugins installed (group %s).",
LOGITSPROCS_GROUP)
return []
# Load logitsprocs plugins
logger.debug("Loading installed logitsprocs plugins (group %s):",
LOGITSPROCS_GROUP)
classes: list[type[LogitsProcessor]] = []
for entrypoint in installed_logitsprocs_plugins:
try:
logger.debug("- Loading logitproc plugin entrypoint=%s target=%s",
entrypoint.name, entrypoint.value)
classes.append(entrypoint.load())
except Exception as e:
raise RuntimeError(
f"Failed to load LogitsProcessor plugin {entrypoint}") from e
return classes
def _load_logitsprocs_by_fqcns(
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]]
) -> list[type[LogitsProcessor]]:
"""Load logit processor types, identifying them by fully-qualified class
names (FQCNs).
Effectively, a mixed list of logitproc types and FQCN strings is converted
into a list of entirely logitproc types, by loading from the FQCNs.
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
Already-loaded logitproc types must be subclasses of LogitsProcessor
Args:
logits_processors: Potentially mixed list of logitsprocs types and FQCN
strings for logitproc types
Returns:
List of logitproc types
"""
if not logits_processors:
return []
logger.debug(
"%s additional custom logits processors specified, checking whether "
"they need to be loaded.", len(logits_processors))
classes: list[type[LogitsProcessor]] = []
for ldx, logitproc in enumerate(logits_processors):
if isinstance(logitproc, type):
logger.debug(" - Already-loaded logit processor: %s",
logitproc.__name__)
if not issubclass(logitproc, LogitsProcessor):
raise ValueError(
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
)
classes.append(logitproc)
continue
logger.debug("- Loading logits processor %s", logitproc)
module_path, qualname = logitproc.split(":")
try:
# Load module
module = importlib.import_module(module_path)
except Exception as e:
raise RuntimeError(
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
) from e
# Walk down dotted name to get logitproc class
obj = module
for attr in qualname.split("."):
obj = getattr(obj, attr)
if not isinstance(obj, type):
raise ValueError("Loaded logit processor must be a type.")
if not issubclass(obj, LogitsProcessor):
raise ValueError(
f"{obj.__name__} must be a subclass of LogitsProcessor")
classes.append(obj)
return classes
def _load_custom_logitsprocs(
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]],
) -> list[type[LogitsProcessor]]:
"""Load all custom logits processors.
* First load all installed logitproc plugins
* Second load custom logitsprocs pass by the user at initialization time
Args:
logits_processors: potentially mixed list of logitproc types and
logitproc type fully-qualified names (FQCNs)
which need to be loaded
Returns:
A list of all loaded logitproc types
"""
from vllm.platforms import current_platform
if current_platform.is_tpu():
# No logitsprocs specified by caller
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
return []
return (_load_logitsprocs_plugins() +
_load_logitsprocs_by_fqcns(logits_processors))
def build_logitsprocs(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
) -> LogitsProcessors:
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
logger.debug("Skipping logits processor loading because pooling models"
" do not support logits processors.")
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
__all__ = [
"LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor",
"MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder",
"MoveDirectionality", "LogitsProcessors", "build_logitsprocs",
"STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP"
]

View File

@@ -1,241 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain
from typing import Optional, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
import torch
from torch._prims_common import DeviceLikeType
from vllm import PoolingParams, SamplingParams
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
logger = init_logger(__name__)
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = 0
# Two-way i1<->i2 req swap within batch
SWAP = 1
# (index, params, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclasses.dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Note: each added request is represented as
# (index, params, output_tok_ids)
# Key assumption: output_tok_ids is a reference to the
# request's running output tokens list; in this way
# the logits processors always see the latest list of
# generated tokens
removed: Sequence[RemovedRequest]
moved: Sequence[MovedRequest]
added: Sequence[AddedRequest]
class BatchUpdateBuilder:
"""Helps track persistent batch state changes and build
a batch update data structure for logitsprocs
Assumptions:
* All information about requests removed from persistent batch
during a step is aggregated in self._removed through calls to
self.removed_append() at the beginning of a step. This must happen
before the first time that self.removed, self.pop_removed()
or self.peek_removed() are invoked in a given step
* After the first time that self.removed, self.pop_removed()
or self.peek_removed() are read in a step, no new removals
are registered using self.removed_append()
* Elements of self._removed are never directly modified, added or
removed (i.e. modification is only via self.removed_append() and
self.pop_removed())
Guarantees under above assumptions:
* self.removed is always sorted in descending order
* self.pop_removed() and self.peek_removed() both return
the lowest removed request index in the current step
"""
_removed: list[RemovedRequest]
_is_removed_sorted: bool
moved: list[MovedRequest]
added: list[AddedRequest]
def __init__(
self,
removed: Optional[list[RemovedRequest]] = None,
moved: Optional[list[MovedRequest]] = None,
added: Optional[list[AddedRequest]] = None,
) -> None:
self._removed = removed or []
self.moved = moved or []
self.added = added or []
self._is_removed_sorted = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Idempotent after first call in a
given step, until reset.
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True)
self._is_removed_sorted = True
@property
def removed(self) -> list[RemovedRequest]:
"""Removed request indices sorted in
descending order"""
self._ensure_removed_sorted()
return self._removed
def removed_append(self, index: int) -> None:
"""Register the removal of a request from
the persistent batch.
Must not be called after the first time
self.removed, self.pop_removed() or
self.peek_removed() are invoked.
Args:
index: request index
"""
if self._is_removed_sorted:
raise RuntimeError("Cannot register new removed request after"
" self.removed has been read.")
self._removed.append(index)
def has_removed(self) -> bool:
return bool(self._removed)
def peek_removed(self) -> Optional[int]:
"""Return lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1]
return None
def pop_removed(self) -> Optional[int]:
"""Pop lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop()
return None
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""Generate a logitsprocs batch update data structure
and reset internal batch update builder state.
Args:
batch_size: current persistent batch size
Returns:
Frozen logitsprocs batch update instance; `None` if no updates
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
if not any((self._removed, self.moved, self.added)):
# No update; short-circuit
return None
# Build batch state update
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
# Reset removed/moved/added update lists
self._removed = []
self.moved = []
self.added = []
return batch_update
class LogitsProcessor(ABC):
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
TODO(andy): won't be utilized until logits
processors are user-extensible
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: Optional[BatchUpdate],
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError
@dataclass
class LogitsProcessorManager:
"""Encapsulates initialized logitsproc objects."""
argmax_invariant: list[LogitsProcessor] = field(
default_factory=list) # argmax-invariant logitsprocs
non_argmax_invariant: list[LogitsProcessor] = field(
default_factory=list) # non-argmax-invariant logitsprocs
@property
def all(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
return chain(self.argmax_invariant, self.non_argmax_invariant)
###### ----- Built-in LogitsProcessor impls below here
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MinPLogitsProcessor(LogitsProcessor):
def __init__(self, max_num_reqs: int, pin_memory: bool,
device: DeviceLikeType):
super().__init__()
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.min_p_count: int = 0
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
pin_memory=is_pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.use_double_tensor = torch.device("cpu") != torch.device(device)
self.use_double_tensor = torch.device(device).type != "cpu"
if self.use_double_tensor:
# Pre-allocated device tensor
@@ -260,8 +51,8 @@ class MinPLogitsProcessor(LogitsProcessor):
needs_update = False
# Process added requests.
for index, params, _ in batch_update.added:
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
for index, params, _, _ in batch_update.added:
min_p = params.min_p
if self.min_p_cpu[index] != min_p:
needs_update = True
self.min_p_cpu[index] = min_p
@@ -316,11 +107,10 @@ class MinPLogitsProcessor(LogitsProcessor):
class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, pin_memory: bool, device: torch.device):
super().__init__()
self.biases: dict[int, dict[int, float]] = {}
def __init__(self, _, device: torch.device, is_pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self.pin_memory = is_pin_memory
self.biases: dict[int, dict[int, float]] = {}
self.bias_tensor: torch.Tensor = torch.tensor(())
self.logits_slice = (self._device_tensor([], torch.int32),
@@ -337,9 +127,8 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
needs_update: bool = False
# Process added requests.
for index, params, _ in batch_update.added:
if isinstance(params, SamplingParams) and (lb :=
params.logit_bias):
for index, params, _, _ in batch_update.added:
if lb := params.logit_bias:
self.biases[index] = lb
needs_update = True
else:
@@ -400,12 +189,12 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(self, pin_memory: bool, device: torch.device):
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
# index -> (min_toks, output_token_ids, stop_token_ids)
super().__init__()
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
self.device = device
self.pin_memory = pin_memory
self.pin_memory = is_pin_memory
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
# (req_idx_tensor,eos_tok_id_tensor)
self.logits_slice: tuple[torch.Tensor,
@@ -424,9 +213,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
if batch_update:
# Process added requests.
for index, params, output_tok_ids in batch_update.added:
if (isinstance(params, SamplingParams)
and (min_tokens := params.min_tokens)
for index, params, _, output_tok_ids in batch_update.added:
if ((min_tokens := params.min_tokens)
and len(output_tok_ids) < min_tokens):
# Replace request metadata at batch index
self.min_toks[index] = (min_tokens, output_tok_ids,
@@ -499,35 +287,3 @@ class MinTokensLogitsProcessor(LogitsProcessor):
# Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf")
return logits
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
device: torch.device) -> LogitsProcessorManager:
"""Construct 'builtin' vLLM logitsprocs which the engine
loads by default.
Args:
pin_memory_available: pinned memory is available for use
for use by logitsproc
max_num_reqs: ceiling on request count in persistent batch
device: inference device
Returns:
Data structure encapsulating loaded logitsprocs
"""
min_tokens_logitproc = MinTokensLogitsProcessor(
pin_memory=pin_memory_available, device=device)
logit_bias_logitproc = LogitBiasLogitsProcessor(
pin_memory=pin_memory_available, device=device)
min_p_logitproc = MinPLogitsProcessor(
pin_memory=pin_memory_available,
device=device,
# +1 for temporary swap space
max_num_reqs=max_num_reqs + 1)
return LogitsProcessorManager(
non_argmax_invariant=[
min_tokens_logitproc,
logit_bias_logitproc,
],
argmax_invariant=[min_p_logitproc],
)

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, Optional
import torch
from vllm import SamplingParams
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = auto()
# Two-way i1<->i2 req swap within batch
SWAP = auto()
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Key assumption: the `output_tok_ids` list (which is an element of each
# tuple in `added`) is a reference to the request's running output tokens
# list; via this reference, the logits processors always see the latest
# list of generated output tokens
removed: Sequence[RemovedRequest]
moved: Sequence[MovedRequest]
added: Sequence[AddedRequest]
class LogitsProcessor(ABC):
@abstractmethod
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool) -> None:
raise NotImplementedError
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: Optional["BatchUpdate"],
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from itertools import chain
from typing import TYPE_CHECKING, Optional
from vllm.v1.sample.logits_processor.interface import (AddedRequest,
BatchUpdate,
MovedRequest,
RemovedRequest)
if TYPE_CHECKING:
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
class BatchUpdateBuilder:
"""Helps track persistent batch state changes and build
a batch update data structure for logitsprocs
Assumptions:
* All information about requests removed from persistent batch
during a step is aggregated in self._removed through calls to
self.removed_append() at the beginning of a step. This must happen
before the first time that self.removed, self.pop_removed()
or self.peek_removed() are invoked in a given step
* After the first time that self.removed, self.pop_removed()
or self.peek_removed() are read in a step, no new removals
are registered using self.removed_append()
* Elements of self._removed are never directly modified, added or
removed (i.e. modification is only via self.removed_append() and
self.pop_removed())
Guarantees under above assumptions:
* self.removed is always sorted in descending order
* self.pop_removed() and self.peek_removed() both return
the lowest removed request index in the current step
"""
_removed: list[RemovedRequest]
_is_removed_sorted: bool
moved: list[MovedRequest]
added: list[AddedRequest]
def __init__(
self,
removed: Optional[list[RemovedRequest]] = None,
moved: Optional[list[MovedRequest]] = None,
added: Optional[list[AddedRequest]] = None,
) -> None:
self._removed = removed or []
self.moved = moved or []
self.added = added or []
self._is_removed_sorted = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Idempotent after first call in a
given step, until reset.
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True)
self._is_removed_sorted = True
@property
def removed(self) -> list[RemovedRequest]:
"""Removed request indices sorted in
descending order"""
self._ensure_removed_sorted()
return self._removed
def removed_append(self, index: int) -> None:
"""Register the removal of a request from the persistent batch.
Must not be called after the first time self.removed,
self.pop_removed() or self.peek_removed() are invoked.
Args:
index: request index
"""
if self._is_removed_sorted:
raise RuntimeError("Cannot register new removed request after"
" self.removed has been read.")
self._removed.append(index)
def has_removed(self) -> bool:
return bool(self._removed)
def peek_removed(self) -> Optional[int]:
"""Return lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1]
return None
def pop_removed(self) -> Optional[int]:
"""Pop lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop()
return None
def _is_update(self) -> bool:
"""True if there is a batch state change"""
return any((self._removed, self.moved, self.added))
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""Generate a logitsprocs batch update data structure and reset
internal batch update builder state.
Args:
batch_size: current persistent batch size
Returns:
Frozen logitsprocs batch update instance; `None` if no updates
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
if not self._is_update():
# No update; short-circuit
return None
# Build batch state update
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
self._removed = []
self.moved = []
self.added = []
return batch_update
class LogitsProcessors:
"""Encapsulates initialized logitsproc objects."""
def __init__(
self,
logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None:
self.argmax_invariant: list[LogitsProcessor] = []
self.non_argmax_invariant: list[LogitsProcessor] = []
if logitsprocs:
for logitproc in logitsprocs:
(self.argmax_invariant if logitproc.is_argmax_invariant() else
self.non_argmax_invariant).append(logitproc)
@property
def all(self) -> Iterator["LogitsProcessor"]:
"""Iterator over all logits processors."""
return chain(self.argmax_invariant, self.non_argmax_invariant)

View File

@@ -6,7 +6,7 @@ from typing import Optional
import torch
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
@@ -40,4 +40,4 @@ class SamplingMetadata:
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessorManager
logitsprocs: LogitsProcessors

View File

@@ -18,8 +18,8 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
MoveDirectionality,
init_builtin_logitsprocs)
LogitsProcessors,
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
@@ -78,8 +78,11 @@ class InputBatch:
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@@ -221,14 +224,6 @@ class InputBatch:
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# Define logits processors.
# TODO(andy): logits processor list should be extensible via engine
# constructor argument; for now the list is fixed.
self.logitsprocs = init_builtin_logitsprocs(
pin_memory_available=pin_memory,
max_num_reqs=max_num_reqs + 1,
device=device)
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
@@ -244,6 +239,10 @@ class InputBatch:
self.req_output_token_ids: list[Optional[list[int]]] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@@ -255,28 +254,35 @@ class InputBatch:
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _get_next_add_index(self) -> int:
if (req_index := self.batch_update_builder.pop_removed()) is not None:
# Fill the empty index.
return req_index
# Append to end
return self.num_reqs
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations"""
req_index = self._get_next_add_index()
assert req_index < self.max_num_reqs
params = (request.sampling_params
if request.sampling_params else request.pooling_params)
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append(
(req_index, params, request.output_token_ids))
return req_index
(new_req_index, request.sampling_params, request.prompt_token_ids,
request.output_token_ids))
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
req_index = self._register_add_request(request)
if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
@@ -411,7 +417,10 @@ class InputBatch:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.batch_update_builder.removed_append(req_index)
if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
@@ -446,6 +455,8 @@ class InputBatch:
return req_index
def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
@@ -513,11 +524,18 @@ class InputBatch:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs
if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
@@ -541,6 +559,8 @@ class InputBatch:
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
@@ -596,15 +616,20 @@ class InputBatch:
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:]
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply batch updates, reset input batch at end of step
"""Apply any batch updates to sampling metadata."""
* Apply batch add/remove/permute to logits procs' states
* If batch state is modified, update sampling metadata
"""
if self.is_pooling_model:
# Batch changes every step for pooling models.
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)

View File

@@ -68,6 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
@@ -80,7 +81,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
@@ -221,6 +221,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config, self.device, self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
@@ -2447,7 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
output_token_ids=[[] for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
logitsprocs=LogitsProcessors(),
)
try:
sampler_output = self.sampler(logits=logits,
@@ -2968,6 +2973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
)
def _allocate_kv_cache_tensors(