[V1] Implement vLLM V1 [1/N] (#9289)
This commit is contained in:
@@ -17,6 +17,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||
XFORMERS = enum.auto()
|
||||
ROCM_FLASH = enum.auto()
|
||||
TORCH_SDPA = enum.auto()
|
||||
@@ -110,6 +111,10 @@ def get_attn_backend(
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
if backend == _Backend.FLASH_ATTN_VLLM_V1:
|
||||
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend as FlashAttentionBackendV1)
|
||||
return FlashAttentionBackendV1
|
||||
if backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||
@@ -215,6 +220,9 @@ def which_attn_to_use(
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if not current_platform.has_device_capability(80):
|
||||
|
||||
Reference in New Issue
Block a user