[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
@@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
|
||||
seen_contexts = []
|
||||
|
||||
@@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artifical stop'
|
||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
|
||||
worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
args, _ = rejection_sampler.call_args_list[0]
|
||||
@@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
|
||||
@@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs)
|
||||
|
||||
target_worker.execute_model.return_value = target_output[0]
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
@@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = (
|
||||
@@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
@@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
@@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict(), return_python_output=False)
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
|
||||
@@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
@@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
batch_size, k, prev_output_token_len=0)
|
||||
|
||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
||||
num_spec_tokens=k)
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||
@@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict(), return_python_output=False)
|
||||
**execute_model_data.to_dict())
|
||||
target_worker.execute_model.assert_called_once_with(
|
||||
**execute_model_data.to_dict())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user