[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
@@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
|
||||
return draft_token_ids
|
||||
|
||||
|
||||
def get_acceptance_sampler(
|
||||
posterior_threshold: float = 0.03,
|
||||
posterior_alpha: float = 0.9,
|
||||
disable_bonus_tokens: bool = False,
|
||||
strict_mode: bool = False,
|
||||
) -> TypicalAcceptanceSampler:
|
||||
"""
|
||||
Initializes and returns a TypicalAcceptanceSampler.
|
||||
"""
|
||||
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
|
||||
disable_bonus_tokens, strict_mode)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
different combinations of k, vocab_size, batch_size and num devices.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler()
|
||||
typical_acceptance_sampler = get_acceptance_sampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
@@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that sampling succeeds for all cases.
|
||||
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
|
||||
typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
@@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
typical_acceptance_sampler(target_probs, bonus_token_ids,
|
||||
draft_token_ids)
|
||||
typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
@@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# We are using a uniform target probability distribution.
|
||||
# For a uniform distribution the entropy is very high and it
|
||||
# should lead to all draft tokens being accepted. Verify that.
|
||||
@@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
@@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
# 1.0 tokens in the target distribution we will reject all of them and
|
||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||
# Verify the same.
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
@@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size = 4
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
@@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# verify the shape of output_token_ids
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
@@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
@@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
@@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size, k, vocab_size, zero_temperature_token_ids)
|
||||
draft_token_ids = torch.cat(
|
||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
||||
@@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
# Simulate temperature 0 probability distribution for target
|
||||
@@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 1:-1] == -1)
|
||||
@@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
posterior_threshold=0.0,
|
||||
posterior_alpha=0.0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
output_token_ids = typical_acceptance_sampler(target_probs,
|
||||
bonus_token_ids,
|
||||
draft_token_ids)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
@@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user