[Test] Add xformer and flash attn tests (#3961)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Type
|
||||
|
||||
@@ -10,6 +11,8 @@ from vllm.utils import is_cpu, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
@@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||
"Cannot use FlashAttention backend because the flash_attn package "
|
||||
"is not found. Please install it for better performance.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var is not None:
|
||||
return _Backend[backend_by_env_var]
|
||||
|
||||
# Default case.
|
||||
return _Backend.FLASH_ATTN
|
||||
|
||||
Reference in New Issue
Block a user