[Bugfix] Check NVIDIA artifactory is accessible before using flashinfer cubin kernels (#21893)
This commit is contained in:
@@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
@@ -38,7 +38,6 @@ logger = init_logger(__name__)
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
cached_sm100a_supported: Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
@@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend):
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def use_trtllm_decode_attention(
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
attn_head_size: int,
|
||||
) -> bool:
|
||||
if FlashInferBackend.cached_sm100a_supported is None:
|
||||
FlashInferBackend.cached_sm100a_supported = (
|
||||
current_platform.has_device_capability(100))
|
||||
if not FlashInferBackend.cached_sm100a_supported:
|
||||
return False
|
||||
if (num_qo_heads // num_kv_heads > 8
|
||||
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||
return False
|
||||
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
||||
env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
no_use_trtllm = env_value == "0"
|
||||
if not no_use_trtllm:
|
||||
logger.info_once(
|
||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
|
||||
"using TRTLLM decode attention.")
|
||||
return not no_use_trtllm
|
||||
else:
|
||||
# Environment variable not set - use auto-detection
|
||||
# Only supports attention head size of 128
|
||||
use_trtllm = (FlashInferBackend.cached_sm100a_supported
|
||||
and batch_size <= 256 and max_seq_len < 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
if use_trtllm:
|
||||
logger.warning_once(
|
||||
"Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
@@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
if num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
if not use_trtllm_decode_attention(
|
||||
num_decodes, attn_metadata.max_seq_len,
|
||||
self.cache_config.cache_dtype,
|
||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||
@@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
if not use_trtllm_decode_attention(
|
||||
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
||||
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
|
||||
@@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_per_layer_parameters, infer_global_hyperparameters,
|
||||
@@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
|
||||
# For blackwell default to flashinfer prefill if its available since
|
||||
# its faster than FA2.
|
||||
return current_platform.has_device_capability(100)
|
||||
return False
|
||||
# For blackwell default to flashinfer prefill if its available since
|
||||
# it is faster than FA2.
|
||||
return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL
|
||||
and current_platform.is_device_capability(100))
|
||||
|
||||
|
||||
def use_cudnn_prefill() -> bool:
|
||||
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
|
||||
return current_platform.has_device_capability(100)
|
||||
return False
|
||||
return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL
|
||||
and current_platform.is_device_capability(100)
|
||||
and has_nvidia_artifactory())
|
||||
|
||||
|
||||
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
||||
|
||||
Reference in New Issue
Block a user