[V1] Logit processors for rejection sampler (#19482)
Signed-off-by: southfreebird <yvorott@gmail.com> Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com> Signed-off-by: Sergei Skvortsov <yvorott@gmail.com> Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -8,6 +8,8 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
@@ -83,6 +85,14 @@ class RejectionSampler(nn.Module):
|
||||
A tensor containing the final output token IDs.
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
|
||||
# Use float32 for the target_logits.
|
||||
target_logits = target_logits.to(torch.float32)
|
||||
|
||||
target_logits = self.apply_logits_processors(
|
||||
target_logits, sampling_metadata, metadata
|
||||
)
|
||||
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
@@ -131,6 +141,100 @@ class RejectionSampler(nn.Module):
|
||||
]
|
||||
return outputs
|
||||
|
||||
def apply_logits_processors(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
) -> torch.Tensor:
|
||||
any_penalties_or_bad_words = (
|
||||
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
||||
)
|
||||
|
||||
output_token_ids = sampling_metadata.output_token_ids
|
||||
if any_penalties_or_bad_words:
|
||||
output_token_ids = self._combine_outputs_with_spec_tokens(
|
||||
sampling_metadata.output_token_ids,
|
||||
sampling_metadata.spec_token_ids,
|
||||
)
|
||||
|
||||
# Calculate indices of target logits.
|
||||
if (
|
||||
sampling_metadata.allowed_token_ids_mask is not None
|
||||
or not sampling_metadata.no_penalties
|
||||
):
|
||||
num_requests = len(sampling_metadata.output_token_ids)
|
||||
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
||||
original_indices = torch.arange(num_requests, device="cpu")
|
||||
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
|
||||
repeat_indices = repeat_indices_cpu.to(
|
||||
device=logits.device, non_blocking=True
|
||||
)
|
||||
logits = self.apply_penalties(
|
||||
logits, sampling_metadata, metadata, repeat_indices, output_token_ids
|
||||
)
|
||||
|
||||
# Apply allowed token ids.
|
||||
if sampling_metadata.allowed_token_ids_mask is not None:
|
||||
token_mask = sampling_metadata.allowed_token_ids_mask[repeat_indices]
|
||||
logits.masked_fill_(token_mask, float("-inf"))
|
||||
|
||||
# Apply bad words exclusion.
|
||||
if sampling_metadata.bad_words_token_ids:
|
||||
apply_bad_words_with_drafts(
|
||||
logits,
|
||||
sampling_metadata.bad_words_token_ids,
|
||||
output_token_ids,
|
||||
metadata.num_draft_tokens,
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
metadata: SpecDecodeMetadata,
|
||||
repeat_indices: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.no_penalties:
|
||||
return logits
|
||||
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
|
||||
prompt_token_ids = sampling_metadata.prompt_token_ids[repeat_indices]
|
||||
presence_penalties = sampling_metadata.presence_penalties[repeat_indices]
|
||||
frequency_penalties = sampling_metadata.frequency_penalties[repeat_indices]
|
||||
repetition_penalties = sampling_metadata.repetition_penalties[repeat_indices]
|
||||
|
||||
logits = apply_all_penalties(
|
||||
logits,
|
||||
prompt_token_ids,
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
repetition_penalties,
|
||||
output_token_ids,
|
||||
)
|
||||
return logits
|
||||
|
||||
def _combine_outputs_with_spec_tokens(
|
||||
self,
|
||||
output_token_ids: list[list[int]],
|
||||
spec_token_ids: Optional[list[list[int]]] = None,
|
||||
) -> list[list[int]]:
|
||||
if spec_token_ids is None:
|
||||
return output_token_ids
|
||||
|
||||
result = []
|
||||
for out, spec in zip(output_token_ids, spec_token_ids):
|
||||
if len(spec) == 0:
|
||||
continue
|
||||
result.append(out)
|
||||
for i in range(len(spec) - 1):
|
||||
result.append([*result[-1], spec[i]])
|
||||
return result
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_tokens]
|
||||
|
||||
Reference in New Issue
Block a user