[Kernel][RFC] Refactor the punica kernel based on Triton (#5036)
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
import gc
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential,
|
||||
from vllm.model_executor.layers.ops.sample import (_sample_triton,
|
||||
_uniform_to_exponential,
|
||||
sample)
|
||||
from vllm.model_executor.sampling_metadata import SamplingTensors
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
|
||||
get_num_triton_sampler_splits)
|
||||
|
||||
@@ -76,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
||||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, bs),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
_save_modified_probs=True)
|
||||
#The current _sample_triton does not utilize the
|
||||
# libentry decoration. The purpose of adding this patch is to test
|
||||
# the correctness of libentry.
|
||||
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
|
||||
LibEntry(_sample_triton)):
|
||||
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
_save_modified_probs=True)
|
||||
assert sampled_tokens.shape == (bs, max_best_of)
|
||||
for i in range(bs):
|
||||
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
|
||||
@@ -130,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
||||
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||
def test_sample_prompt_logprobs(random_sampling, max_best_of,
|
||||
modify_greedy_probs, seed, vocab_size):
|
||||
|
||||
set_random_seed(seed)
|
||||
prompt_sizes = [16, 32, 64, 128] * 2
|
||||
samples = 8
|
||||
@@ -157,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
|
||||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, samples),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
sampled_tokens, sampled_logprobs, _ = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=True)
|
||||
#ditto
|
||||
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
|
||||
LibEntry(_sample_triton)):
|
||||
sampled_tokens, sampled_logprobs, _ = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=True)
|
||||
assert sampled_tokens.shape == (samples, max_best_of)
|
||||
assert sampled_logprobs.shape == (samples, max_best_of)
|
||||
for i, t in enumerate(sample_indices):
|
||||
|
||||
Reference in New Issue
Block a user