[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import random
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
||||
num_lookahead_slots=k)
|
||||
|
||||
assert len(rejection_sampler.call_args_list) == 1
|
||||
args, _ = rejection_sampler.call_args_list[0]
|
||||
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
|
||||
actual_proposal_token_ids) = args
|
||||
_, kwargs = rejection_sampler.call_args_list[0]
|
||||
actual = SimpleNamespace(**kwargs)
|
||||
|
||||
assert torch.equal(actual_bonus_token_ids,
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(
|
||||
actual_proposal_scores,
|
||||
actual.target_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
|
||||
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual_proposal_probs, proposal_probs)
|
||||
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
|
||||
target_worker = mock_worker(vocab_size=vocab_size)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
@@ -500,8 +506,8 @@ def test_init_device():
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||
rejection_sampler.token_id_dtype = torch.int64
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
Reference in New Issue
Block a user