[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)
This commit is contained in:
@@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(
|
||||
actual.target_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
||||
assert torch.equal(actual.target_with_bonus_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1))
|
||||
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user