Files
vllm/vllm/v1/sample/sampler.py

137 lines
5.1 KiB
Python
Raw Normal View History

2024-10-22 01:24:07 -07:00
"""A layer that samples the next tokens from the model's outputs."""
from typing import Tuple
2024-10-22 01:24:07 -07:00
import torch
import torch.nn as nn
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_min_token_penalties,
apply_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
2024-10-22 01:24:07 -07:00
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
2024-10-22 01:24:07 -07:00
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
needs_logprobs = sampling_metadata.max_num_logprobs > 0
if needs_logprobs:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# NOTE: We compute logprobs first because the below ops may
# modify the logits tensor in-place (and we don't want to clone
# the logits tensor for memory efficiency).
topk_logprobs, topk_indices = self.get_topk_logprobs(
logits, sampling_metadata)
2024-10-22 01:24:07 -07:00
else:
topk_logprobs = None
topk_indices = None
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# NOTE: CPU-GPU synchronization happens here.
2024-10-22 01:24:07 -07:00
sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(),
2024-10-22 01:24:07 -07:00
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1))
return logits
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
2024-10-22 01:24:07 -07:00
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(logits)
random_sampled = self.topk_topp_sampler(
2024-10-22 01:24:07 -07:00
logits,
sampling_metadata.generators,
2024-10-22 01:24:07 -07:00
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
if sampling_metadata.all_random:
return random_sampled
2024-10-22 01:24:07 -07:00
greedy_sampled = self.greedy_sample(logits)
2024-10-22 01:24:07 -07:00
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
)
return sampled
def get_topk_logprobs(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
return topk_logprobs, topk_indices
2024-10-22 01:24:07 -07:00
def apply_penalties(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
sampling_metadata.stop_token_ids,
sampling_metadata.min_tokens)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_penalties(logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
2024-10-22 01:24:07 -07:00
return logits