[Core] Refactor _prepare_model_input_tensors - take 2 (#6164)

This commit is contained in:
Cody Yu
2024-07-17 09:37:16 -07:00
committed by GitHub
parent a9a2e74d21
commit 2fa4623d9e
12 changed files with 1050 additions and 470 deletions

View File

@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
logger = init_logger(__name__)
@@ -136,7 +137,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if torch.cuda.get_device_capability()[0] != 9:
if current_platform.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
@@ -145,7 +146,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if torch.cuda.get_device_capability()[0] < 8:
if current_platform.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "