[Bugfix] Check NVIDIA artifactory is accessible before using flashinfer cubin kernels (#21893)

This commit is contained in:
Michael Goin
2025-08-01 08:28:45 -04:00
committed by GitHub
parent fb0e0d46fc
commit f81c1bb055
4 changed files with 93 additions and 99 deletions

View File

@@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
@@ -38,7 +38,6 @@ logger = init_logger(__name__)
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
cached_sm100a_supported: Optional[bool] = None
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
@@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod
def use_trtllm_decode_attention(
batch_size: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: int,
num_kv_heads: int,
attn_head_size: int,
) -> bool:
if FlashInferBackend.cached_sm100a_supported is None:
FlashInferBackend.cached_sm100a_supported = (
current_platform.has_device_capability(100))
if not FlashInferBackend.cached_sm100a_supported:
return False
if (num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = env_value == "0"
if not no_use_trtllm:
logger.info_once(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
"using TRTLLM decode attention.")
return not no_use_trtllm
else:
# Environment variable not set - use auto-detection
# Only supports attention head size of 128
use_trtllm = (FlashInferBackend.cached_sm100a_supported
and batch_size <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once(
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
@@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention(
if not use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
@@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl):
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not FlashInferBackend.use_trtllm_decode_attention(
if not use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len,
self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):

View File

@@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
get_per_layer_parameters, infer_global_hyperparameters,
@@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata)
def use_flashinfer_prefill() -> bool:
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
# For blackwell default to flashinfer prefill if its available since
# its faster than FA2.
return current_platform.has_device_capability(100)
return False
# For blackwell default to flashinfer prefill if its available since
# it is faster than FA2.
return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL
and current_platform.is_device_capability(100))
def use_cudnn_prefill() -> bool:
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
return current_platform.has_device_capability(100)
return False
return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL
and current_platform.is_device_capability(100)
and has_nvidia_artifactory())
# Currently 394MB, this can be tuned based on GEMM sizes used.