[V1] Logit processors for rejection sampler (#19482)
Signed-off-by: southfreebird <yvorott@gmail.com> Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com> Signed-off-by: Sergei Skvortsov <yvorott@gmail.com> Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.sample.utils import create_allowed_token_ids
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
@@ -51,26 +50,6 @@ def _create_prompt_tokens_tensor(
|
||||
)
|
||||
|
||||
|
||||
def _create_allowed_token_ids(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
num_allowed_token_ids: int,
|
||||
device: torch.device,
|
||||
) -> Optional[torch.Tensor]:
|
||||
mask: Optional[torch.Tensor] = None
|
||||
for i in range(batch_size):
|
||||
if i % 2 == 1:
|
||||
continue
|
||||
if mask is None:
|
||||
mask = torch.zeros(
|
||||
(batch_size, vocab_size), dtype=torch.bool, device=device
|
||||
)
|
||||
start = min(i, vocab_size - 1)
|
||||
end = min(i + num_allowed_token_ids, vocab_size - 1)
|
||||
mask[i, start:end] = True
|
||||
return mask
|
||||
|
||||
|
||||
def _create_bad_words_token_ids(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
@@ -173,6 +152,7 @@ def _create_default_sampling_metadata(
|
||||
prompt_token_ids, vocab_size, device
|
||||
),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[[] for _ in range(batch_size)],
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||
@@ -241,7 +221,9 @@ def test_sampler_presence_penalty(
|
||||
)
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(
|
||||
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
|
||||
)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
# Since all tokens initially have the same logits, the non-penalized
|
||||
@@ -293,7 +275,9 @@ def test_sampler_frequency_penalty(
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(
|
||||
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
|
||||
)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
@@ -343,7 +327,9 @@ def test_sampler_repetition_penalty(
|
||||
)
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(
|
||||
fake_logits, sampling_metadata, sampling_metadata.output_token_ids
|
||||
)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
@@ -394,7 +380,7 @@ def test_sampler_allowed_token_ids(
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)
|
||||
)
|
||||
mask = _create_allowed_token_ids(
|
||||
mask = create_allowed_token_ids(
|
||||
batch_size=batch_size,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
num_allowed_token_ids=num_allowed_token_ids,
|
||||
@@ -402,7 +388,9 @@ def test_sampler_allowed_token_ids(
|
||||
)
|
||||
sampling_metadata.allowed_token_ids_mask = mask
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_logits_processors(
|
||||
fake_logits, sampling_metadata, predict_bonus_token=False
|
||||
)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
@@ -444,7 +432,9 @@ def test_sampler_bad_words(
|
||||
sampling_metadata, VOCAB_SIZE
|
||||
)
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_logits_processors(
|
||||
fake_logits, sampling_metadata, predict_bonus_token=False
|
||||
)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
|
||||
Reference in New Issue
Block a user