[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
assert tokens1[0] == tokens2[1]
|
||||
assert tokens1[1] == tokens2[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_include_gpu_probs_tensor(device: str):
|
||||
set_random_seed(42)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampler.include_gpu_probs_tensor = True
|
||||
sampler.should_modify_greedy_probs_inplace = False
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
mock_inplace = Mock()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
|
||||
mock_inplace):
|
||||
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
mock_inplace.assert_not_called()
|
||||
|
||||
assert sampler_output.sampled_token_probs is not None
|
||||
assert sampler_output.logprobs is not None
|
||||
assert sampler_output.sampled_token_ids is not None
|
||||
|
||||
Reference in New Issue
Block a user