[Speculative decoding] Support target-model logprobs (#4378)
This commit is contained in:
@@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len():
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ),
|
||||
@@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k():
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
|
||||
Reference in New Issue
Block a user