[V1][Spec Decode] Enable spec decode for top-p & top-k sampling (#15063)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -8,6 +8,7 @@ import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.ops.utils import compiled_softmax
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
@@ -245,23 +246,79 @@ def compute_probs(
|
||||
return logits
|
||||
|
||||
num_tokens = logits.shape[0]
|
||||
batch_size = cu_num_draft_tokens.shape[0]
|
||||
expanded_temperature = torch.empty(
|
||||
(num_tokens, 1),
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
expand_kernel[(batch_size, )](
|
||||
expanded_temperature,
|
||||
temperature = expand_batch_to_tokens(
|
||||
sampling_metadata.temperature,
|
||||
cu_num_draft_tokens,
|
||||
GREEDY_TEMPERATURE, # replace_from
|
||||
1, # replace_to
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN,
|
||||
num_tokens,
|
||||
replace_from=GREEDY_TEMPERATURE,
|
||||
replace_to=1,
|
||||
)
|
||||
# TODO(woosuk): Consider using in-place op to reduce memory usage.
|
||||
logits = logits / temperature.unsqueeze(-1)
|
||||
|
||||
# Get expanded top_k and top_p tensors.
|
||||
top_k = None
|
||||
if sampling_metadata.top_k is not None:
|
||||
top_k = expand_batch_to_tokens(
|
||||
sampling_metadata.top_k,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
top_p = None
|
||||
if sampling_metadata.top_p is not None:
|
||||
top_p = expand_batch_to_tokens(
|
||||
sampling_metadata.top_p,
|
||||
cu_num_draft_tokens,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
output_prob = compiled_softmax(logits)
|
||||
return output_prob
|
||||
|
||||
|
||||
def expand_batch_to_tokens(
|
||||
x: torch.Tensor, # [batch_size]
|
||||
cu_num_tokens: torch.Tensor, # [batch_size]
|
||||
num_tokens: int,
|
||||
replace_from: int = 0,
|
||||
replace_to: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
|
||||
tokens per batch in cu_num_tokens.
|
||||
|
||||
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
|
||||
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
|
||||
|
||||
Args:
|
||||
x: [batch_size] tensor to expand.
|
||||
cu_num_tokens: [batch_size] tensor containing the cumulative number of
|
||||
tokens per batch. Each element represents the total number of
|
||||
tokens up to and including that batch.
|
||||
num_tokens: Total number of tokens.
|
||||
replace_from: int = 0
|
||||
Value to be replaced if it is found in x.
|
||||
replace_to: int = 0
|
||||
Value to replace with when replace_from is found.
|
||||
Returns:
|
||||
expanded_x: [num_tokens] tensor.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
expand_kernel[(batch_size, )](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||
num_warps=1,
|
||||
)
|
||||
output_prob = compiled_softmax(logits, expanded_temperature)
|
||||
return output_prob
|
||||
return expanded_x
|
||||
|
||||
|
||||
def generate_uniform_probs(
|
||||
|
||||
Reference in New Issue
Block a user