[Kernel] Lazy import FlashInfer (#26977)
This commit is contained in:
@@ -5,20 +5,13 @@ import torch
|
||||
from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
apply_top_k_top_p,
|
||||
is_flashinfer_available,
|
||||
)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
|
||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||
if is_flashinfer_available:
|
||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_default_device():
|
||||
@@ -65,6 +58,14 @@ def test_flashinfer_sampler():
|
||||
sampling results due to randomness), so we will compare the probability
|
||||
renormed consequently by top-k and then top-p of FlashInfer implementation.
|
||||
"""
|
||||
try:
|
||||
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
|
||||
|
||||
is_flashinfer_available = True
|
||||
except ImportError:
|
||||
is_flashinfer_available = False
|
||||
|
||||
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
|
||||
|
||||
if not FLASHINFER_ENABLED:
|
||||
pytest.skip("FlashInfer not installed or not available on this platform.")
|
||||
|
||||
Reference in New Issue
Block a user