[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)
This commit is contained in:
@@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
def mock_sample(probs, *args, **kwargs):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
Reference in New Issue
Block a user