2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2024-10-22 01:24:07 -07:00
|
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
2025-07-02 12:10:42 -04:00
|
|
|
from vllm.utils import is_pin_memory_available
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
2024-10-22 01:24:07 -07:00
|
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
2025-03-08 14:50:26 -08:00
|
|
|
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
2025-07-02 12:10:42 -04:00
|
|
|
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
2024-12-27 09:32:38 +09:00
|
|
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.topk_topp_sampler = TopKTopPSampler()
|
2025-06-06 04:59:25 -07:00
|
|
|
self.pin_memory = is_pin_memory_available()
|
2024-12-27 09:32:38 +09:00
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> SamplerOutput:
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
|
|
|
|
# temperature scaling) for the top-k logprobs.
|
|
|
|
|
# This is different from the V0 sampler, which uses the logits that
|
|
|
|
|
# is used for sampling (after penalties and temperature scaling).
|
|
|
|
|
# TODO(rob): provide option for logprobs post sampling.
|
|
|
|
|
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
|
|
|
|
num_logprobs = sampling_metadata.max_num_logprobs
|
|
|
|
|
if num_logprobs is not None:
|
|
|
|
|
raw_logprobs = self.compute_logprobs(logits)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Use float32 for the logits.
|
|
|
|
|
logits = logits.to(torch.float32)
|
2025-02-21 22:13:05 -08:00
|
|
|
# Apply allowed token ids.
|
|
|
|
|
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
2025-03-08 14:50:26 -08:00
|
|
|
# Apply bad words exclusion.
|
|
|
|
|
logits = self.apply_bad_words(logits, sampling_metadata)
|
2025-07-02 12:10:42 -04:00
|
|
|
|
|
|
|
|
# Apply logits processors which can impact greedy sampling
|
|
|
|
|
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
|
|
|
|
|
logits = processor.apply(logits)
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
|
|
|
|
logits = self.apply_penalties(logits, sampling_metadata)
|
|
|
|
|
# Sample the next token.
|
|
|
|
|
sampled = self.sample(logits, sampling_metadata)
|
2025-03-18 23:52:19 -07:00
|
|
|
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
|
|
|
|
# with subsequent operations that may use these values as indices.
|
|
|
|
|
# This conversion is necessary because FlashInfer sampling operations
|
|
|
|
|
# return int32 (while PyTorch argmax and topk return int64).
|
|
|
|
|
sampled = sampled.long()
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
|
|
|
|
# Gather the logprobs of the topk and sampled token (if requested).
|
|
|
|
|
# Get logprobs and rank tensors (if requested)
|
|
|
|
|
logprobs_tensors = None if num_logprobs is None else \
|
|
|
|
|
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Use int32 to reduce the tensor size.
|
|
|
|
|
sampled = sampled.to(torch.int32)
|
|
|
|
|
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
# These are GPU tensors.
|
2024-10-22 01:24:07 -07:00
|
|
|
sampler_output = SamplerOutput(
|
2025-02-15 18:05:11 -08:00
|
|
|
# The sampled tokens are expanded to 2D tensor with shape
|
|
|
|
|
# [num_requests, 1], where each row represents one generated
|
|
|
|
|
# token per request.
|
|
|
|
|
sampled_token_ids=sampled.unsqueeze(-1),
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
logprobs_tensors=logprobs_tensors,
|
2024-10-22 01:24:07 -07:00
|
|
|
)
|
|
|
|
|
return sampler_output
|
|
|
|
|
|
|
|
|
|
def apply_temperature(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
temp: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
# Use in-place division to avoid creating a new tensor.
|
2025-02-20 22:05:56 -08:00
|
|
|
return logits.div_(temp.unsqueeze(dim=1))
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
return logits.argmax(dim=-1).view(-1)
|
|
|
|
|
|
|
|
|
|
def sample(
|
2024-10-22 01:24:07 -07:00
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
2025-03-26 10:56:47 -07:00
|
|
|
"""Sample logits based on sampling metadata.
|
|
|
|
|
|
|
|
|
|
The various logits processing functions called in this method
|
|
|
|
|
may update the logits tensor in-place.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
assert not (sampling_metadata.all_greedy
|
|
|
|
|
and sampling_metadata.all_random)
|
2025-02-14 18:10:53 -08:00
|
|
|
if sampling_metadata.all_random:
|
|
|
|
|
greedy_sampled = None
|
|
|
|
|
else:
|
|
|
|
|
greedy_sampled = self.greedy_sample(logits)
|
|
|
|
|
if sampling_metadata.all_greedy:
|
|
|
|
|
return greedy_sampled
|
2024-12-27 09:32:38 +09:00
|
|
|
|
2025-02-20 22:05:56 -08:00
|
|
|
assert sampling_metadata.temperature is not None
|
|
|
|
|
|
2025-02-14 18:10:53 -08:00
|
|
|
# Apply temperature.
|
|
|
|
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
|
|
|
|
|
2025-07-02 12:10:42 -04:00
|
|
|
# Apply logits processors that only apply to random sampling
|
|
|
|
|
# (argmax invariant)
|
|
|
|
|
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
|
|
|
|
logits = processor.apply(logits)
|
2025-02-14 18:10:53 -08:00
|
|
|
|
|
|
|
|
# Apply top_k and/or top_p.
|
2024-12-27 09:32:38 +09:00
|
|
|
random_sampled = self.topk_topp_sampler(
|
2024-10-22 01:24:07 -07:00
|
|
|
logits,
|
2024-12-27 09:32:38 +09:00
|
|
|
sampling_metadata.generators,
|
2024-10-22 01:24:07 -07:00
|
|
|
sampling_metadata.top_k,
|
|
|
|
|
sampling_metadata.top_p,
|
|
|
|
|
)
|
2025-02-15 07:50:05 +08:00
|
|
|
|
2025-02-14 18:10:53 -08:00
|
|
|
if greedy_sampled is None:
|
2024-12-27 09:32:38 +09:00
|
|
|
return random_sampled
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
|
sampled = torch.where(
|
|
|
|
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
|
|
|
|
greedy_sampled,
|
|
|
|
|
random_sampled,
|
2025-02-20 22:05:56 -08:00
|
|
|
out=greedy_sampled, # Reuse tensor
|
2024-10-22 01:24:07 -07:00
|
|
|
)
|
|
|
|
|
return sampled
|
|
|
|
|
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
def gather_logprobs(
|
2024-12-27 09:32:38 +09:00
|
|
|
self,
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
logprobs: torch.Tensor,
|
|
|
|
|
num_logprobs: int,
|
|
|
|
|
token_ids: torch.Tensor,
|
|
|
|
|
) -> LogprobsTensors:
|
|
|
|
|
"""
|
|
|
|
|
Gather logprobs for topk and sampled/prompt token.
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-03-24 23:45:32 +08:00
|
|
|
logprobs: (num tokens) x (vocab) tensor
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
num_logprobs: minimum number of logprobs to
|
|
|
|
|
retain per token
|
|
|
|
|
token_ids: prompt tokens (if prompt logprobs)
|
|
|
|
|
or sampled tokens (if sampled
|
|
|
|
|
logprobs); 1D token ID tensor
|
|
|
|
|
with (num tokens) elements
|
2025-03-18 23:52:19 -07:00
|
|
|
Must be int64.
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
|
|
|
|
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
|
|
|
|
Sampled token rank tensor, (num tokens)
|
|
|
|
|
"""
|
2025-03-18 23:52:19 -07:00
|
|
|
assert token_ids.dtype == torch.int64
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
# Find the topK values.
|
|
|
|
|
topk_logprobs, topk_indices = torch.topk(logprobs,
|
|
|
|
|
num_logprobs,
|
|
|
|
|
dim=-1)
|
|
|
|
|
|
|
|
|
|
# Get with the logprob of the prompt or sampled token.
|
2025-03-18 23:52:19 -07:00
|
|
|
token_ids = token_ids.unsqueeze(-1)
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
token_logprobs = logprobs.gather(-1, token_ids)
|
|
|
|
|
|
|
|
|
|
# Compute the ranks of the actual token.
|
|
|
|
|
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
|
|
|
|
|
|
|
|
|
# Concatenate together with the topk.
|
|
|
|
|
indices = torch.cat((token_ids, topk_indices), dim=1)
|
|
|
|
|
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Use int32 to reduce the tensor size.
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
indices = indices.to(torch.int32)
|
|
|
|
|
|
|
|
|
|
return LogprobsTensors(indices, logprobs, token_ranks)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def apply_penalties(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if not sampling_metadata.no_penalties:
|
|
|
|
|
assert sampling_metadata.prompt_token_ids is not None
|
2024-12-27 09:47:10 +09:00
|
|
|
logits = apply_all_penalties(
|
2025-02-21 22:13:05 -08:00
|
|
|
logits,
|
|
|
|
|
sampling_metadata.prompt_token_ids,
|
2024-12-27 09:47:10 +09:00
|
|
|
sampling_metadata.presence_penalties,
|
|
|
|
|
sampling_metadata.frequency_penalties,
|
|
|
|
|
sampling_metadata.repetition_penalties,
|
2025-02-21 22:13:05 -08:00
|
|
|
sampling_metadata.output_token_ids,
|
|
|
|
|
)
|
2024-10-22 01:24:07 -07:00
|
|
|
return logits
|
2025-02-14 04:34:59 -08:00
|
|
|
|
2025-02-21 22:13:05 -08:00
|
|
|
def apply_allowed_token_ids(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if sampling_metadata.allowed_token_ids_mask is not None:
|
|
|
|
|
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
|
|
|
|
float("-inf"))
|
|
|
|
|
return logits
|
2025-03-08 14:50:26 -08:00
|
|
|
|
|
|
|
|
def apply_bad_words(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if sampling_metadata.bad_words_token_ids:
|
|
|
|
|
apply_bad_words(
|
|
|
|
|
logits,
|
|
|
|
|
sampling_metadata.bad_words_token_ids,
|
|
|
|
|
sampling_metadata.output_token_ids,
|
|
|
|
|
)
|
|
|
|
|
return logits
|