[V1][spec decode] return logprobs for spec decoding (#26060)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -44,17 +48,22 @@ class RejectionSampler(nn.Module):
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Sampler):
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
logprobs_mode = self.sampler.logprobs_mode
|
||||
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
|
||||
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
# [num_tokens + batch_size, vocab_size]
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
Args:
|
||||
metadata:
|
||||
@@ -63,43 +72,65 @@ class RejectionSampler(nn.Module):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
Shape is [num_tokens + batch_size, vocab_size]. Here,
|
||||
probabilities from different requests are flattened into a
|
||||
single tensor because this is the shape of the output logits.
|
||||
NOTE: `logits` can be updated in place to save memory.
|
||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
SamplerOutput:
|
||||
Contains the final output token IDs and their logprobs if
|
||||
requested.
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
|
||||
# Use float32 for the target_logits.
|
||||
target_logits = target_logits.to(torch.float32)
|
||||
bonus_logits_indices = metadata.bonus_logits_indices
|
||||
target_logits_indices = metadata.target_logits_indices
|
||||
|
||||
target_logits = self.apply_logits_processors(
|
||||
target_logits, sampling_metadata, metadata
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[bonus_logits_indices]
|
||||
bonus_sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=replace(
|
||||
sampling_metadata,
|
||||
max_num_logprobs=-1,
|
||||
),
|
||||
predict_bonus_token=True,
|
||||
# Override the logprobs mode to return logits because they are
|
||||
# needed later to compute the accepted token logprobs.
|
||||
logprobs_mode_override="processed_logits"
|
||||
if self.is_processed_logprobs_mode
|
||||
else "raw_logits",
|
||||
)
|
||||
bonus_token_ids = bonus_sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
raw_target_logits = logits[target_logits_indices]
|
||||
# Use float32 for the target_logits.
|
||||
raw_target_logits = raw_target_logits.to(torch.float32)
|
||||
target_logits = self.apply_logits_processors(
|
||||
raw_target_logits, sampling_metadata, metadata
|
||||
)
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
target_probs = compute_probs(
|
||||
# `apply_sampling_constraints` function.
|
||||
target_logits = apply_sampling_constraints(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
# Compute probability distribution from target logits.
|
||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
@@ -111,7 +142,63 @@ class RejectionSampler(nn.Module):
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
logprobs_tensors = None
|
||||
if sampling_metadata.max_num_logprobs:
|
||||
logprobs_tensors = self._get_logprobs_tensors(
|
||||
sampling_metadata.max_num_logprobs,
|
||||
metadata,
|
||||
logits,
|
||||
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
|
||||
bonus_sampler_output.logprobs_tensors.logprobs,
|
||||
output_token_ids,
|
||||
)
|
||||
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=output_token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
|
||||
def _get_logprobs_tensors(
|
||||
self,
|
||||
max_num_logprobs: int,
|
||||
metadata: SpecDecodeMetadata,
|
||||
logits: torch.Tensor,
|
||||
target_logits: torch.Tensor,
|
||||
bonus_logits: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
|
||||
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
|
||||
|
||||
# Collect target and bonus logits.
|
||||
bonus_logits_indices = metadata.bonus_logits_indices
|
||||
target_logits_indices = metadata.target_logits_indices
|
||||
final_logits = torch.zeros_like(logits, dtype=torch.float32)
|
||||
final_logits[target_logits_indices] = target_logits.to(torch.float32)
|
||||
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
|
||||
|
||||
# Compute accepted token indices.
|
||||
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
|
||||
num_accepted_tokens = accepted_mask.sum(dim=-1)
|
||||
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
|
||||
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
|
||||
num_accepted_tokens
|
||||
)
|
||||
|
||||
# Compute logprobs for accepted tokens.
|
||||
accepted_logits = final_logits[accepted_logit_indices]
|
||||
accepted_logprobs = (
|
||||
accepted_logits
|
||||
if self.is_logits_logprobs_mode
|
||||
else self.sampler.compute_logprobs(accepted_logits)
|
||||
)
|
||||
accepted_tokens = sampled_token_ids[accepted_mask]
|
||||
return self.sampler.gather_logprobs(
|
||||
accepted_logprobs,
|
||||
max_num_logprobs,
|
||||
accepted_tokens.to(torch.int64),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
@@ -119,14 +206,12 @@ class RejectionSampler(nn.Module):
|
||||
vocab_size: int,
|
||||
) -> list[list[int]]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
[batch_size, max_spec_len + 1]. The rejected tokens are
|
||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||
and will be filtered out in this function.
|
||||
vocab_size: The size of the vocabulary.
|
||||
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
"""
|
||||
@@ -328,27 +413,26 @@ def rejection_sample(
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def compute_probs(
|
||||
def apply_sampling_constraints(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Compute probability distribution from logits based on sampling metadata.
|
||||
"""Process logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits and converts
|
||||
them to probabilities using softmax. For greedy decoding, it returns
|
||||
This function applies temperature scaling to the logits,
|
||||
as well as top-k and top-p. For greedy decoding, it returns
|
||||
the original logits.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be converted to probabilities.
|
||||
logits: Input logits tensor to be processed.
|
||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||
sampling_metadata: Metadata containing sampling parameters such as
|
||||
temperature and whether greedy sampling is used.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Probability distribution (softmax of scaled logits)
|
||||
if non-greedy sampling is used, otherwise returns the
|
||||
original logits.
|
||||
torch.Tensor: Processed logits if non-greedy sampling is used,
|
||||
otherwise returns the original logits.
|
||||
"""
|
||||
assert logits.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
@@ -384,9 +468,7 @@ def compute_probs(
|
||||
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return output_prob
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
|
||||
def expand_batch_to_tokens(
|
||||
|
||||
Reference in New Issue
Block a user