[Sampler] Adapt to FlashInfer 0.2.3 sampler API (#15777)

Signed-off-by: Bowen Wang <abmfy@icloud.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Bowen Wang
2025-05-16 15:14:03 -07:00
committed by GitHub
parent aef94c6d07
commit 7fdfa01530
7 changed files with 122 additions and 88 deletions

View File

@@ -1,14 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
from torch import Generator
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
is_flashinfer_available)
DEVICE = "cuda"
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
def test_topk_impl_equivalance():
@@ -35,3 +41,67 @@ def test_topk_impl_equivalance():
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
assert torch.allclose(result1, result2)
def test_flashinfer_sampler():
'''
This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation.
NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
'''
if not FLASHINFER_ENABLED:
pytest.skip(
"FlashInfer not installed or not available on this platform.")
with torch.device(DEVICE):
generator = Generator(device=DEVICE).manual_seed(42)
# Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
# Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand(
(BATCH_SIZE, ),
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
# Sometimes disable top-k (k=vocab_size)
k_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), VOCAB_SIZE)
# Sometimes disable top-p (p=1.0)
p_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), 1.0)
python_logits = apply_top_k_top_p(
logits=logits.clone(),
k=k_values,
p=p_values,
)
python_probs = torch.softmax(python_logits, dim=-1)
# FlashInfer only exposed renorm interfaces for probs so convert first
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
flashinfer_probs = top_k_renorm_probs(
probs=flashinfer_probs,
top_k=k_values,
)
flashinfer_probs = top_p_renorm_probs(
probs=flashinfer_probs,
top_p=p_values,
)
# Compare the results
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
"FlashInfer and Python sampling implementations do not match!"