[Speculative decoding] Support target-model logprobs (#4378)
This commit is contained in:
@@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
@@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
@@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
num_lookahead_slots=k)
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
|
||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
||||
probs=[None for _ in range(k + 1)],
|
||||
logprobs=[None for _ in range(k + 1)])
|
||||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
@@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
continue
|
||||
assert actual_by_step[i].output_token == expected_by_step[
|
||||
i].output_token
|
||||
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user