[CI/Build] Avoid CUDA initialization (#8534)
This commit is contained in:
@@ -203,7 +203,7 @@ def which_attn_to_use(
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
if not current_platform.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
@@ -212,7 +212,7 @@ def which_attn_to_use(
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if current_platform.get_device_capability()[0] < 8:
|
||||
if not current_platform.has_device_capability(80):
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
|
||||
Reference in New Issue
Block a user