[Core] Use flashinfer sampling kernel when available (#7137)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
@@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||
# top-k and top-p is only calculated when flashinfer kernel is not available
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
|
||||
patch("vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", None):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
assert sample_probs is not None
|
||||
@@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_flashinfer_fallback(seed: int, device: str):
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
pytest.skip("Flashinfer sampler is disabled")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
def failing_flashinfer_sampling(*_args, **_kwargs):
|
||||
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
|
||||
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert sampler_output == fallback_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user