[Core] Refactor _prepare_model_input_tensors - take 2 (#6164)
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -136,7 +137,7 @@ def which_attn_to_use(
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if torch.cuda.get_device_capability()[0] != 9:
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
@@ -145,7 +146,7 @@ def which_attn_to_use(
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
if current_platform.get_device_capability()[0] < 8:
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
|
||||
Reference in New Issue
Block a user