[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)

This commit is contained in:
Cade Daniel
2024-04-23 01:02:36 -07:00
committed by GitHub
parent 050f285ff6
commit 62b8aebc6f
22 changed files with 1164 additions and 175 deletions

View File

@@ -1,4 +1,5 @@
import random
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0]
(actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
actual_proposal_token_ids) = args
_, kwargs = rejection_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs)
assert torch.equal(actual_bonus_token_ids,
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual_proposal_scores,
actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
assert torch.equal(actual_proposal_probs, proposal_probs)
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
target_worker = mock_worker(vocab_size=vocab_size)
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -500,8 +506,8 @@ def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)