[V1][Spec Decode] Ngram Spec Decode (#12193)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
160
vllm/v1/sample/rejection_sampler.py
Normal file
160
vllm/v1/sample/rejection_sampler.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
try:
|
||||
import flashinfer.sampling as fs
|
||||
is_flashinfer_available = True
|
||||
except ImportError:
|
||||
is_flashinfer_available = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
INVALID_TOKEN_ID = -1
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
|
||||
def forward(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
if not sampling_metadata.all_greedy:
|
||||
raise NotImplementedError(
|
||||
"Only greedy sampling is supported by rejection sampler.")
|
||||
|
||||
if is_flashinfer_available:
|
||||
logger.info("User FlashInfer for rejection sampling.")
|
||||
return RejectionSampler.flashinfer_sample(logits,
|
||||
sampling_metadata)
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of rejection sampling.")
|
||||
return RejectionSampler.greedy_sample_native(
|
||||
logits, sampling_metadata)
|
||||
|
||||
@staticmethod
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
spec_token_ids = sampling_metadata.spec_token_ids
|
||||
max_spec_len = max(len(s) for s in spec_token_ids)
|
||||
batch_size = len(spec_token_ids)
|
||||
draft_token_ids = torch.full((batch_size, max_spec_len),
|
||||
INVALID_TOKEN_ID,
|
||||
device="cpu",
|
||||
dtype=torch.long)
|
||||
|
||||
target_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
fill_value=INVALID_TOKEN_ID,
|
||||
device=logits.device,
|
||||
dtype=torch.long)
|
||||
|
||||
# TODO: Vectorize the following loop for better performance.
|
||||
start_loc = 0
|
||||
for i in range(batch_size):
|
||||
num_spec_tokens = len(spec_token_ids[i])
|
||||
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
|
||||
spec_token_ids[i], device="cpu", dtype=torch.long)
|
||||
end_loc = start_loc + num_spec_tokens + 1
|
||||
# Assume greedy sampling.
|
||||
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
|
||||
logits[start_loc:end_loc], dim=-1)
|
||||
start_loc = end_loc
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids = draft_token_ids.to(logits.device)
|
||||
draft_probs = RejectionSampler._create_greedy_token_probs(
|
||||
draft_token_ids, vocab_size, logits.device)
|
||||
target_probs = RejectionSampler._create_greedy_token_probs(
|
||||
target_token_ids, vocab_size, logits.device)
|
||||
uniform_samples = torch.zeros(batch_size,
|
||||
max_spec_len + 1,
|
||||
device=logits.device)
|
||||
|
||||
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
)
|
||||
return SamplerOutput(sampled_token_ids=sampled_token_ids,
|
||||
logprobs_tensors=None)
|
||||
|
||||
# TODO: The following method can be optimized for better performance.
|
||||
@staticmethod
|
||||
def greedy_sample_native(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||
# Add 1 to include the 'bonus' token.
|
||||
sample_lens = [x + 1 for x in spec_lens]
|
||||
|
||||
output_token_ids = logits.argmax(dim=-1).view(-1)
|
||||
output_token_ids = output_token_ids.split(sample_lens)
|
||||
output_token_ids = pad_sequence(output_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
|
||||
spec_token_ids = [
|
||||
torch.tensor(x,
|
||||
dtype=output_token_ids.dtype,
|
||||
device=output_token_ids.device)
|
||||
for x in sampling_metadata.spec_token_ids
|
||||
]
|
||||
spec_token_ids = pad_sequence(spec_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Produce a mask that remains 1 (True) until the first
|
||||
# mismatch (cumprod turns 0 after a mismatch).
|
||||
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
|
||||
dim=1)
|
||||
# Identify valid positions (non-padding).
|
||||
valid_mask = output_token_ids != INVALID_TOKEN_ID
|
||||
# Generate mask with bonus token.
|
||||
generate_mask = torch.cat([
|
||||
accept_mask,
|
||||
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
|
||||
],
|
||||
dim=1).to(torch.bool) & valid_mask
|
||||
zeros_mask = (generate_mask == 0)
|
||||
first_zero_idx = zeros_mask.float().argmax(dim=1)
|
||||
# Figure out which rows actually contain at least one zero.
|
||||
rows_with_zero = zeros_mask.any(dim=1)
|
||||
# Use indexing to set the first zero in each of those rows to 1.
|
||||
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
|
||||
|
||||
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
|
||||
return SamplerOutput(sampled_token_ids=output_token_ids,
|
||||
logprobs_tensors=None)
|
||||
|
||||
@staticmethod
|
||||
def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int,
|
||||
out_device: torch.device) -> torch.Tensor:
|
||||
batch_size, num_tokens = token_ids.shape
|
||||
|
||||
token_probs = torch.zeros(batch_size,
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float,
|
||||
device=out_device)
|
||||
|
||||
# Ignore INVALID_TOKEN_ID.
|
||||
valid_mask = (token_ids != INVALID_TOKEN_ID)
|
||||
valid_indices = token_ids.clone()
|
||||
valid_indices[~valid_mask] = 0
|
||||
|
||||
token_probs.scatter_(dim=2,
|
||||
index=valid_indices.unsqueeze(-1),
|
||||
src=valid_mask.unsqueeze(-1).float())
|
||||
|
||||
return token_probs
|
||||
Reference in New Issue
Block a user