[V1] LogitsProcessor programming model (#16728)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Andrew Feldman <afeldman@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterator
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm import CompletionOutput
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
class BatchLogprobsComposition(Enum):
|
||||
@@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
|
||||
logprobs = completion_output.logprobs
|
||||
assert logprobs is not None
|
||||
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])
|
||||
|
||||
|
||||
def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
|
||||
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
|
||||
return fake_logits
|
||||
|
||||
|
||||
def create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
return torch.full((batch_size, ),
|
||||
fill_value=penalty_value,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
|
||||
|
||||
def create_prompt_tokens_tensor(
|
||||
prompt_token_ids: list[list[int]],
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
return make_tensor_with_pad(
|
||||
prompt_token_ids,
|
||||
pad=vocab_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
|
||||
class LogitsprocsTestFakes(NamedTuple):
|
||||
"""Wraps fake data structures to support testing"""
|
||||
logits: torch.Tensor
|
||||
sampling_metadata: SamplingMetadata
|
||||
|
||||
def get_logitsprocs_by_cls(
|
||||
self,
|
||||
cls: type[LogitsProcessor],
|
||||
) -> Iterator[LogitsProcessor]:
|
||||
"""Yield logits processors of a specific class.
|
||||
|
||||
Args:
|
||||
cls: :class:`LogitsProcessor` subclass
|
||||
|
||||
Returns:
|
||||
Iterator over logits processors
|
||||
"""
|
||||
return (lp for lp in self.sampling_metadata.logitsprocs.all
|
||||
if isinstance(lp, cls))
|
||||
|
||||
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
return self.sampling_metadata.logitsprocs.all
|
||||
|
||||
|
||||
def fake_update_logitsprocs_state(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
batch_update: BatchUpdate,
|
||||
) -> None:
|
||||
"""Imitate logits processors persistent batch state update
|
||||
in engine core"""
|
||||
for logitproc in test_fakes.get_logitsprocs():
|
||||
logitproc.update_state(batch_update)
|
||||
|
||||
|
||||
def fake_apply_logitsprocs(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
slice_indices: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Imitate application of logits processors in engine core"""
|
||||
logits = test_fakes.logits[torch.tensor(slice_indices,
|
||||
dtype=torch.long)].clone()
|
||||
for processor in test_fakes.get_logitsprocs():
|
||||
logits = processor.apply(logits)
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user