[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
@@ -11,9 +11,15 @@ distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality. This gives us good coverage of temp=0.
|
||||
|
||||
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
|
||||
highest probability in the target distribution are accepted. Therefore, we can
|
||||
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
|
||||
|
||||
For temp>0, we rely on unit tests on the rejection sampler to verify that the
|
||||
output distribution is the same with spec decode vs. no spec decode (this would
|
||||
be prohibitively expensive to run with a real model).
|
||||
be prohibitively expensive to run with a real model). Similarly, for the
|
||||
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
|
||||
test cases.
|
||||
|
||||
NOTE: Speculative decoding's distribution equality requires that the measured
|
||||
distributions of the target model and proposal model be deterministic given the
|
||||
@@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
|
||||
}
|
||||
# Try a range of common k.
|
||||
for k in [1, 2, 3]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_typical_acceptance_sampling(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with TypicalAcceptanceSampler as the draft token acceptance
|
||||
sampling method.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
@@ -3,33 +3,35 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import create_batch, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('queue_size', [4])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('k', [1])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that speculative tokens are disabled when the batch size
|
||||
exceeds the threshold.
|
||||
"""
|
||||
disable_by_batch_size = 3
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
rejection_sampler=rejection_sampler,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
metrics_collector=metrics_collector,
|
||||
disable_by_batch_size=disable_by_batch_size)
|
||||
|
||||
|
||||
@@ -10,16 +10,16 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
def test_initial_call_returns_none():
|
||||
"""Expect first call to get metrics to return None.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(rej_sampler)
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert maybe_metrics is None
|
||||
@@ -28,14 +28,14 @@ def test_initial_call_returns_none():
|
||||
def test_second_call_returns_metrics():
|
||||
"""Expect second call to not return None.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
@@ -43,7 +43,7 @@ def test_second_call_returns_metrics():
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
@@ -56,16 +56,16 @@ def test_second_call_returns_metrics():
|
||||
def test_nonzero_rank_noop(rank):
|
||||
"""Verify nonzero ranks don't collect metrics.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collector = AsyncMetricsCollector(rej_sampler)
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler)
|
||||
collector.init_gpu_tensors(rank=rank)
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
@@ -75,14 +75,14 @@ def test_nonzero_rank_noop(rank):
|
||||
def test_noop_until_time():
|
||||
"""Verify metrics aren't collected until enough time passes.
|
||||
"""
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = 0
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
@@ -91,7 +91,7 @@ def test_noop_until_time():
|
||||
collect_interval_s + 0.1, collect_interval_s + 0.1
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
@@ -122,14 +122,14 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
|
||||
num_draft_tokens, k)
|
||||
|
||||
rej_sampler = MagicMock()
|
||||
rej_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
rej_sampler.num_draft_tokens = num_draft_tokens
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = num_draft_tokens
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
@@ -137,7 +137,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(rejection_sampler=rej_sampler,
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
|
||||
@@ -6,7 +6,6 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
@@ -16,23 +15,26 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
|
||||
from .test_utils import mock_spec_decode_sampler
|
||||
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
@@ -53,15 +55,16 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'cuda'
|
||||
@@ -69,8 +72,9 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
@@ -133,8 +137,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||
correct inputs. Everything else is mocked out.
|
||||
"""
|
||||
@@ -144,15 +151,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@@ -199,15 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
_, kwargs = rejection_sampler.call_args_list[0]
|
||||
assert len(spec_decode_sampler.call_args_list) == 1
|
||||
_, kwargs = spec_decode_sampler.call_args_list[0]
|
||||
actual = SimpleNamespace(**kwargs)
|
||||
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
@@ -221,8 +228,11 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_formats_output(k: int, batch_size: int):
|
||||
def test_correctly_formats_output(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||
Everything else is mocked out.
|
||||
"""
|
||||
@@ -232,15 +242,13 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@@ -286,24 +294,23 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
rejection_sampler_output[i][
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
||||
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
||||
probs=[None for _ in range(k + 1)],
|
||||
logprobs=[None for _ in range(k + 1)])
|
||||
|
||||
@@ -350,8 +357,11 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker collects metrics.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
@@ -360,15 +370,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'cuda'
|
||||
target_worker.device = 'cuda'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
@@ -414,17 +423,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
rejection_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
rejection_sampler_output[i][
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
rejection_sampler.return_value = rejection_sampler_output
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
|
||||
mock_rejsample_metrics = MagicMock(
|
||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||
@@ -445,15 +453,16 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
|
||||
@pytest.mark.parametrize('k', [0])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_k_equals_zero(k: int, batch_size: int):
|
||||
def test_k_equals_zero(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when k is zero. This happens during prefill.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
@@ -465,8 +474,9 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@@ -487,16 +497,17 @@ def test_k_equals_zero(k: int, batch_size: int):
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
@pytest.mark.parametrize('batch_size', [0])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_empty_input_batch(k: int, batch_size: int):
|
||||
def test_empty_input_batch(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when the input batch is empty. This can happen if the engine communicates
|
||||
to the workers information without scheduling a batch.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
@@ -508,8 +519,9 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@@ -528,18 +540,19 @@ def test_empty_input_batch(k: int, batch_size: int):
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_init_device():
|
||||
def test_init_device(acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
|
||||
worker.init_device()
|
||||
@@ -549,22 +562,23 @@ def test_init_device():
|
||||
target_worker.init_device.assert_called_once()
|
||||
|
||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||
rejection_sampler.init_gpu_tensors.assert_called_once()
|
||||
spec_decode_sampler.init_gpu_tensors.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_initialize_cache():
|
||||
def test_initialize_cache(acceptance_sampler_method):
|
||||
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
||||
workers.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||
worker.initialize_cache(**kwargs)
|
||||
@@ -577,19 +591,20 @@ def test_initialize_cache():
|
||||
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||
available_cpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int):
|
||||
draft_kv_size_bytes: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||
split the blocks between proposer and scorer worker.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.determine_num_available_blocks.return_value = (
|
||||
@@ -598,8 +613,9 @@ def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||
target_cache_block_size_bytes)
|
||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||
from vllm.spec_decode.util import split_batch_by_proposal_len
|
||||
|
||||
@@ -109,3 +113,21 @@ def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def mock_spec_decode_sampler(acceptance_sampler_method):
|
||||
"""
|
||||
Returns either a RejectionSampler or TypicalAcceptanceSampler
|
||||
object depending on whether acceptance_sampler_method is
|
||||
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
|
||||
"""
|
||||
if acceptance_sampler_method == "rejection_sampler":
|
||||
sampler = MagicMock(spec=RejectionSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
elif acceptance_sampler_method == "typical_acceptance_sampler":
|
||||
sampler = MagicMock(spec=TypicalAcceptanceSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
else:
|
||||
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
||||
|
||||
Reference in New Issue
Block a user