[LogitsProcs] Deduplicate built-in LP implementation logic (#23362)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -8,10 +8,9 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality)
|
||||
LogitsProcessor)
|
||||
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
@@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
self.req_info: dict[int, SamplingParams] = {}
|
||||
self.req_info: dict[int, int] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Never impacts greedy sampling"""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
assert params is not None
|
||||
if params.extra_args and (target_token :=
|
||||
params.extra_args.get("target_token")):
|
||||
self.req_info[index] = target_token
|
||||
|
||||
if self.req_info:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
self.req_info.pop(index, None)
|
||||
|
||||
# Process moved requests, unidirectional move (a->b) and swap
|
||||
# (a<->b)
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
a_val = self.req_info.pop(adx, None)
|
||||
b_val = self.req_info.pop(bdx, None)
|
||||
if a_val is not None:
|
||||
self.req_info[bdx] = a_val
|
||||
if direct == MoveDirectionality.SWAP and b_val is not None:
|
||||
self.req_info[adx] = b_val
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
lambda params, _, __: params.extra_args and
|
||||
(params.extra_args.get("target_token")),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.req_info:
|
||||
|
||||
Reference in New Issue
Block a user