2024-10-22 01:24:07 -07:00
|
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
2024-12-27 09:32:38 +09:00
|
|
|
from typing import Tuple
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from vllm.v1.outputs import SamplerOutput
|
|
|
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
2024-12-27 09:32:38 +09:00
|
|
|
from vllm.v1.sample.ops.penalties import (apply_min_token_penalties,
|
|
|
|
|
apply_penalties)
|
|
|
|
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
2024-10-22 01:24:07 -07:00
|
|
|
|
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.topk_topp_sampler = TopKTopPSampler()
|
|
|
|
|
|
2024-10-22 01:24:07 -07:00
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> SamplerOutput:
|
2024-12-27 09:32:38 +09:00
|
|
|
needs_logprobs = sampling_metadata.max_num_logprobs > 0
|
|
|
|
|
if needs_logprobs:
|
|
|
|
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
|
|
|
|
# temperature scaling) for the top-k logprobs.
|
|
|
|
|
# This is different from the V0 sampler, which uses the logits that
|
|
|
|
|
# is used for sampling (after penalties and temperature scaling).
|
|
|
|
|
# NOTE: We compute logprobs first because the below ops may
|
|
|
|
|
# modify the logits tensor in-place (and we don't want to clone
|
|
|
|
|
# the logits tensor for memory efficiency).
|
|
|
|
|
topk_logprobs, topk_indices = self.get_topk_logprobs(
|
|
|
|
|
logits, sampling_metadata)
|
2024-10-22 01:24:07 -07:00
|
|
|
else:
|
|
|
|
|
topk_logprobs = None
|
|
|
|
|
topk_indices = None
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
# Use float32 for the logits.
|
|
|
|
|
logits = logits.to(torch.float32)
|
|
|
|
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
|
|
|
|
logits = self.apply_penalties(logits, sampling_metadata)
|
|
|
|
|
# Apply temperature.
|
|
|
|
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
|
|
|
|
# Sample the next token.
|
|
|
|
|
sampled = self.sample(logits, sampling_metadata)
|
|
|
|
|
# Use int32 to reduce the tensor size.
|
|
|
|
|
sampled = sampled.to(torch.int32)
|
|
|
|
|
|
2024-12-10 01:28:14 -05:00
|
|
|
# NOTE: CPU-GPU synchronization happens here.
|
2024-10-22 01:24:07 -07:00
|
|
|
sampler_output = SamplerOutput(
|
2024-12-10 01:28:14 -05:00
|
|
|
sampled_token_ids=sampled.tolist(),
|
2024-10-22 01:24:07 -07:00
|
|
|
logprob_token_ids=topk_indices,
|
|
|
|
|
logprobs=topk_logprobs,
|
|
|
|
|
prompt_logprob_token_ids=None,
|
|
|
|
|
prompt_logprobs=None,
|
|
|
|
|
)
|
|
|
|
|
return sampler_output
|
|
|
|
|
|
|
|
|
|
def apply_temperature(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
temp: torch.Tensor,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
# Avoid division by zero.
|
|
|
|
|
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
|
|
|
|
# Use in-place division to avoid creating a new tensor.
|
|
|
|
|
logits.div_(temp.unsqueeze(dim=1))
|
|
|
|
|
return logits
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
return logits.argmax(dim=-1).view(-1)
|
|
|
|
|
|
|
|
|
|
def sample(
|
2024-10-22 01:24:07 -07:00
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
2024-12-27 09:32:38 +09:00
|
|
|
assert not (sampling_metadata.all_greedy
|
|
|
|
|
and sampling_metadata.all_random)
|
|
|
|
|
if sampling_metadata.all_greedy:
|
|
|
|
|
return self.greedy_sample(logits)
|
|
|
|
|
|
|
|
|
|
random_sampled = self.topk_topp_sampler(
|
2024-10-22 01:24:07 -07:00
|
|
|
logits,
|
2024-12-27 09:32:38 +09:00
|
|
|
sampling_metadata.generators,
|
2024-10-22 01:24:07 -07:00
|
|
|
sampling_metadata.no_top_k,
|
|
|
|
|
sampling_metadata.top_k,
|
|
|
|
|
sampling_metadata.no_top_p,
|
|
|
|
|
sampling_metadata.top_p,
|
|
|
|
|
)
|
|
|
|
|
if sampling_metadata.all_random:
|
2024-12-27 09:32:38 +09:00
|
|
|
return random_sampled
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
greedy_sampled = self.greedy_sample(logits)
|
2024-10-22 01:24:07 -07:00
|
|
|
sampled = torch.where(
|
|
|
|
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
|
|
|
|
greedy_sampled,
|
|
|
|
|
random_sampled,
|
|
|
|
|
)
|
|
|
|
|
return sampled
|
|
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def get_topk_logprobs(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
|
|
|
|
|
# FIXME: Mask the sampled token_id, get topk logprobs,
|
|
|
|
|
# and concatenate the topk with the sampled token_id.
|
|
|
|
|
topk_logprobs, topk_indices = torch.topk(
|
|
|
|
|
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
|
|
|
|
# Use int32 to reduce the tensor size.
|
|
|
|
|
topk_indices = topk_indices.to(torch.int32)
|
|
|
|
|
return topk_logprobs, topk_indices
|
2024-10-22 01:24:07 -07:00
|
|
|
|
2024-12-27 09:32:38 +09:00
|
|
|
def apply_penalties(
|
|
|
|
|
self,
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
|
|
|
|
sampling_metadata.stop_token_ids,
|
|
|
|
|
sampling_metadata.min_tokens)
|
|
|
|
|
if not sampling_metadata.no_penalties:
|
|
|
|
|
assert sampling_metadata.prompt_token_ids is not None
|
|
|
|
|
logits = apply_penalties(logits,
|
|
|
|
|
sampling_metadata.prompt_token_ids,
|
|
|
|
|
sampling_metadata.presence_penalties,
|
|
|
|
|
sampling_metadata.frequency_penalties,
|
|
|
|
|
sampling_metadata.repetition_penalties,
|
|
|
|
|
sampling_metadata.output_token_ids)
|
2024-10-22 01:24:07 -07:00
|
|
|
return logits
|