[V1] Logit processors for rejection sampler (#19482)

Signed-off-by: southfreebird <yvorott@gmail.com>
Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com>
Signed-off-by: Sergei Skvortsov <yvorott@gmail.com>
Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Sergei Skvortsov
2025-10-07 21:02:49 +01:00
committed by GitHub
parent 0c824fc46f
commit 6ebaf43ee4
12 changed files with 471 additions and 92 deletions

View File

@@ -8,6 +8,8 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -83,6 +85,14 @@ class RejectionSampler(nn.Module):
A tensor containing the final output token IDs.
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
# Use float32 for the target_logits.
target_logits = target_logits.to(torch.float32)
target_logits = self.apply_logits_processors(
target_logits, sampling_metadata, metadata
)
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
@@ -131,6 +141,100 @@ class RejectionSampler(nn.Module):
]
return outputs
def apply_logits_processors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
) -> torch.Tensor:
any_penalties_or_bad_words = (
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
)
output_token_ids = sampling_metadata.output_token_ids
if any_penalties_or_bad_words:
output_token_ids = self._combine_outputs_with_spec_tokens(
sampling_metadata.output_token_ids,
sampling_metadata.spec_token_ids,
)
# Calculate indices of target logits.
if (
sampling_metadata.allowed_token_ids_mask is not None
or not sampling_metadata.no_penalties
):
num_requests = len(sampling_metadata.output_token_ids)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
repeat_indices = repeat_indices_cpu.to(
device=logits.device, non_blocking=True
)
logits = self.apply_penalties(
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
)
# Apply allowed token ids.
if sampling_metadata.allowed_token_ids_mask is not None:
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
logits.masked_fill_(token_mask, float("-inf"))
# Apply bad words exclusion.
if sampling_metadata.bad_words_token_ids:
apply_bad_words_with_drafts(
logits,
sampling_metadata.bad_words_token_ids,
output_token_ids,
metadata.num_draft_tokens,
)
return logits
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
metadata: SpecDecodeMetadata,
repeat_indices: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
logits = apply_all_penalties(
logits,
prompt_token_ids,
presence_penalties,
frequency_penalties,
repetition_penalties,
output_token_ids,
)
return logits
def _combine_outputs_with_spec_tokens(
self,
output_token_ids: list[list[int]],
spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]:
if spec_token_ids is None:
return output_token_ids
result = []
for out, spec in zip(output_token_ids, spec_token_ids):
if len(spec) == 0:
continue
result.append(out)
for i in range(len(spec) - 1):
result.append([*result[-1], spec[i]])
return result
def rejection_sample(
# [num_tokens]