[SpecDecode][Kernel] Flashinfer Rejection Sampling (#7244)
This commit is contained in:
@@ -44,12 +44,16 @@ def mock_causal_accepted_tensor(
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str,
|
||||
disable_bonus_tokens: bool, seed: int,
|
||||
device: str):
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
disable_bonus_tokens: bool, device: str,
|
||||
use_flashinfer: bool):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
if use_flashinfer and disable_bonus_tokens:
|
||||
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str,
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens)
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
@@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str,
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
device: str, use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int,
|
||||
device: str):
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
assert torch.equal(results[j][i], results[0][i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
batch_size: int, device: str):
|
||||
"""
|
||||
Test the flashinfer and nonflashinfer backend generate
|
||||
the same output metrics.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
num_accepted_tokens = []
|
||||
num_emitted_tokens = []
|
||||
num_draft_tokens = []
|
||||
|
||||
def get_seeded_seqs():
|
||||
return {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(batch_size)
|
||||
}
|
||||
|
||||
for use_flashinfer in [True, False]:
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
# We use seeded sequences to ensure the same tokens are accepted
|
||||
# for both flashinfer and nonflashinfer backends.
|
||||
seeded_seqs = get_seeded_seqs()
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, seeded_seqs)
|
||||
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
|
||||
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
|
||||
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
|
||||
|
||||
assert num_accepted_tokens[0] == num_accepted_tokens[1]
|
||||
assert num_emitted_tokens[0] == num_emitted_tokens[1]
|
||||
assert num_draft_tokens[0] == num_draft_tokens[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str):
|
||||
which_token_ids: str, device: str,
|
||||
use_flashinfer: bool):
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(strict_mode=True)
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer,
|
||||
strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
@@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
|
||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||
@pytest.mark.parametrize("seed", list(range(5)))
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_rejection_sampling_approximates_target_distribution(
|
||||
seed: int, draft_and_target_probs_equal: bool):
|
||||
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
|
||||
"""Verify rejection sampling approximates target distribution,
|
||||
despite sampling from a potentially distinct draft distribution.
|
||||
|
||||
@@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution(
|
||||
"""
|
||||
torch.set_default_device("cpu")
|
||||
set_random_seed(seed)
|
||||
|
||||
helper = _CorrectnessTestHelper(
|
||||
vocab_size=10,
|
||||
rejection_sampler=RejectionSampler(),
|
||||
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer),
|
||||
)
|
||||
|
||||
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
||||
@@ -398,10 +476,10 @@ class _CorrectnessTestHelper:
|
||||
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
|
||||
num_samples, 1, 1)
|
||||
|
||||
# Repeat target probs num_samples * k times.
|
||||
# Repeat target probs num_samples * (k + 1) times.
|
||||
# Rejection sampler requires bonus token probs, but they aren't used.
|
||||
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
|
||||
num_samples, self.k, 1)
|
||||
num_samples, self.k + 1, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
|
||||
Reference in New Issue
Block a user