[V1] LogitsProcessor programming model (#16728)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm
2025-07-02 12:10:42 -04:00
committed by GitHub
parent c1909e7e8c
commit 48fb076cbc
13 changed files with 1401 additions and 393 deletions

View File

@@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from enum import Enum
from typing import Optional
from typing import NamedTuple, Optional
import regex as re
import torch
from vllm import CompletionOutput
from vllm.utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
class BatchLogprobsComposition(Enum):
@@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
logprobs = completion_output.logprobs
assert logprobs is not None
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])
def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
return fake_logits
def create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)
def create_prompt_tokens_tensor(
prompt_token_ids: list[list[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
return make_tensor_with_pad(
prompt_token_ids,
pad=vocab_size,
device=device,
dtype=torch.int64,
pin_memory=False,
)
class LogitsprocsTestFakes(NamedTuple):
"""Wraps fake data structures to support testing"""
logits: torch.Tensor
sampling_metadata: SamplingMetadata
def get_logitsprocs_by_cls(
self,
cls: type[LogitsProcessor],
) -> Iterator[LogitsProcessor]:
"""Yield logits processors of a specific class.
Args:
cls: :class:`LogitsProcessor` subclass
Returns:
Iterator over logits processors
"""
return (lp for lp in self.sampling_metadata.logitsprocs.all
if isinstance(lp, cls))
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
return self.sampling_metadata.logitsprocs.all
def fake_update_logitsprocs_state(
test_fakes: LogitsprocsTestFakes,
batch_update: BatchUpdate,
) -> None:
"""Imitate logits processors persistent batch state update
in engine core"""
for logitproc in test_fakes.get_logitsprocs():
logitproc.update_state(batch_update)
def fake_apply_logitsprocs(
test_fakes: LogitsprocsTestFakes,
slice_indices: list[int],
) -> torch.Tensor:
"""Imitate application of logits processors in engine core"""
logits = test_fakes.logits[torch.tensor(slice_indices,
dtype=torch.long)].clone()
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits