[Sampler] Adapt to FlashInfer 0.2.3 sampler API (#15777)
Signed-off-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -169,7 +169,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
@@ -214,7 +217,10 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# @pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
# Not testing FlashInfer now, since 0.2.3 API removed the ability
|
||||
# to pass in uniform samples.
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
@@ -284,6 +290,10 @@ def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
Test the flashinfer and nonflashinfer backend generate
|
||||
the same output metrics.
|
||||
"""
|
||||
|
||||
pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed "
|
||||
"the ability to pass in uniform samples.")
|
||||
|
||||
torch.set_default_device(device)
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user