[Hardware] Initial TPU integration (#5292)

This commit is contained in:
Woosuk Kwon
2024-06-12 11:53:03 -07:00
committed by GitHub
parent 847cdcca1c
commit 1a8bfd92d5
22 changed files with 1322 additions and 28 deletions

View File

@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip
from vllm.utils import is_cpu, is_hip, is_tpu
logger = init_logger(__name__)
@@ -18,6 +18,7 @@ class _Backend(enum.Enum):
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
PALLAS = enum.auto()
@lru_cache(maxsize=None)
@@ -66,6 +67,10 @@ def get_attn_backend(
"Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
else:
raise ValueError("Invalid attention backend.")
@@ -80,7 +85,6 @@ def which_attn_to_use(
block_size: int,
) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
@@ -100,6 +104,11 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
if is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if is_hip():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend