[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)

This commit is contained in:
Lily Liu
2024-09-01 21:23:29 -07:00
committed by GitHub
parent f8d60145b4
commit e6a26ed037
9 changed files with 306 additions and 109 deletions

View File

@@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual.target_with_bonus_probs,
target_token_probs.reshape(batch_size, k + 1, -1))
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)