[Speculative decoding] Support target-model logprobs (#4378)

This commit is contained in:
Cade Daniel
2024-05-03 15:52:01 -07:00
committed by GitHub
parent 43c413ec57
commit ab50275111
15 changed files with 727 additions and 86 deletions

View File

@@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
@@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int):
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
@@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
num_lookahead_slots=k)
expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
token_ids=rejection_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)])
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
@@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int):
continue
assert actual_by_step[i].output_token == expected_by_step[
i].output_token
assert actual_by_step[i].logprobs == expected_by_step[i].logprobs
@pytest.mark.parametrize('k', [1, 2])
@@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]