[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)
This commit is contained in:
@@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that sampling succeeds for all cases.
|
||||
typical_acceptance_sampler(target_probs,
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
typical_acceptance_sampler(target_probs,
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
@@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size)
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
# Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
@@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||
# Verify the same.
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||
# distribution.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs = target_with_bonus_probs[:, :-1]
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
|
||||
@@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||
# Verify that all of them are accepted.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
draft_token_ids = zero_temperature_token_ids
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
draft_token_ids = torch.cat(
|
||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
@@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0. Without any changes to the posterior thresholds
|
||||
# none of the draft tokens are accepted.
|
||||
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
|
||||
batch_size, k, vocab_size))
|
||||
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
|
||||
batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs[target_probs == 0] = 0.00001
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user