[Hardware] Initial TPU integration (#5292)
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import is_cpu, is_hip, is_tpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -18,6 +18,7 @@ class _Backend(enum.Enum):
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@@ -66,6 +67,10 @@ def get_attn_backend(
|
||||
"Please make sure --enforce-eager is set.")
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
logger.info("Using Pallas backend.")
|
||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
||||
return PallasAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
@@ -80,7 +85,6 @@ def which_attn_to_use(
|
||||
block_size: int,
|
||||
) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
@@ -100,6 +104,11 @@ def which_attn_to_use(
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if is_tpu():
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
|
||||
if is_hip():
|
||||
# AMD GPUs.
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
|
||||
Reference in New Issue
Block a user