[V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729)
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@@ -52,62 +54,62 @@ class RejectionSampler(nn.Module):
|
||||
else:
|
||||
self.forward_method = self.forward_native
|
||||
|
||||
def forward(self, logits: torch.Tensor,
|
||||
def forward(self, draft_token_ids: List[List[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
if not sampling_metadata.all_greedy:
|
||||
raise NotImplementedError(
|
||||
"Currently, only greedy sampling is supported by "
|
||||
"rejection sampler.")
|
||||
return self.forward_method(logits, sampling_metadata)
|
||||
return self.forward_method(draft_token_ids, target_probs,
|
||||
sampling_metadata)
|
||||
|
||||
def flashinfer_sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
draft_token_ids: List[List[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_token_ids = sampling_metadata.spec_token_ids
|
||||
max_spec_len = max(len(s) for s in spec_token_ids)
|
||||
batch_size = len(spec_token_ids)
|
||||
draft_token_ids = torch.full((batch_size, max_spec_len),
|
||||
INVALID_TOKEN_ID,
|
||||
device="cpu",
|
||||
dtype=torch.long)
|
||||
sample_lens = [len(x) + 1 for x in draft_token_ids]
|
||||
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
|
||||
draft_token_ids = [
|
||||
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
|
||||
]
|
||||
draft_token_ids_tensor = pad_sequence(draft_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
target_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
fill_value=INVALID_TOKEN_ID,
|
||||
device=logits.device,
|
||||
dtype=torch.long)
|
||||
if sampling_metadata.all_greedy:
|
||||
target_token_ids = target_probs.argmax(dim=-1).view(-1)
|
||||
target_token_ids = target_token_ids.split(sample_lens)
|
||||
target_token_ids = pad_sequence(target_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# TODO: Vectorize the following loop for better performance.
|
||||
start_loc = 0
|
||||
for i in range(batch_size):
|
||||
num_spec_tokens = len(spec_token_ids[i])
|
||||
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
|
||||
spec_token_ids[i], device="cpu", dtype=torch.long)
|
||||
end_loc = start_loc + num_spec_tokens + 1
|
||||
# Assume greedy sampling.
|
||||
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
|
||||
logits[start_loc:end_loc], dim=-1)
|
||||
start_loc = end_loc
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids = draft_token_ids.to(logits.device)
|
||||
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
|
||||
logits.device)
|
||||
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
|
||||
logits.device)
|
||||
uniform_samples = torch.zeros(batch_size,
|
||||
max_spec_len + 1,
|
||||
device=logits.device)
|
||||
vocab_size = target_probs.size(-1)
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids_tensor = draft_token_ids_tensor.to(
|
||||
target_probs.device)
|
||||
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
|
||||
vocab_size,
|
||||
target_probs.device)
|
||||
target_probs = _create_greedy_token_probs(target_token_ids,
|
||||
vocab_size,
|
||||
target_probs.device)
|
||||
uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
|
||||
draft_token_ids_tensor.size(1) + 1,
|
||||
device=target_probs.device)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently, only greedy sampling is supported by "
|
||||
"rejection sampler.")
|
||||
|
||||
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
draft_token_ids_tensor,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
)
|
||||
@@ -117,35 +119,35 @@ class RejectionSampler(nn.Module):
|
||||
# TODO: The following method can be optimized for better performance.
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
draft_token_ids: List[List[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||
# Add 1 to include the 'bonus' token.
|
||||
sample_lens = [x + 1 for x in spec_lens]
|
||||
|
||||
output_token_ids = logits.argmax(dim=-1).view(-1)
|
||||
output_token_ids = output_token_ids.split(sample_lens)
|
||||
output_token_ids = pad_sequence(output_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
|
||||
spec_token_ids = [
|
||||
torch.tensor(x,
|
||||
dtype=output_token_ids.dtype,
|
||||
device=output_token_ids.device)
|
||||
for x in sampling_metadata.spec_token_ids
|
||||
sample_lens = [len(x) + 1 for x in draft_token_ids]
|
||||
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
|
||||
draft_token_ids = [
|
||||
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
|
||||
]
|
||||
spec_token_ids = pad_sequence(spec_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Produce a mask that remains 1 (True) until the first
|
||||
# mismatch (cumprod turns 0 after a mismatch).
|
||||
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
|
||||
dim=1)
|
||||
draft_token_ids_tensor = pad_sequence(draft_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
|
||||
# Add 1 to include the 'bonus' token.
|
||||
if sampling_metadata.all_greedy:
|
||||
output_token_ids = target_probs.argmax(dim=-1).view(-1)
|
||||
output_token_ids = output_token_ids.split(sample_lens)
|
||||
output_token_ids = pad_sequence(output_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
# Produce a mask that remains 1 (True) until the first
|
||||
# mismatch (cumprod turns 0 after a mismatch).
|
||||
accept_mask = (
|
||||
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
|
||||
dim=1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently, only greedy sampling is supported by "
|
||||
"rejection sampler.")
|
||||
# Identify valid positions (non-padding).
|
||||
valid_mask = output_token_ids != INVALID_TOKEN_ID
|
||||
# Generate mask with bonus token.
|
||||
|
||||
Reference in New Issue
Block a user