[Core] Use flashinfer sampling kernel when available (#7137)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Peng Guanwen
2024-08-19 11:24:03 +08:00
committed by GitHub
parent ff7ec82c4d
commit f710fb5265
5 changed files with 129 additions and 27 deletions

View File

@@ -8,6 +8,7 @@ import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
@@ -634,7 +635,10 @@ def test_sampler_top_k_top_p(seed: int, device: str):
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
# top-k and top-p is only calculated when flashinfer kernel is not available
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
patch("vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", None):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
assert sample_probs is not None
@@ -645,6 +649,37 @@ def test_sampler_top_k_top_p(seed: int, device: str):
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
pytest.skip("Flashinfer sampler is disabled")
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
def failing_flashinfer_sampling(*_args, **_kwargs):
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
with patch(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
assert sampler_output == fallback_sampler_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):