[LogitsProcs] Deduplicate built-in LP implementation logic (#23362)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-27 08:11:33 -07:00
committed by GitHub
parent 83f555f637
commit 3ce8285d6d
4 changed files with 95 additions and 143 deletions

View File

@@ -42,8 +42,8 @@ from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
# Hypothetical custom logits processor
@@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
self.req_info: dict[int, SamplingParams] = {}
self.req_info: dict[int, int] = {}
def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return
# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
if params.extra_args and (
target_token := params.extra_args.get("target_token")
):
self.req_info[index] = target_token
if self.req_info:
# Process removed requests.
for index in batch_update.removed:
self.req_info.pop(index, None)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for adx, bdx, direct in batch_update.moved:
a_val = self.req_info.pop(adx, None)
b_val = self.req_info.pop(bdx, None)
if a_val is not None:
self.req_info[bdx] = a_val
if direct == MoveDirectionality.SWAP and b_val is not None:
self.req_info[adx] = b_val
process_dict_updates(
self.req_info,
batch_update,
# This function returns the LP's per-request state based on the
# request details, or None if this LP does not apply to the
# request.
lambda params, _, __: params.extra_args
and (params.extra_args.get("target_token")),
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info: