[Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.n is not 1 and max_tokens is large & Add tests for preemption (#4451)

This commit is contained in:
SangBin Cho
2024-05-02 11:24:13 +09:00
committed by GitHub
parent b8afa8b95a
commit 0d62fe58db
6 changed files with 172 additions and 13 deletions

View File

@@ -33,7 +33,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector)
exception_secret = 'artifical stop'
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
execute_model_data, _, _ = create_batch(batch_size, k)
@@ -101,7 +101,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
exception_secret = 'artifical stop'
exception_secret = 'artificial stop'
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
@@ -197,7 +197,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artifical stop'
exception_secret = 'artificial stop'
rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):