diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 112c3a5a9..8b5edc0d3 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -229,7 +229,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( get_and_maybe_dequant_weights, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_nvidia_artifactory +from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory from vllm.utils.math_utils import cdiv, round_down from vllm.utils.torch_utils import ( direct_register_custom_op, @@ -599,13 +599,6 @@ except ImportError: is_vllm_fa = False -@functools.cache -def flashinfer_available() -> bool: - import importlib.util - - return importlib.util.find_spec("flashinfer") is not None - - def dynamic_per_batched_tensor_quant( x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn ): @@ -824,7 +817,7 @@ def use_flashinfer_prefill() -> bool: vllm_config = get_current_vllm_config() if not ( not vllm_config.attention_config.disable_flashinfer_prefill - and flashinfer_available() + and has_flashinfer() and not vllm_config.attention_config.use_cudnn_prefill and current_platform.is_device_capability_family(100) ): @@ -838,7 +831,7 @@ def use_cudnn_prefill() -> bool: vllm_config = get_current_vllm_config() return ( - flashinfer_available() + has_flashinfer() and vllm_config.attention_config.use_cudnn_prefill and current_platform.is_device_capability_family(100) and has_nvidia_artifactory() @@ -851,7 +844,7 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: vllm_config = get_current_vllm_config() if not ( - flashinfer_available + has_flashinfer() and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill and current_platform.is_device_capability_family(100) ):