[V1] Wrapper which plumbs request-level logits processors into vLLM batch-level logits processing (#23656)

Signed-off-by: Andrew Feldman <afeldman@redhat.com>
This commit is contained in:
afeldman-nm
2025-09-02 22:52:51 -04:00
committed by GitHub
parent e32a0e8678
commit 136d853e65
6 changed files with 524 additions and 5 deletions

View File

@@ -3,15 +3,21 @@
import types
from enum import Enum, auto
from typing import Optional
from typing import Any, Optional
import torch
from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
LogitsProcessor)
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP,
AdapterLogitsProcessor,
BatchUpdate, LogitsProcessor,
RequestLogitsProcessor)
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
logger = init_logger(__name__)
MODEL_NAME = "facebook/opt-125m"
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
DUMMY_LOGITPROC_ARG = "target_token"
@@ -104,5 +110,60 @@ class EntryPoints(list):
self.names = [ep.name for ep in eps]
class DummyPerReqLogitsProcessor:
"""The request-level logits processor masks out all logits except the
token id identified by `target_token`"""
def __init__(self, target_token: int) -> None:
"""Specify `target_token`"""
self.target_token = target_token
def __call__(
self,
output_ids: list[int],
logits: torch.Tensor,
) -> torch.Tensor:
val_to_keep = logits[self.target_token].item()
logits[:] = float("-inf")
logits[self.target_token] = val_to_keep
return logits
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""
def is_argmax_invariant(self) -> bool:
return False
def new_req_logits_processor(
self,
params: SamplingParams,
) -> Optional[RequestLogitsProcessor]:
"""This method returns a new request-level logits processor, customized
to the `target_token` value associated with a particular request.
Returns None if the logits processor should not be applied to the
particular request. To use the logits processor the request must have
a "target_token" custom argument with an integer value.
Args:
params: per-request sampling params
Returns:
`Callable` request logits processor, or None
"""
target_token: Optional[
Any] = params.extra_args and params.extra_args.get("target_token")
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.", target_token)
return None
return DummyPerReqLogitsProcessor(target_token)
"""Fake version of importlib.metadata.entry_points"""
entry_points = lambda group: EntryPoints(group)