[Kernel] Use flashinfer for decoding (#4353)

Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2024-05-03 15:51:27 -07:00
committed by GitHub
parent f8e7adda21
commit 43c413ec57
15 changed files with 600 additions and 53 deletions

View File

@@ -17,6 +17,7 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
@lru_cache(maxsize=None)
@@ -41,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend. ")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")