[Bugfix] Fix spec decoding when seed is none in a batch (#10863)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
@@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
assert torch.equal(results[j][i], results[0][i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
set_random_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
single_batches = []
|
||||
for i in range(batch_size):
|
||||
single_batches.append((draft_probs[i].clone().unsqueeze(0),
|
||||
draft_token_ids[i].clone().unsqueeze(0),
|
||||
target_probs[i].clone().unsqueeze(0),
|
||||
bonus_token_ids[i].clone().unsqueeze(0),
|
||||
draft_token_ids[i].clone().unsqueeze(0)))
|
||||
|
||||
set_random_seed(0)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
results = []
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(1, batch_size) # 0 is seed None
|
||||
}
|
||||
batch_result = rejection_sampler(target_probs.clone(),
|
||||
bonus_token_ids.clone(),
|
||||
draft_probs.clone(),
|
||||
draft_token_ids.clone(), seeded_seqs)
|
||||
|
||||
set_random_seed(0)
|
||||
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
for i in range(batch_size):
|
||||
request_seeded_seqs = {
|
||||
0: torch.Generator(device=device).manual_seed(i)
|
||||
} if seeded_seqs.get(i) is not None else None
|
||||
(draft_probs, draft_token_ids, target_probs, bonus_token_ids,
|
||||
draft_token_ids) = single_batches[i]
|
||||
results.append(
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, request_seeded_seqs))
|
||||
for i in range(batch_size):
|
||||
assert torch.equal(batch_result[i], results[i].squeeze(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
|
||||
Reference in New Issue
Block a user