[LogitsProcs] Deduplicate built-in LP implementation logic (#23362)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-27 08:11:33 -07:00
committed by GitHub
parent 83f555f637
commit 3ce8285d6d
4 changed files with 95 additions and 143 deletions

View File

@@ -8,10 +8,9 @@ from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
LogitsProcessor,
MoveDirectionality)
LogitsProcessor)
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
MODEL_NAME = "facebook/opt-125m"
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
@@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor):
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
self.req_info: dict[int, SamplingParams] = {}
self.req_info: dict[int, int] = {}
def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return
# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
if params.extra_args and (target_token :=
params.extra_args.get("target_token")):
self.req_info[index] = target_token
if self.req_info:
# Process removed requests.
for index in batch_update.removed:
self.req_info.pop(index, None)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for adx, bdx, direct in batch_update.moved:
a_val = self.req_info.pop(adx, None)
b_val = self.req_info.pop(bdx, None)
if a_val is not None:
self.req_info[bdx] = a_val
if direct == MoveDirectionality.SWAP and b_val is not None:
self.req_info[adx] = b_val
process_dict_updates(
self.req_info,
batch_update,
lambda params, _, __: params.extra_args and
(params.extra_args.get("target_token")),
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info: