[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)

This commit is contained in:
Cade Daniel
2024-04-16 13:09:21 -07:00
committed by GitHub
parent 69e1d2fb69
commit e95cd87959
31 changed files with 1347 additions and 407 deletions

View File

@@ -6,6 +6,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
@@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
execute_model_data, _, _ = create_batch(batch_size, k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
@@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
seen_contexts = []
@@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artifical stop'
rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
worker.execute_model(**execute_model_data.to_dict(),
num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0]
@@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
@@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
num_lookahead_slots=k)
expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
@@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
target_worker.execute_model.return_value = target_output[0]
target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
@@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
num_lookahead_slots=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = (
@@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
@@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
@@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False)
**execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
@@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
num_spec_tokens=k)
num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
@@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict(), return_python_output=False)
**execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())