[V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729)

This commit is contained in:
Lily Liu
2025-02-25 18:14:48 -08:00
committed by GitHub
parent 9ba28043b5
commit 5629f26df7
8 changed files with 102 additions and 109 deletions

View File

@@ -1,4 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
@@ -52,62 +54,62 @@ class RejectionSampler(nn.Module):
else:
self.forward_method = self.forward_native
def forward(self, logits: torch.Tensor,
def forward(self, draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
return self.forward_method(logits, sampling_metadata)
return self.forward_method(draft_token_ids, target_probs,
sampling_metadata)
def flashinfer_sample(
self,
logits: torch.Tensor,
draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
assert sampling_metadata.spec_token_ids is not None
spec_token_ids = sampling_metadata.spec_token_ids
max_spec_len = max(len(s) for s in spec_token_ids)
batch_size = len(spec_token_ids)
draft_token_ids = torch.full((batch_size, max_spec_len),
INVALID_TOKEN_ID,
device="cpu",
dtype=torch.long)
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
target_token_ids = torch.full((batch_size, max_spec_len + 1),
fill_value=INVALID_TOKEN_ID,
device=logits.device,
dtype=torch.long)
if sampling_metadata.all_greedy:
target_token_ids = target_probs.argmax(dim=-1).view(-1)
target_token_ids = target_token_ids.split(sample_lens)
target_token_ids = pad_sequence(target_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# TODO: Vectorize the following loop for better performance.
start_loc = 0
for i in range(batch_size):
num_spec_tokens = len(spec_token_ids[i])
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
spec_token_ids[i], device="cpu", dtype=torch.long)
end_loc = start_loc + num_spec_tokens + 1
# Assume greedy sampling.
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
logits[start_loc:end_loc], dim=-1)
start_loc = end_loc
vocab_size = logits.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids = draft_token_ids.to(logits.device)
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
logits.device)
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
logits.device)
uniform_samples = torch.zeros(batch_size,
max_spec_len + 1,
device=logits.device)
vocab_size = target_probs.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids_tensor = draft_token_ids_tensor.to(
target_probs.device)
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
vocab_size,
target_probs.device)
target_probs = _create_greedy_token_probs(target_token_ids,
vocab_size,
target_probs.device)
uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
draft_token_ids_tensor.size(1) + 1,
device=target_probs.device)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
draft_probs,
draft_token_ids,
draft_token_ids_tensor,
uniform_samples,
target_probs,
)
@@ -117,35 +119,35 @@ class RejectionSampler(nn.Module):
# TODO: The following method can be optimized for better performance.
def forward_native(
self,
logits: torch.Tensor,
draft_token_ids: List[List[int]],
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
assert sampling_metadata.spec_token_ids is not None
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens]
output_token_ids = logits.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids = [
torch.tensor(x,
dtype=output_token_ids.dtype,
device=output_token_ids.device)
for x in sampling_metadata.spec_token_ids
sample_lens = [len(x) + 1 for x in draft_token_ids]
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
draft_token_ids = [
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
]
spec_token_ids = pad_sequence(spec_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
dim=1)
draft_token_ids_tensor = pad_sequence(draft_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
# Add 1 to include the 'bonus' token.
if sampling_metadata.all_greedy:
output_token_ids = target_probs.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
dim=1)
else:
raise NotImplementedError(
"Currently, only greedy sampling is supported by "
"rejection sampler.")
# Identify valid positions (non-padding).
valid_mask = output_token_ids != INVALID_TOKEN_ID
# Generate mask with bonus token.