202 lines
7.5 KiB
Python
202 lines
7.5 KiB
Python
from typing import Dict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm import envs
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
try:
|
|
import flashinfer.sampling
|
|
is_flashinfer_available = True
|
|
except ImportError:
|
|
is_flashinfer_available = False
|
|
|
|
|
|
class TopKTopPSampler(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
if current_platform.is_cuda:
|
|
if is_flashinfer_available:
|
|
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
|
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
|
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
|
# default it is unused). For backward compatibility, we set
|
|
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
|
# interpret it differently in V0 and V1 samplers: In V0,
|
|
# None means False, while in V1, None means True. This is
|
|
# why we use the condition
|
|
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
|
logger.info("Using FlashInfer for top-p & top-k sampling.")
|
|
self.forward = self.forward_cuda
|
|
else:
|
|
logger.warning(
|
|
"FlashInfer is available, but it is not enabled. "
|
|
"Falling back to the PyTorch-native implementation of "
|
|
"top-p & top-k sampling. For the best performance, "
|
|
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
|
self.forward = self.forward_native
|
|
else:
|
|
logger.warning(
|
|
"FlashInfer is not available. Falling back to the PyTorch-"
|
|
"native implementation of top-p & top-k sampling. For the "
|
|
"best performance, please install FalshInfer.")
|
|
self.forward = self.forward_native
|
|
else:
|
|
self.forward = self.forward_native
|
|
|
|
def forward_native(
|
|
self,
|
|
logits: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""PyTorch-native implementation of top-k and top-p sampling."""
|
|
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
|
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
|
return random_sample(probs, generators)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
logits: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""More optimized implementation for top-k and top-p sampling."""
|
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
|
if no_top_k and no_top_p:
|
|
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
|
# not needed. This is because `random_sample` does not require
|
|
# CPU-GPU synchronization while `flashinfer_sample` does.
|
|
return random_sample(probs, generators)
|
|
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
|
|
|
|
|
|
def apply_top_k_top_p(
|
|
logits: torch.Tensor,
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Apply top-k and top-p masks to the logits.
|
|
|
|
This function sorts the logits tensor, which can be slow for large batches.
|
|
"""
|
|
if no_top_k and no_top_p:
|
|
return logits
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
|
|
|
if not no_top_k:
|
|
# Apply top-k.
|
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
|
# Get all the top_k values.
|
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
|
top_k_mask = logits_sort < top_k_mask
|
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
|
|
|
if not no_top_p:
|
|
# Apply top-p.
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
|
# at least one
|
|
top_p_mask[:, -1] = False
|
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
|
|
|
# Re-sort the probabilities.
|
|
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
|
return logits
|
|
|
|
|
|
def random_sample(
|
|
probs: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
) -> torch.Tensor:
|
|
"""Randomly sample from the probabilities.
|
|
|
|
We use this function instead of torch.multinomial because torch.multinomial
|
|
causes CPU-GPU synchronization.
|
|
"""
|
|
q = torch.empty_like(probs)
|
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
|
# which is the common case, we first assume that every request does
|
|
# not have its own seed. Then, we overwrite the values for the requests
|
|
# that have their own seeds.
|
|
if len(generators) != probs.shape[0]:
|
|
q.exponential_()
|
|
if generators:
|
|
# TODO(woosuk): This can be slow because we handle each request
|
|
# one by one. Optimize this.
|
|
for i, generator in generators.items():
|
|
q[i].exponential_(generator=generator)
|
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
|
|
|
|
|
def flashinfer_sample(
|
|
probs: torch.Tensor,
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
) -> torch.Tensor:
|
|
"""Sample from the probabilities using FlashInfer.
|
|
|
|
Statistically, this function is equivalent to the `random_sample` function.
|
|
However, this function is faster because it avoids sorting the logits tensor
|
|
via rejection sampling.
|
|
|
|
NOTE: The outputs of this function do not necessarily match the outputs of
|
|
the `random_sample` function. It only guarantees that the outputs are
|
|
statistically equivalent.
|
|
|
|
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
|
does not. Call this function at the end of the forward pass to minimize
|
|
the synchronization overhead.
|
|
"""
|
|
assert not (no_top_k and no_top_p)
|
|
max_top_k_round = 32
|
|
batch_size = probs.shape[0]
|
|
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
|
device=probs.device)
|
|
if len(generators) != batch_size:
|
|
uniform_samples.uniform_()
|
|
if generators:
|
|
for i, generator in generators.items():
|
|
uniform_samples[:, i].uniform_(generator=generator)
|
|
|
|
if no_top_k:
|
|
# Top-p only.
|
|
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
|
probs, uniform_samples, p, deterministic=True)
|
|
elif no_top_p:
|
|
# Top-k only.
|
|
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
|
probs, uniform_samples, k, deterministic=True)
|
|
else:
|
|
# Both top-k and top-p.
|
|
next_token_ids, success = (
|
|
flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
|
probs, uniform_samples, k, p, deterministic=True))
|
|
|
|
# NOTE: CPU-GPU synchronization happens here.
|
|
if not success.all():
|
|
if not no_top_k:
|
|
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
|
if not no_top_p:
|
|
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
|
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
|
probs, uniform_samples[0], deterministic=True)
|
|
return next_token_ids.view(-1)
|