[Kernel] Optimize sample_recovered_tokens_kernel (#34974)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -11,7 +11,11 @@ from tests.v1.sample.utils import create_allowed_token_ids
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.sample.rejection_sampler import (
|
||||
PLACEHOLDER_TOKEN_ID,
|
||||
RejectionSampler,
|
||||
sample_recovered_tokens,
|
||||
)
|
||||
from vllm.v1.sample.sampler import Sampler, SamplerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
@@ -518,6 +522,70 @@ def estimate_rejection_sampling_pdf(
|
||||
return hist.hist
|
||||
|
||||
|
||||
def native_sample_recovered_tokens(
|
||||
max_spec_len: int,
|
||||
num_draft_tokens: list[int],
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
draft_token_ids: torch.Tensor, # [num_tokens]
|
||||
draft_probs: torch.Tensor | None, # [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor, # [num_tokens, vocab_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
|
||||
states = {
|
||||
i: generator.get_state()
|
||||
for i, generator in sampling_metadata.generators.items()
|
||||
}
|
||||
for i, generator in sampling_metadata.generators.items():
|
||||
# Do not generate random numbers for requests with no draft tokens.
|
||||
# This can be important for reproducibility.
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
# In order to generate the same exponential later, reset the CUDA RNG
|
||||
# state because RNG state advances after each call.
|
||||
generator.set_state(states[i])
|
||||
|
||||
inv_q = q.reciprocal()
|
||||
|
||||
out = torch.empty_like(draft_token_ids)
|
||||
|
||||
for req_idx in range(batch_size):
|
||||
start_idx = 0 if req_idx == 0 else int(cu_num_draft_tokens[req_idx - 1].item())
|
||||
end_idx = int(cu_num_draft_tokens[req_idx].item())
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
for pos in range(max_spec_len):
|
||||
if pos >= num_tokens:
|
||||
continue
|
||||
token_idx = start_idx + pos
|
||||
|
||||
if draft_probs is None:
|
||||
# prob is target_probs[token_idx] except draft_token_id is zeroed
|
||||
prob = target_probs[token_idx].clone()
|
||||
draft_token_id = draft_token_ids[token_idx]
|
||||
prob[draft_token_id] = 0.0
|
||||
else:
|
||||
prob = (target_probs[token_idx] - draft_probs[token_idx]).clamp_min_(
|
||||
0.0
|
||||
)
|
||||
|
||||
score = prob * inv_q[req_idx]
|
||||
recovered_id = torch.argmax(score, dim=-1)
|
||||
out[token_idx] = recovered_id
|
||||
return out
|
||||
|
||||
|
||||
def _test_masked_logits(
|
||||
rejection_sampler,
|
||||
batch_size: int,
|
||||
@@ -778,3 +846,60 @@ def test_allowed_token_ids(rejection_sampler):
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 100])
|
||||
@pytest.mark.parametrize("vocab_size", [100, 8192, 10000])
|
||||
@pytest.mark.parametrize("max_spec_len", [1, 3])
|
||||
@pytest.mark.parametrize("no_draft_probs", [True, False])
|
||||
def test_sample_recovered_tokens(
|
||||
batch_size: int, vocab_size: int, max_spec_len: int, no_draft_probs: bool
|
||||
):
|
||||
num_tokens = batch_size * max_spec_len
|
||||
|
||||
# Create random draft probabilities.
|
||||
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
|
||||
# Create random target probabilities.
|
||||
target_logits = torch.rand(
|
||||
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
target_probs = F.softmax(target_logits, dim=-1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs
|
||||
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
|
||||
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
generators = {
|
||||
i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size)
|
||||
}
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature, generators=generators
|
||||
)
|
||||
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.reshape(batch_size, max_spec_len).tolist(), target_logits
|
||||
)
|
||||
|
||||
ref_recovered_token_ids = native_sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
None if no_draft_probs else draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=DEVICE,
|
||||
)
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
None if no_draft_probs else draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=DEVICE,
|
||||
)
|
||||
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
|
||||
|
||||
@@ -623,16 +623,19 @@ def sample_recovered_tokens(
|
||||
if num_draft_tokens[i] > 0:
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
inv_q = q.reciprocal()
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
BLOCK_SIZE = 8192
|
||||
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
inv_q,
|
||||
vocab_size,
|
||||
triton.next_power_of_2(vocab_size),
|
||||
BLOCK_SIZE,
|
||||
NO_DRAFT_PROBS=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
@@ -776,9 +779,9 @@ def sample_recovered_tokens_kernel(
|
||||
draft_token_ids_ptr, # [num_tokens]
|
||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||
target_probs_ptr, # [num_tokens, vocab_size]
|
||||
q_ptr, # [batch_size, vocab_size]
|
||||
inv_q_ptr, # [batch_size, vocab_size]
|
||||
vocab_size,
|
||||
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
NO_DRAFT_PROBS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
@@ -791,33 +794,50 @@ def sample_recovered_tokens_kernel(
|
||||
if pos >= num_draft_tokens:
|
||||
return
|
||||
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
|
||||
other=0,
|
||||
)
|
||||
else:
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
token_idx = start_idx + pos
|
||||
|
||||
q = tl.load(
|
||||
q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"),
|
||||
)
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
|
||||
|
||||
max_val = float("-inf")
|
||||
recovered_id = 0
|
||||
for v in range(0, vocab_size, BLOCK_SIZE):
|
||||
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
|
||||
vocab_mask = vocab_offset < vocab_size
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
prob = tl.load(
|
||||
target_probs_ptr + token_idx * vocab_size + vocab_offset,
|
||||
mask=(vocab_mask & (vocab_offset != draft_token_id)),
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + token_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_mask,
|
||||
other=0.0,
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + token_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_mask,
|
||||
other=0.0,
|
||||
)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0.0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
|
||||
inv_q = tl.load(
|
||||
inv_q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_mask,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Local tile reduction
|
||||
score = prob * inv_q
|
||||
local_max, local_id = tl.max(score, axis=0, return_indices=True)
|
||||
|
||||
if local_max > max_val:
|
||||
max_val = local_max
|
||||
recovered_id = v + local_id
|
||||
|
||||
tl.store(output_token_ids_ptr + token_idx, recovered_id)
|
||||
|
||||
Reference in New Issue
Block a user