[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)

This commit is contained in:
Lily Liu
2024-09-22 12:34:14 -07:00
committed by GitHub
parent 5b59532760
commit c6bd70d772
7 changed files with 33 additions and 115 deletions

View File

@@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
disable_bonus_tokens, strict_mode)
strict_mode)
@pytest.mark.parametrize("k", list(range(1, 6)))
@@ -154,11 +153,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, disable_bonus_tokens: bool, device: str):
seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
@@ -166,17 +164,14 @@ def test_uniform_target_distribution_accepts_all_tokens(
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted. The test also ensures that the behavior
regarding bonus tokens is consistent with the `disable_bonus_tokens`
flag.
draft tokens being accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
@@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_temperature_zero_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
@@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
@@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
device: str):
def test_mixed_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
@@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
for sequences with a uniform distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
@@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# if disable_bonus_tokens is false then we also accept the bonus tokens.
# we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
if disable_bonus_tokens:
assert torch.all(output_token_ids[[1, 3], -1] == -1)
else:
assert torch.all(output_token_ids[[1, 3], -1] != -1)
assert torch.all(output_token_ids[[1, 3], -1] != -1)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_partially(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
@@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
@@ -384,10 +361,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
@@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int,
disable_bonus_tokens: bool,
device: str):
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
@@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
@@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True,
disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0,
posterior_alpha=0.0)
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
@@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
if disable_bonus_tokens:
assert torch.all(output_token_ids[:, -1] == -1)
else:
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
device: str):
def test_replacement_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
@@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = -torch.ones(