[V1] Implement vLLM V1 [1/N] (#9289)

This commit is contained in:
Woosuk Kwon
2024-10-22 01:24:07 -07:00
committed by GitHub
parent 3ddbe25502
commit 6c5af09b39
27 changed files with 3058 additions and 180 deletions

View File

@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
@@ -110,6 +111,10 @@ def get_attn_backend(
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
@@ -215,6 +220,9 @@ def which_attn_to_use(
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
if envs.VLLM_USE_V1:
return _Backend.FLASH_ATTN_VLLM_V1
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if not current_platform.has_device_capability(80):