[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
@@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
# Bonus tokens are currently disabled. Verify they're set to -1.
|
||||
# See https://github.com/vllm-project/vllm/issues/4212
|
||||
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
# Expect all tokens to be equal to draft tokens.
|
||||
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
|
||||
|
||||
# Expect all bonus tokens to be included.
|
||||
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
|
||||
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
# Expect first token to be equal to recovered tokens.
|
||||
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
|
||||
@@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
torch.ones_like(output_token_ids[:, 1:]) * -1)
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
recovered_plus_bonus = torch.cat(
|
||||
(recovered_token_ids, bonus_token_ids), dim=-1)
|
||||
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
|
||||
# Assert first rejected token is a recovered token or bonus token.
|
||||
assert torch.equal(
|
||||
recovered_plus_bonus[torch.arange(0, batch_size),
|
||||
|
||||
Reference in New Issue
Block a user