[SpecDec][Misc] Cleanup, remove bonus token logic. (#8701)
This commit is contained in:
@@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
|
||||
@pytest.mark.parametrize(
|
||||
"which_tokens_accepted",
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
disable_bonus_tokens: bool, device: str,
|
||||
use_flashinfer: bool):
|
||||
device: str, use_flashinfer: bool):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
if use_flashinfer and disable_bonus_tokens:
|
||||
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler(
|
||||
disable_bonus_tokens=disable_bonus_tokens,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
@@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
)
|
||||
|
||||
expected_bonus_token_ids = bonus_token_ids.clone()
|
||||
# If bonus tokens disabled. Verify they are set to -1.
|
||||
# See https://github.com/vllm-project/vllm/issues/4212
|
||||
if disable_bonus_tokens:
|
||||
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
# Expect all tokens to be equal to draft tokens.
|
||||
@@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
@@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
@@ -251,8 +238,7 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
}
|
||||
|
||||
for use_flashinfer in [True, False]:
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
# We use seeded sequences to ensure the same tokens are accepted
|
||||
# for both flashinfer and nonflashinfer backends.
|
||||
@@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer,
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
|
||||
strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
@@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
|
||||
set_random_seed(seed)
|
||||
helper = _CorrectnessTestHelper(
|
||||
vocab_size=10,
|
||||
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
|
||||
use_flashinfer=use_flashinfer),
|
||||
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
|
||||
)
|
||||
|
||||
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
||||
|
||||
Reference in New Issue
Block a user