[SpecDecode] [Minor] Fix spec decode sampler tests (#7183)
This commit is contained in:
@@ -78,7 +78,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -111,7 +111,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -171,7 +171,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -225,7 +225,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
||||
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
@@ -285,7 +285,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||
# distribution.
|
||||
@@ -352,7 +352,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||
# Verify that all of them are accepted.
|
||||
@@ -414,7 +414,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Simulate temperature 0 probability distribution for target
|
||||
# probabilities and create target probabilities such that only 1 token
|
||||
# id has probability 1.0 and others have a very low probability of
|
||||
@@ -447,7 +447,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
posterior_threshold=0.0,
|
||||
posterior_alpha=0.0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
@@ -485,7 +485,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(
|
||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
expected_replacement_tokens = -torch.ones(
|
||||
(batch_size, k), dtype=torch.long)
|
||||
|
||||
Reference in New Issue
Block a user