[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)

This commit is contained in:
Cade Daniel
2024-04-23 01:02:36 -07:00
committed by GitHub
parent 050f285ff6
commit 62b8aebc6f
22 changed files with 1164 additions and 175 deletions

View File

@@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)