[SpecDecode] [Minor] Fix spec decode sampler tests (#7183)
This commit is contained in:
@@ -25,7 +25,7 @@ def mock_causal_accepted_tensor(
|
||||
|
||||
accepted = (torch.arange(k).expand(batch_size, k) <=
|
||||
last_accepted_indices.unsqueeze(-1).broadcast_to(
|
||||
batch_size, k)).to(device="cuda")
|
||||
batch_size, k))
|
||||
|
||||
# Sprinkle accepted values after the contiguous initial accepted values.
|
||||
# This replicates the behavior of rejection sampling, which may "accept"
|
||||
@@ -33,7 +33,7 @@ def mock_causal_accepted_tensor(
|
||||
sprinkle_candidates = (
|
||||
torch.arange(k).expand(batch_size, k) >
|
||||
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
|
||||
sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5
|
||||
sprinkle = torch.rand(batch_size, k) > 0.5
|
||||
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
|
||||
return accepted
|
||||
|
||||
@@ -86,7 +86,7 @@ def test_correct_output_format(which_tokens_accepted: str,
|
||||
|
||||
rejection_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens)
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
@@ -138,7 +138,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
@@ -167,7 +167,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
@@ -211,7 +211,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
@@ -339,7 +339,7 @@ class _CorrectnessTestHelper:
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_range = (0, vocab_size)
|
||||
|
||||
self.rejection_sampler.init_gpu_tensors(rank=0)
|
||||
self.rejection_sampler.init_gpu_tensors(device=0)
|
||||
|
||||
# Keep test simple, use k=1
|
||||
self.k = 1
|
||||
|
||||
Reference in New Issue
Block a user