2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2024-10-22 01:24:07 -07:00
|
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
|
|
|
|
2025-08-20 21:28:32 -07:00
|
|
|
from typing import Optional
|
|
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
2025-10-08 15:10:00 +01:00
|
|
|
from vllm.config.model import LogprobsMode
|
2025-07-02 12:10:42 -04:00
|
|
|
from vllm.utils import is_pin_memory_available
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
2024-10-22 01:24:07 -07:00
|
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
2025-03-08 14:50:26 -08:00
|
|
|
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
2025-07-21 13:47:47 -07:00
|
|
|
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
|
2025-07-02 12:10:42 -04:00
|
|
|
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
2024-12-27 09:32:38 +09:00
|
|
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Sampler(nn.Module):
|
2025-08-20 21:28:32 -07:00
|
|
|
"""
|
|
|
|
|
A layer that samples the next tokens from the model's outputs
|
|
|
|
|
with the following steps in order:
|
|
|
|
|
|
2025-10-05 15:06:22 +01:00
|
|
|
1. If logprobs are requested:
|
2025-08-20 21:28:32 -07:00
|
|
|
a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
|
2025-10-05 15:06:22 +01:00
|
|
|
as the final logprobs to return.
|
2025-08-20 21:28:32 -07:00
|
|
|
b) If `logprobs_mode` is `raw_logits`, clone the logits
|
2025-10-05 15:06:22 +01:00
|
|
|
as the final logprobs to return.
|
|
|
|
|
2. Convert logits to float32.
|
|
|
|
|
3. Apply allowed token ids whitelist.
|
|
|
|
|
4. Apply bad words exclusion.
|
2025-08-20 21:28:32 -07:00
|
|
|
5. Apply logit processors which are not argmax-invariant,
|
2025-10-05 15:06:22 +01:00
|
|
|
i.e. that can impact greedy sampling.
|
|
|
|
|
a) Min tokens processor
|
|
|
|
|
b) Logit bias processor
|
|
|
|
|
6. Apply penalties
|
|
|
|
|
a) Repetition penalty
|
|
|
|
|
b) Frequency penalty
|
|
|
|
|
c) Presence penalty
|
|
|
|
|
7. Sample the next tokens. `sample` method performs the following steps:
|
2025-08-20 21:28:32 -07:00
|
|
|
a) If not `all_random`, perform greedy sampling. If `all_greedy`,
|
2025-10-05 15:06:22 +01:00
|
|
|
return the greedily sampled tokens and final logprobs if requested.
|
|
|
|
|
b) Apply temperature.
|
2025-08-20 21:28:32 -07:00
|
|
|
c) Apply logit processors which are argmax-invariant, by default
|
2025-10-05 15:06:22 +01:00
|
|
|
the min_p processor.
|
|
|
|
|
d) Apply top_k and/or top_p.
|
|
|
|
|
e) Sample the next tokens with the probability distribution.
|
2025-08-20 21:28:32 -07:00
|
|
|
f) If `all_random` or temperature >= epsilon (1e-5), return the
|
|
|
|
|
randomly sampled tokens and final logprobs if requested. Else,
|
2025-10-05 15:06:22 +01:00
|
|
|
return the greedily sampled tokens and logprobs if requested.
|
2025-08-20 21:28:32 -07:00
|
|
|
8. Gather the logprobs of the top `max_num_logprobs` and sampled token
|
|
|
|
|
(if requested). Note that if the sampled token is within the top
|
|
|
|
|
`max_num_logprobs`, the logprob will be eventually merged in
|
|
|
|
|
`LogprobsProcessor` during output processing. Therefore, the
|
|
|
|
|
final output may contain either `max_num_logprobs + 1` or
|
2025-10-05 15:06:22 +01:00
|
|
|
`max_num_logprobs` logprobs.
|
2025-08-20 21:28:32 -07:00
|
|
|
9. Return the final `SamplerOutput`.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-09-19 17:22:33 +01:00
|
|
|
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
|
2024-12-27 09:32:38 +09:00
|
|
|
super().__init__()
|
2025-08-20 21:28:32 -07:00
|
|
|
self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
|
2025-06-06 04:59:25 -07:00
|
|
|
self.pin_memory = is_pin_memory_available()
|
2025-07-23 01:39:25 -07:00
|
|
|
self.logprobs_mode = logprobs_mode
|
2024-12-27 09:32:38 +09:00
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
2025-10-07 21:02:49 +01:00
|
|
|
predict_bonus_token: bool = False,
|
2024-10-22 01:24:07 -07:00
|
|
|
) -> SamplerOutput:
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
|
|
|
|
# temperature scaling) for the top-k logprobs.
|
|
|
|
|
# This is different from the V0 sampler, which uses the logits that
|
|
|
|
|
# is used for sampling (after penalties and temperature scaling).
|
|
|
|
|
num_logprobs = sampling_metadata.max_num_logprobs
|
|
|
|
|
if num_logprobs is not None:
|
2025-09-19 17:22:33 +01:00
|
|
|
if self.logprobs_mode == "raw_logprobs":
|
2025-07-23 01:39:25 -07:00
|
|
|
raw_logprobs = self.compute_logprobs(logits)
|
2025-09-19 17:22:33 +01:00
|
|
|
elif self.logprobs_mode == "raw_logits":
|
2025-07-23 01:39:25 -07:00
|
|
|
raw_logprobs = logits.clone()
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Use float32 for the logits.
|
|
|
|
|
logits = logits.to(torch.float32)
|
2025-07-23 01:39:25 -07:00
|
|
|
|
2025-10-07 21:02:49 +01:00
|
|
|
logits = self.apply_logits_processors(
|
|
|
|
|
logits, sampling_metadata, predict_bonus_token
|
|
|
|
|
)
|
2024-12-27 09:32:38 +09:00
|
|
|
# Sample the next token.
|
2025-08-20 21:28:32 -07:00
|
|
|
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
|
|
|
|
|
if processed_logprobs is not None:
|
|
|
|
|
raw_logprobs = processed_logprobs
|
2025-03-18 23:52:19 -07:00
|
|
|
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
|
|
|
|
# with subsequent operations that may use these values as indices.
|
|
|
|
|
# This conversion is necessary because FlashInfer sampling operations
|
|
|
|
|
# return int32 (while PyTorch argmax and topk return int64).
|
|
|
|
|
sampled = sampled.long()
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
|
|
|
|
# Gather the logprobs of the topk and sampled token (if requested).
|
|
|
|
|
# Get logprobs and rank tensors (if requested)
|
2025-10-05 15:06:22 +01:00
|
|
|
logprobs_tensors = (
|
|
|
|
|
None
|
|
|
|
|
if num_logprobs is None
|
|
|
|
|
else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
|
|
|
|
)
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Use int32 to reduce the tensor size.
|
|
|
|
|
sampled = sampled.to(torch.int32)
|
|
|
|
|
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
# These are GPU tensors.
|
2024-10-22 01:24:07 -07:00
|
|
|
sampler_output = SamplerOutput(
|
2025-02-15 18:05:11 -08:00
|
|
|
# The sampled tokens are expanded to 2D tensor with shape
|
|
|
|
|
# [num_requests, 1], where each row represents one generated
|
|
|
|
|
# token per request.
|
|
|
|
|
sampled_token_ids=sampled.unsqueeze(-1),
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
logprobs_tensors=logprobs_tensors,
|
2024-10-22 01:24:07 -07:00
|
|
|
)
|
|
|
|
|
return sampler_output
|
|
|
|
|
|
|
|
|
|
def apply_temperature(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
temp: torch.Tensor,
|
2025-09-25 13:32:21 +08:00
|
|
|
all_random: bool,
|
2024-10-22 01:24:07 -07:00
|
|
|
) -> torch.Tensor:
|
|
|
|
|
# Use in-place division to avoid creating a new tensor.
|
2025-09-25 13:32:21 +08:00
|
|
|
# Avoid division by zero if there are greedy requests.
|
|
|
|
|
if not all_random:
|
|
|
|
|
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
2025-02-20 22:05:56 -08:00
|
|
|
return logits.div_(temp.unsqueeze(dim=1))
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
return logits.argmax(dim=-1).view(-1)
|
|
|
|
|
|
|
|
|
|
def sample(
|
2024-10-22 01:24:07 -07:00
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
2025-08-20 21:28:32 -07:00
|
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
2025-03-26 10:56:47 -07:00
|
|
|
"""Sample logits based on sampling metadata.
|
|
|
|
|
|
|
|
|
|
The various logits processing functions called in this method
|
|
|
|
|
may update the logits tensor in-place.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-10-05 15:06:22 +01:00
|
|
|
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
|
2025-02-14 18:10:53 -08:00
|
|
|
if sampling_metadata.all_random:
|
|
|
|
|
greedy_sampled = None
|
|
|
|
|
else:
|
|
|
|
|
greedy_sampled = self.greedy_sample(logits)
|
|
|
|
|
if sampling_metadata.all_greedy:
|
2025-08-20 21:28:32 -07:00
|
|
|
processed_logprobs = None
|
|
|
|
|
if sampling_metadata.max_num_logprobs is not None:
|
2025-09-19 17:22:33 +01:00
|
|
|
if self.logprobs_mode == "processed_logits":
|
2025-08-20 21:28:32 -07:00
|
|
|
processed_logprobs = logits
|
2025-09-19 17:22:33 +01:00
|
|
|
elif self.logprobs_mode == "processed_logprobs":
|
2025-08-20 21:28:32 -07:00
|
|
|
processed_logprobs = self.compute_logprobs(logits)
|
|
|
|
|
return greedy_sampled, processed_logprobs
|
2024-12-27 09:32:38 +09:00
|
|
|
|
2025-02-20 22:05:56 -08:00
|
|
|
assert sampling_metadata.temperature is not None
|
|
|
|
|
|
2025-02-14 18:10:53 -08:00
|
|
|
# Apply temperature.
|
2025-10-05 15:06:22 +01:00
|
|
|
logits = self.apply_temperature(
|
|
|
|
|
logits, sampling_metadata.temperature, sampling_metadata.all_random
|
|
|
|
|
)
|
2025-02-14 18:10:53 -08:00
|
|
|
|
2025-07-02 12:10:42 -04:00
|
|
|
# Apply logits processors that only apply to random sampling
|
|
|
|
|
# (argmax invariant)
|
|
|
|
|
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
|
|
|
|
logits = processor.apply(logits)
|
2025-02-14 18:10:53 -08:00
|
|
|
|
|
|
|
|
# Apply top_k and/or top_p.
|
2025-08-20 21:28:32 -07:00
|
|
|
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
2024-10-22 01:24:07 -07:00
|
|
|
logits,
|
2024-12-27 09:32:38 +09:00
|
|
|
sampling_metadata.generators,
|
2024-10-22 01:24:07 -07:00
|
|
|
sampling_metadata.top_k,
|
|
|
|
|
sampling_metadata.top_p,
|
|
|
|
|
)
|
2025-02-15 07:50:05 +08:00
|
|
|
|
2025-02-14 18:10:53 -08:00
|
|
|
if greedy_sampled is None:
|
2025-08-20 21:28:32 -07:00
|
|
|
return random_sampled, processed_logprobs
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
|
sampled = torch.where(
|
|
|
|
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
|
|
|
|
greedy_sampled,
|
|
|
|
|
random_sampled,
|
2025-02-20 22:05:56 -08:00
|
|
|
out=greedy_sampled, # Reuse tensor
|
2024-10-22 01:24:07 -07:00
|
|
|
)
|
2025-08-20 21:28:32 -07:00
|
|
|
return sampled, processed_logprobs
|
2024-10-22 01:24:07 -07:00
|
|
|
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
def gather_logprobs(
|
2024-12-27 09:32:38 +09:00
|
|
|
self,
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
logprobs: torch.Tensor,
|
|
|
|
|
num_logprobs: int,
|
|
|
|
|
token_ids: torch.Tensor,
|
|
|
|
|
) -> LogprobsTensors:
|
|
|
|
|
"""
|
|
|
|
|
Gather logprobs for topk and sampled/prompt token.
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-03-24 23:45:32 +08:00
|
|
|
logprobs: (num tokens) x (vocab) tensor
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
num_logprobs: minimum number of logprobs to
|
|
|
|
|
retain per token
|
|
|
|
|
token_ids: prompt tokens (if prompt logprobs)
|
|
|
|
|
or sampled tokens (if sampled
|
|
|
|
|
logprobs); 1D token ID tensor
|
|
|
|
|
with (num tokens) elements
|
2025-03-18 23:52:19 -07:00
|
|
|
Must be int64.
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
|
|
|
|
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
|
|
|
|
Sampled token rank tensor, (num tokens)
|
|
|
|
|
"""
|
2025-03-18 23:52:19 -07:00
|
|
|
assert token_ids.dtype == torch.int64
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
# Find the topK values.
|
2025-10-05 15:06:22 +01:00
|
|
|
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
|
|
|
|
# Get with the logprob of the prompt or sampled token.
|
2025-03-18 23:52:19 -07:00
|
|
|
token_ids = token_ids.unsqueeze(-1)
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
token_logprobs = logprobs.gather(-1, token_ids)
|
|
|
|
|
|
|
|
|
|
# Compute the ranks of the actual token.
|
2025-07-21 13:47:47 -07:00
|
|
|
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
|
|
|
|
|
# Concatenate together with the topk.
|
|
|
|
|
indices = torch.cat((token_ids, topk_indices), dim=1)
|
|
|
|
|
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Use int32 to reduce the tensor size.
|
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.
New behavior:
- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
2025-02-07 10:26:20 -05:00
|
|
|
indices = indices.to(torch.int32)
|
|
|
|
|
|
|
|
|
|
return LogprobsTensors(indices, logprobs, token_ranks)
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2025-10-07 21:02:49 +01:00
|
|
|
def _combine_outputs_with_spec_tokens(
|
|
|
|
|
self,
|
|
|
|
|
output_token_ids: list[list[int]],
|
|
|
|
|
spec_token_ids: Optional[list[list[int]]] = None,
|
|
|
|
|
) -> list[list[int]]:
|
|
|
|
|
if spec_token_ids is None:
|
|
|
|
|
return output_token_ids
|
|
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
[*out, *spec] if spec else out
|
|
|
|
|
for out, spec in zip(output_token_ids, spec_token_ids)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def apply_logits_processors(
|
2024-12-27 09:32:38 +09:00
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
2025-10-07 21:02:49 +01:00
|
|
|
predict_bonus_token: bool,
|
2024-12-27 09:32:38 +09:00
|
|
|
) -> torch.Tensor:
|
2025-10-07 21:02:49 +01:00
|
|
|
any_penalties_or_bad_words = (
|
|
|
|
|
sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output_token_ids = sampling_metadata.output_token_ids
|
|
|
|
|
if predict_bonus_token and any_penalties_or_bad_words:
|
|
|
|
|
# Combine base outputs with spec tokens when speculative decoding
|
|
|
|
|
# is enabled.
|
|
|
|
|
output_token_ids = self._combine_outputs_with_spec_tokens(
|
2025-02-21 22:13:05 -08:00
|
|
|
sampling_metadata.output_token_ids,
|
2025-10-07 21:02:49 +01:00
|
|
|
sampling_metadata.spec_token_ids,
|
2025-02-21 22:13:05 -08:00
|
|
|
)
|
2025-02-14 04:34:59 -08:00
|
|
|
|
2025-10-07 21:02:49 +01:00
|
|
|
# Apply allowed token ids.
|
2025-02-21 22:13:05 -08:00
|
|
|
if sampling_metadata.allowed_token_ids_mask is not None:
|
2025-10-05 15:06:22 +01:00
|
|
|
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
|
2025-10-07 21:02:49 +01:00
|
|
|
|
|
|
|
|
# Apply bad words exclusion.
|
|
|
|
|
if sampling_metadata.bad_words_token_ids:
|
|
|
|
|
apply_bad_words(
|
|
|
|
|
logits,
|
|
|
|
|
sampling_metadata.bad_words_token_ids,
|
|
|
|
|
output_token_ids
|
|
|
|
|
if output_token_ids is not None
|
|
|
|
|
else sampling_metadata.output_token_ids,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Apply logits processors which can impact greedy sampling.
|
|
|
|
|
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
|
|
|
|
logits = processor.apply(logits)
|
|
|
|
|
|
|
|
|
|
# Apply penalties (e.g., freq_penalties).
|
|
|
|
|
logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
|
2025-02-21 22:13:05 -08:00
|
|
|
return logits
|
2025-03-08 14:50:26 -08:00
|
|
|
|
2025-10-07 21:02:49 +01:00
|
|
|
def apply_penalties(
|
2025-03-08 14:50:26 -08:00
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
2025-10-07 21:02:49 +01:00
|
|
|
output_token_ids: Optional[list[list[int]]] = None,
|
2025-03-08 14:50:26 -08:00
|
|
|
) -> torch.Tensor:
|
2025-10-07 21:02:49 +01:00
|
|
|
if not sampling_metadata.no_penalties:
|
|
|
|
|
assert sampling_metadata.prompt_token_ids is not None
|
|
|
|
|
logits = apply_all_penalties(
|
2025-03-08 14:50:26 -08:00
|
|
|
logits,
|
2025-10-07 21:02:49 +01:00
|
|
|
sampling_metadata.prompt_token_ids,
|
|
|
|
|
sampling_metadata.presence_penalties,
|
|
|
|
|
sampling_metadata.frequency_penalties,
|
|
|
|
|
sampling_metadata.repetition_penalties,
|
|
|
|
|
output_token_ids
|
|
|
|
|
if output_token_ids is not None
|
|
|
|
|
else sampling_metadata.output_token_ids,
|
2025-03-08 14:50:26 -08:00
|
|
|
)
|
|
|
|
|
return logits
|